diff --git a/.github/ISSUE_TEMPLATE/40-tflite-op-request.md b/.github/ISSUE_TEMPLATE/40-tflite-op-request.md new file mode 100644 index 00000000000..7b391279e47 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/40-tflite-op-request.md @@ -0,0 +1,24 @@ +--- +name: TensorFlow Lite Op Request +about: Use this template for reporting ops you are using or missing. + +--- + + +**System information** +- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): +- TensorFlow installed from (source or binary): +- TensorFlow version (or github SHA if from source): + + +**Provide the text output from tflite_convert** + +``` +# Copy and paste here +``` + +Also, please include a link to a GraphDef or the model if possible. + +**Any other info / logs** + +Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached. diff --git a/README.md b/README.md index 8af5370befb..6fefdd32244 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,8 @@ data flow graphs. The graph nodes represent mathematical operations, while the graph edges represent the multidimensional data arrays (tensors) that flow between them. This flexible architecture enables you to deploy computation to one or more CPUs or GPUs in a desktop, server, or mobile device without rewriting -code. TensorFlow also includes [TensorBoard](https://www.tensorflow.org/guide/summaries_and_tensorboard), a data visualization toolkit. +code. TensorFlow also includes [TensorBoard](https://github.com/tensorflow/tensorboard), +a data visualization toolkit. TensorFlow was originally developed by researchers and engineers working on the Google Brain team within Google's Machine Intelligence Research @@ -111,7 +112,7 @@ The TensorFlow project strives to abide by generally accepted best practices in Build Type | Status | Artifacts ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------- **IBM s390x** | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA -**IBM ppc64le CPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/) | TBA +**IBM ppc64le CPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | TBA **IBM ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) **IBM ppc64le GPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) **Linux CPU with IntelĀ® MKL-DNN** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) @@ -127,6 +128,7 @@ Build Type * [TensorFlow Roadmap](https://www.tensorflow.org/community/roadmap) * [TensorFlow White Papers](https://www.tensorflow.org/about/bib) * [TensorFlow YouTube Channel](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ) +* [TensorFlow Visualization Toolkit](https://github.com/tensorflow/tensorboard) Learn more about the TensorFlow community at the [community page of tensorflow.org](https://www.tensorflow.org/community) for a few ways to participate. diff --git a/WORKSPACE b/WORKSPACE index 0c7bc085b51..7cc08e0164a 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -1,5 +1,7 @@ workspace(name = "org_tensorflow") +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + http_archive( name = "io_bazel_rules_closure", sha256 = "a38539c5b5c358548e75b44141b4ab637bba7c4dc02b46b1f62a96d6433f56ae", @@ -57,9 +59,9 @@ android_workspace() # Please add all new TensorFlow dependencies in workspace.bzl. tf_workspace() -new_http_archive( +http_archive( name = "inception_v1", - build_file = "models.BUILD", + build_file = "//:models.BUILD", sha256 = "7efe12a8363f09bc24d7b7a450304a15655a57a7751929b2c1593a71183bb105", urls = [ "http://storage.googleapis.com/download.tensorflow.org/models/inception_v1.zip", @@ -67,9 +69,9 @@ new_http_archive( ], ) -new_http_archive( +http_archive( name = "mobile_ssd", - build_file = "models.BUILD", + build_file = "//:models.BUILD", sha256 = "bddd81ea5c80a97adfac1c9f770e6f55cbafd7cce4d3bbe15fbeb041e6b8f3e8", urls = [ "http://storage.googleapis.com/download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_android_export.zip", @@ -77,9 +79,9 @@ new_http_archive( ], ) -new_http_archive( +http_archive( name = "mobile_multibox", - build_file = "models.BUILD", + build_file = "//:models.BUILD", sha256 = "859edcddf84dddb974c36c36cfc1f74555148e9c9213dedacf1d6b613ad52b96", urls = [ "http://storage.googleapis.com/download.tensorflow.org/models/mobile_multibox_v1a.zip", @@ -87,9 +89,9 @@ new_http_archive( ], ) -new_http_archive( +http_archive( name = "stylize", - build_file = "models.BUILD", + build_file = "//:models.BUILD", sha256 = "3d374a730aef330424a356a8d4f04d8a54277c425e274ecb7d9c83aa912c6bfa", urls = [ "http://storage.googleapis.com/download.tensorflow.org/models/stylize_v1.zip", @@ -97,9 +99,9 @@ new_http_archive( ], ) -new_http_archive( +http_archive( name = "speech_commands", - build_file = "models.BUILD", + build_file = "//:models.BUILD", sha256 = "c3ec4fea3158eb111f1d932336351edfe8bd515bb6e87aad4f25dbad0a600d0c", urls = [ "http://storage.googleapis.com/download.tensorflow.org/models/speech_commands_v0.01.zip", diff --git a/configure.py b/configure.py index 234561d94a4..5f429c3de89 100644 --- a/configure.py +++ b/configure.py @@ -238,6 +238,13 @@ def setup_python(environ_cp): write_to_bazelrc('build --python_path=\"%s"' % python_bin_path) environ_cp['PYTHON_BIN_PATH'] = python_bin_path + # If choosen python_lib_path is from a path specified in the PYTHONPATH + # variable, need to tell bazel to include PYTHONPATH + if environ_cp.get('PYTHONPATH'): + python_paths = environ_cp.get('PYTHONPATH').split(':') + if python_lib_path in python_paths: + write_action_env_to_bazelrc('PYTHONPATH', environ_cp.get('PYTHONPATH')) + # Write tools/python_bin_path.sh with open( os.path.join(_TF_WORKSPACE_ROOT, 'tools', 'python_bin_path.sh'), @@ -445,11 +452,12 @@ def convert_version_to_int(version): return int(version_str) -def check_bazel_version(min_version): - """Check installed bazel version is at least min_version. +def check_bazel_version(min_version, max_version): + """Check installed bazel version is between min_version and max_version. Args: min_version: string for minimum bazel version. + max_version: string for maximum bazel version. Returns: The bazel version detected. @@ -467,6 +475,7 @@ def check_bazel_version(min_version): min_version_int = convert_version_to_int(min_version) curr_version_int = convert_version_to_int(curr_version) + max_version_int = convert_version_to_int(max_version) # Check if current bazel version can be detected properly. if not curr_version_int: @@ -480,6 +489,10 @@ def check_bazel_version(min_version): print('Please upgrade your bazel installation to version %s or higher to ' 'build TensorFlow!' % min_version) sys.exit(0) + if curr_version_int > max_version_int: + print('Please downgrade your bazel installation to version %s or lower to ' + 'build TensorFlow!' % max_version) + sys.exit(0) return curr_version @@ -859,7 +872,7 @@ def set_tf_cuda_version(environ_cp): cuda_toolkit_paths_full = [ os.path.join(cuda_toolkit_path, x) for x in cuda_rt_lib_paths ] - if any([os.path.exists(x) for x in cuda_toolkit_paths_full]): + if any(os.path.exists(x) for x in cuda_toolkit_paths_full): break # Reset and retry @@ -1552,7 +1565,7 @@ def main(): # environment variables. environ_cp = dict(os.environ) - check_bazel_version('0.15.0') + check_bazel_version('0.15.0', '0.20.0') reset_tf_configure_bazelrc() # Explicitly import tools/bazel.rc, this is needed for Bazel 0.19.0 or later @@ -1694,6 +1707,7 @@ def main(): config_info_line('nohdfs', 'Disable HDFS support.') config_info_line('noignite', 'Disable Apacha Ignite support.') config_info_line('nokafka', 'Disable Apache Kafka support.') + config_info_line('nonccl', 'Disable NVIDIA NCCL support.') if __name__ == '__main__': diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 859dc3b8d77..fd4b94202aa 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -43,6 +43,11 @@ TENSORFLOW_API_INIT_FILES_V2 = ( TENSORFLOW_API_INIT_FILES + get_compat_files(TENSORFLOW_API_INIT_FILES_V1, 1) ) +# @unused +TENSORFLOW_API_INIT_FILES_V1_WITH_COMPAT = ( + TENSORFLOW_API_INIT_FILES_V1 + get_compat_files(TENSORFLOW_API_INIT_FILES_V1, 1) +) + # Config setting used when building for products # which requires restricted licenses to be avoided. config_setting( @@ -213,31 +218,37 @@ config_setting( # config_setting( name = "no_aws_support", - define_values = {"no_aws_support": "false"}, + define_values = {"no_aws_support": "true"}, visibility = ["//visibility:public"], ) config_setting( name = "no_gcp_support", - define_values = {"no_gcp_support": "false"}, + define_values = {"no_gcp_support": "true"}, visibility = ["//visibility:public"], ) config_setting( name = "no_hdfs_support", - define_values = {"no_hdfs_support": "false"}, + define_values = {"no_hdfs_support": "true"}, visibility = ["//visibility:public"], ) config_setting( name = "no_ignite_support", - define_values = {"no_ignite_support": "false"}, + define_values = {"no_ignite_support": "true"}, visibility = ["//visibility:public"], ) config_setting( name = "no_kafka_support", - define_values = {"no_kafka_support": "false"}, + define_values = {"no_kafka_support": "true"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "no_nccl_support", + define_values = {"no_nccl_support": "true"}, visibility = ["//visibility:public"], ) @@ -350,7 +361,7 @@ package_group( "-//third_party/tensorflow/python/estimator", "//learning/meta_rank/...", "//tensorflow/...", - "//tensorflow_estimator/...", + "//tensorflow_estimator/contrib/...", "//tensorflow_fold/llgtm/...", "//tensorflow_text/...", "//third_party/py/tensor2tensor/...", @@ -554,18 +565,24 @@ genrule( }), outs = ["__init__.py"], cmd = select({ - "api_version_2": "cp $(@D)/_api/v2/__init__.py $(OUTS)", - "//conditions:default": "cp $(@D)/_api/v1/__init__.py $(OUTS)", + "api_version_2": "cp $(@D)/_api/v2/v2.py $(OUTS)", + "//conditions:default": "cp $(@D)/_api/v1/v1.py $(OUTS)", }), ) gen_api_init_files( name = "tf_python_api_gen_v1", - srcs = ["api_template_v1.__init__.py"], + srcs = [ + "api_template_v1.__init__.py", + "compat_template_v1.__init__.py", + ], api_version = 1, + compat_api_versions = [1], + compat_init_templates = ["compat_template_v1.__init__.py"], output_dir = "_api/v1/", - output_files = TENSORFLOW_API_INIT_FILES_V1, + output_files = TENSORFLOW_API_INIT_FILES_V1_WITH_COMPAT, output_package = "tensorflow._api.v1", + root_file_name = "v1.py", root_init_template = "api_template_v1.__init__.py", ) @@ -581,6 +598,7 @@ gen_api_init_files( output_dir = "_api/v2/", output_files = TENSORFLOW_API_INIT_FILES_V2, output_package = "tensorflow._api.v2", + root_file_name = "v2.py", root_init_template = "api_template.__init__.py", ) diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index 0d497568385..d81cf067eb0 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -21,8 +21,6 @@ from __future__ import print_function as _print_function import os as _os # pylint: disable=g-bad-import-order -from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import - from tensorflow.python.tools import component_api_helper as _component_api_helper _component_api_helper.package_hook( parent_package_str=__name__, @@ -30,16 +28,16 @@ _component_api_helper.package_hook( # API IMPORTS PLACEHOLDER -from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top - # Make sure directory containing top level submodules is in # the __path__ so that "from tensorflow.foo import bar" works. -_tf_api_dir = _os.path.dirname(_os.path.dirname(app.__file__)) # pylint: disable=undefined-variable +# We're using bitwise, but there's nothing special about that. +_tf_api_dir = _os.path.dirname(_os.path.dirname(bitwise.__file__)) # pylint: disable=undefined-variable if _tf_api_dir not in __path__: __path__.append(_tf_api_dir) -# Calls to enable and disable features. -enable_eager_execution() # pylint: disable=undefined-variable +# Enable TF2 behaviors +from tensorflow.python.compat import compat as _compat # pylint: disable=g-import-not-at-top +_compat.enable_v2_behavior() # These symbols appear because we import the python package which # in turn imports from tensorflow.core and tensorflow.python. They diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index b8db1b21449..59c23e7c184 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -60,6 +60,7 @@ tf_cuda_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:op_gen_lib", + "//tensorflow/core/distributed_runtime:server_lib", ], }), ) @@ -120,7 +121,8 @@ tf_cuda_library( ":c_api", ":c_api_internal", "//tensorflow/c/eager:c_api", - "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", + "//tensorflow/c/eager:c_api_internal", + "//tensorflow/compiler/jit:flags", "//tensorflow/contrib/tpu:all_ops", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", @@ -173,6 +175,30 @@ tf_cuda_library( ], ) +tf_cuda_library( + name = "kernels", + srcs = [ + "kernels.cc", + ], + hdrs = [ + "kernels.h", + ], + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = select({ + "//tensorflow:android": [ + ":c_api", + ":c_api_internal", + "//tensorflow/core:android_tensorflow_lib_lite", + ], + "//conditions:default": [ + ":c_api", + ":c_api_internal", + "//tensorflow/core:framework", + ], + }), +) + # ----------------------------------------------------------------------------- # Tests @@ -208,7 +234,10 @@ tf_cuda_cc_test( "//tensorflow:darwin": ["-headerpad_max_install_names"], "//conditions:default": [], }), - tags = ["noasan"], + tags = [ + "no_oss", # http://b/119522529 + "noasan", + ], # We must ensure that the dependencies can be dynamically linked since # the shared library must be able to use core:framework. # linkstatic = tf_kernel_tests_linkstatic(), @@ -237,7 +266,7 @@ tf_cuda_cc_test( tf_cc_test( name = "c_api_experimental_test", - size = "small", + size = "medium", srcs = ["c_api_experimental_test.cc"], data = ["testdata/tf_record"], linkopts = select({ @@ -248,8 +277,11 @@ tf_cc_test( # the shared library must be able to use core:framework. # linkstatic = tf_kernel_tests_linkstatic(), deps = [ + ":c_api", ":c_api_experimental", ":c_test_util", + "//tensorflow/c/eager:c_api", + "//tensorflow/c/eager:c_api_test_util", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", @@ -300,6 +332,30 @@ tf_kernel_library( alwayslink = 1, ) +tf_cuda_cc_test( + name = "kernels_test", + size = "small", + srcs = ["kernels_test.cc"], + linkopts = select({ + "//tensorflow:darwin": ["-headerpad_max_install_names"], + "//conditions:default": [], + }), + tags = ["noasan"], + # We must ensure that the dependencies can be dynamically linked since + # the shared library must be able to use core:framework. + # linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":c_api", + ":kernels", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:proto_text", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + # ----------------------------------------------------------------------------- # Python API target diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index fabe2fa0f60..38e29aa74a9 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -15,13 +15,18 @@ limitations under the License. #include "tensorflow/c/c_api_experimental.h" +#include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" -#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_internal.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/net.h" #include "tensorflow/core/platform/platform.h" #include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/protobuf/tensorflow_server.pb.h" @@ -51,8 +56,8 @@ void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) { // These XLA flags are needed to trigger XLA properly from C (more generally // non-Python) clients. If this API is called again with `enable` set to // false, it is safe to keep these flag values as is. - tensorflow::legacy_flags::MarkForCompilationPassFlags* flags = - tensorflow::legacy_flags::GetMarkForCompilationPassFlags(); + tensorflow::MarkForCompilationPassFlags* flags = + tensorflow::GetMarkForCompilationPassFlags(); flags->tf_xla_cpu_global_jit = true; flags->tf_xla_min_cluster_size = 1; } else { @@ -71,8 +76,8 @@ TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation, // These XLA flags are needed to trigger XLA properly from C (more generally // non-Python) clients. If this API is called again with `enable` set to // false, it is safe to keep these flag values as is. - tensorflow::legacy_flags::MarkForCompilationPassFlags* flags = - tensorflow::legacy_flags::GetMarkForCompilationPassFlags(); + tensorflow::MarkForCompilationPassFlags* flags = + tensorflow::GetMarkForCompilationPassFlags(); flags->tf_xla_cpu_global_jit = true; flags->tf_xla_min_cluster_size = 1; } else { @@ -6525,7 +6530,7 @@ library { } } node_def { - name: "ParallelInterleaveDataset/cycle_length" + name: "ExperimentalParallelInterleaveDataset/cycle_length" op: "Const" attr { key: "dtype" @@ -6546,7 +6551,7 @@ library { } } node_def { - name: "ParallelInterleaveDataset/block_length" + name: "ExperimentalParallelInterleaveDataset/block_length" op: "Const" attr { key: "dtype" @@ -6567,7 +6572,7 @@ library { } } node_def { - name: "ParallelInterleaveDataset/sloppy" + name: "ExperimentalParallelInterleaveDataset/sloppy" op: "Const" attr { key: "dtype" @@ -6588,7 +6593,7 @@ library { } } node_def { - name: "ParallelInterleaveDataset/buffer_output_elements" + name: "ExperimentalParallelInterleaveDataset/buffer_output_elements" op: "Const" attr { key: "dtype" @@ -6609,7 +6614,7 @@ library { } } node_def { - name: "ParallelInterleaveDataset/prefetch_input_elements" + name: "ExperimentalParallelInterleaveDataset/prefetch_input_elements" op: "Const" attr { key: "dtype" @@ -6630,14 +6635,14 @@ library { } } node_def { - name: "ParallelInterleaveDataset" - op: "ParallelInterleaveDataset" + name: "ExperimentalParallelInterleaveDataset" + op: "ExperimentalParallelInterleaveDataset" input: "RepeatDataset:handle:0" - input: "ParallelInterleaveDataset/cycle_length:output:0" - input: "ParallelInterleaveDataset/block_length:output:0" - input: "ParallelInterleaveDataset/sloppy:output:0" - input: "ParallelInterleaveDataset/buffer_output_elements:output:0" - input: "ParallelInterleaveDataset/prefetch_input_elements:output:0" + input: "ExperimentalParallelInterleaveDataset/cycle_length:output:0" + input: "ExperimentalParallelInterleaveDataset/block_length:output:0" + input: "ExperimentalParallelInterleaveDataset/sloppy:output:0" + input: "ExperimentalParallelInterleaveDataset/buffer_output_elements:output:0" + input: "ExperimentalParallelInterleaveDataset/prefetch_input_elements:output:0" attr { key: "Targuments" value { @@ -6737,7 +6742,7 @@ library { node_def { name: "ShuffleDataset_2" op: "ShuffleDataset" - input: "ParallelInterleaveDataset:handle:0" + input: "ExperimentalParallelInterleaveDataset:handle:0" input: "ShuffleDataset_2/buffer_size_1:output:0" input: "ShuffleDataset_2/seed_2:output:0" input: "ShuffleDataset_2/seed2_2:output:0" @@ -8739,14 +8744,65 @@ void TFE_TensorHandlePrintDebugString(TFE_TensorHandle* handle) { TF_DeleteStatus(status); } -TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status, - const char* errMsg) { +struct TFE_ExecuteOpNotification { + TFE_ExecuteOpNotification() : status(TF_NewStatus(), TF_DeleteStatus) {} + tensorflow::Notification n; + std::unique_ptr thread; + std::unique_ptr status; +}; + +TFE_ExecuteOpNotification* TFE_ExecuteOpInNewThread(TFE_Op* op, + TFE_TensorHandle** retvals, + int* num_retvals, + TF_Status* status) { + TFE_ExecuteOpNotification* n = new TFE_ExecuteOpNotification; + + n->thread.reset(op->operation.EagerContext()->TFEnv()->StartThread( + tensorflow::ThreadOptions(), "ExecuteOpThread", + [op, retvals, num_retvals, n]() { + TFE_Execute(op, retvals, num_retvals, n->status.get()); + n->n.Notify(); + })); + + return n; +} + +void TFE_ExecuteOpNotificationWaitAndDelete( + TFE_ExecuteOpNotification* notification, TF_Status* status) { + if (notification == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "Passed in notification is a nullptr."); + + return; + } + if (notification->thread == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "Passed in notification didn't start a thread correctly. Cleaning up " + "this notification. Please re-execute the operation to get a new " + "notification."); + + delete notification; + return; + } + + notification->n.WaitForNotification(); + + status->status = notification->status->status; + + delete notification; +} + +void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg) { status->status = tensorflow::errors::Internal(errMsg); } // This builder is used in the eager API to build a NodeDef. struct TF_AttrBuilder : public tensorflow::AttrBuilder { using tensorflow::AttrBuilder::AttrBuilder; + // The string buffers to make sure that any `attr_name` we pass into + // `builder->Set()` will outlive the subsequent + // `TF_AttrBuilderCheckCanRunOnDevice()` call(s) on the same `builder`. + std::set attr_names; }; TF_AttrBuilder* TF_NewAttrBuilder(const char* op_name) { @@ -8757,13 +8813,15 @@ void TF_DeleteAttrBuilder(TF_AttrBuilder* builder) { delete builder; } void TF_AttrBuilderSetType(TF_AttrBuilder* builder, const char* attr_name, TF_DataType value) { - builder->Set(attr_name, static_cast(value)); + auto iter = builder->attr_names.insert(attr_name).first; + builder->Set((*iter).c_str(), static_cast(value)); } void TF_AttrBuilderSetTypeList(TF_AttrBuilder* builder, const char* attr_name, const TF_DataType* values, int num_values) { + auto iter = builder->attr_names.insert(attr_name).first; builder->Set( - attr_name, + (*iter).c_str(), tensorflow::gtl::ArraySlice( reinterpret_cast(values), num_values)); } @@ -8800,3 +8858,31 @@ const char* TF_GetNumberAttrForOpListInput(const char* op_name, int input_index, // The returned string is owned by OpRegistry, so liveness is not a concern. return input_arg.number_attr().c_str(); } + +int TF_OpIsStateful(const char* op_type, TF_Status* status) { + const tensorflow::OpRegistrationData* op_reg_data; + status->status = + tensorflow::OpRegistry::Global()->LookUp(op_type, &op_reg_data); + if (!status->status.ok()) { + return 0; + } + return op_reg_data->op_def.is_stateful(); +} + +void TF_InitMain(const char* usage, int* argc, char*** argv) { + tensorflow::port::InitMain(usage, argc, argv); +} + +int TF_PickUnusedPortOrDie() { + return tensorflow::internal::PickUnusedPortOrDie(); +} + +TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType dtype_arg, + void* data, size_t len) { + auto dtype = static_cast(dtype_arg); + DCHECK(tensorflow::DataTypeCanUseMemcpy(dtype)); + + tensorflow::Tensor tensor(dtype, tensorflow::TensorShape({})); + std::memcpy(tensorflow::TensorCApi::Buffer(tensor)->data(), data, len); + return new TFE_TensorHandle(tensor, nullptr, nullptr); +} diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index 6639b0be72b..80c8bfe594c 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -180,6 +180,25 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor( TF_CAPI_EXPORT extern void TFE_TensorHandlePrintDebugString( TFE_TensorHandle* handle); +typedef struct TFE_ExecuteOpNotification TFE_ExecuteOpNotification; + +// Allows invoking a kernel asynchronously, and explicitly returns a +// notification that can be waited upon. This always executes the kernel in a +// new thread. +// 1. `retvals` and `num_retvals` can only be consumed after +// `TFE_ExecuteOp` returns successfully. They shouldn't be used +// if the return is unsuccessful +// 2. These new APIs cannot be used together with the TFE context level async +// support. +TF_CAPI_EXPORT extern TFE_ExecuteOpNotification* TFE_ExecuteOpInNewThread( + TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, + TF_Status* status); + +// Waits to complete the op execution, and cleans up the notification. +// Errors reported by op execution are set in `status`. +TF_CAPI_EXPORT extern void TFE_ExecuteOpNotificationWaitAndDelete( + TFE_ExecuteOpNotification* notification, TF_Status* status); + TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg); @@ -209,6 +228,24 @@ TF_CAPI_EXPORT extern void TF_AttrBuilderCheckCanRunOnDevice( TF_CAPI_EXPORT extern const char* TF_GetNumberAttrForOpListInput( const char* op_name, int input_index, TF_Status* status); +// Returns 1 if the op is stateful, 0 otherwise. The return value is undefined +// if the status is not ok. +TF_CAPI_EXPORT extern int TF_OpIsStateful(const char* op_type, + TF_Status* status); + +// Platform specific initialization routine. Very few platforms actually require +// this to be called. +TF_CAPI_EXPORT void TF_InitMain(const char* usage, int* argc, char*** argv); + +// Platform-specific implementation to return an unused port. (This should used +// in tests only.) +TF_CAPI_EXPORT int TF_PickUnusedPortOrDie(); + +// Fast path method that makes constructing a single scalar tensor require less +// overhead and copies. +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromScalar( + TF_DataType dtype, void* scalar, size_t len); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/c_api_experimental_test.cc b/tensorflow/c/c_api_experimental_test.cc index c6effd39697..daa7701b7fe 100644 --- a/tensorflow/c/c_api_experimental_test.cc +++ b/tensorflow/c/c_api_experimental_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/c/c_api_experimental.h" #include "tensorflow/c/c_test_util.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" @@ -162,5 +164,137 @@ protocol: "grpc" TF_DeleteStatus(status); } +TEST(CAPI_EXPERIMENTAL, IsStateful) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + int assign = TF_OpIsStateful("AssignAddVariableOp", status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + EXPECT_EQ(assign, 1); + int id = TF_OpIsStateful("Identity", status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + EXPECT_EQ(id, 0); +} + +TEST(CAPI_EXPERIMENTAL, TFE_ExecuteOpInNewThreadTest_Simple) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* m = TestMatrixTensorHandle(); + + TFE_Op* matmul_op = MatMulOp(ctx, m, m); + + TFE_TensorHandle* retvals[1] = {nullptr}; + int num_retvals = 1; + + auto* r = + TFE_ExecuteOpInNewThread(matmul_op, &retvals[0], &num_retvals, status); + + TFE_ExecuteOpNotificationWaitAndDelete(r, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + float product[4] = {0}; + EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); + memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(7, product[0]); + EXPECT_EQ(10, product[1]); + EXPECT_EQ(15, product[2]); + EXPECT_EQ(22, product[3]); + + TFE_DeleteOp(matmul_op); + TFE_DeleteTensorHandle(m); + + TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteContext(ctx); + TF_DeleteStatus(status); +} + +// Perform a send/recv test. Recv blocks, so they need to be executed +// asynchronously. +TEST(CAPI_EXPERIMENTAL, TFE_ExecuteOpInNewThreadTest_Blocking) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + // Returns a 2x2 float32 Tensor on the CPU, with data 1., 2., 3., 4. + TFE_TensorHandle* m = TestMatrixTensorHandle(); + + // Build a send op. + TFE_Op* send_op = TFE_NewOp(ctx, "_Send", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(send_op, m, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + string tensor_name = "Tensor"; + TFE_OpSetAttrType(send_op, "T", TF_FLOAT); + TFE_OpSetAttrString(send_op, "tensor_name", tensor_name.c_str(), + tensor_name.size()); + string send_device = "/job:localhost/replica:0/task:0/device:CPU:0"; + TFE_OpSetAttrString(send_op, "send_device", send_device.c_str(), + send_device.size()); + TFE_OpSetAttrInt(send_op, "send_device_incarnation", 1234); + string recv_device = "/job:localhost/replica:0/task:0/device:CPU:0"; + TFE_OpSetAttrString(send_op, "recv_device", recv_device.c_str(), + recv_device.size()); + TFE_OpSetAttrBool(send_op, "client_terminated", true); + + // Build a recv op. + TFE_Op* recv_op = TFE_NewOp(ctx, "_Recv", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_OpSetAttrType(recv_op, "tensor_type", TF_FLOAT); + TFE_OpSetAttrString(recv_op, "tensor_name", tensor_name.c_str(), + tensor_name.size()); + TFE_OpSetAttrString(recv_op, "send_device", send_device.c_str(), + send_device.size()); + TFE_OpSetAttrInt(recv_op, "send_device_incarnation", 1234); + TFE_OpSetAttrString(recv_op, "recv_device", recv_device.c_str(), + recv_device.size()); + TFE_OpSetAttrBool(recv_op, "client_terminated", true); + + TFE_TensorHandle* send_retvals; + int send_num_retvals = 0; + auto* send_result = TFE_ExecuteOpInNewThread(send_op, &send_retvals, + &send_num_retvals, status); + + TFE_TensorHandle* recv_retvals[1] = {nullptr}; + int recv_num_retvals = 1; + auto* recv_result = TFE_ExecuteOpInNewThread(recv_op, &recv_retvals[0], + &recv_num_retvals, status); + + TFE_ExecuteOpNotificationWaitAndDelete(send_result, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_ExecuteOpNotificationWaitAndDelete(recv_result, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_Tensor* t = TFE_TensorHandleResolve(recv_retvals[0], status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + float product[4] = {0}; + EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); + memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(1, product[0]); + EXPECT_EQ(2, product[1]); + EXPECT_EQ(3, product[2]); + EXPECT_EQ(4, product[3]); + + TFE_DeleteOp(send_op); + TFE_DeleteOp(recv_op); + TFE_DeleteTensorHandle(m); + + TFE_DeleteTensorHandle(recv_retvals[0]); + TFE_DeleteContext(ctx); + TF_DeleteStatus(status); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc index f68f8a3e90a..28b9f8df9c8 100644 --- a/tensorflow/c/c_api_function.cc +++ b/tensorflow/c/c_api_function.cc @@ -392,26 +392,26 @@ Status ProcessInputs( EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) { input_tensors->reserve(ninputs); for (int i = 0; i < ninputs; ++i) { - const Node& node = inputs[i].oper->node; + Node* node = &inputs[i].oper->node; int idx = inputs[i].index; TF_RETURN_WITH_CONTEXT_IF_ERROR( - fn_body->graph.IsValidOutputTensor(&node, idx), + fn_body->graph.IsValidOutputTensor(node, idx), "Encountered while processing input ", i, " into function '", fn_name, "'"); - TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(&node, idx), + TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(node, idx), "Encountered while processing input ", i, " into function '", fn_name, "'"); - input_tensors->emplace_back(&node, idx); + input_tensors->emplace_back(node, idx); - const auto& iter = input_nodes->find(&node); + const auto& iter = input_nodes->find(node); if (iter == input_nodes->end()) { - input_nodes->insert({&node, {idx}}); + input_nodes->insert({node, {idx}}); } else { auto& indices = iter->second; if (std::find(indices.begin(), indices.end(), idx) != indices.end()) { - return InvalidArgument("TF_Output ", node.name(), ":", idx, + return InvalidArgument("TF_Output ", node->name(), ":", idx, " appears more than once in the input list"); } indices.push_back(idx); @@ -428,16 +428,16 @@ Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name, EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) { output_tensors->reserve(noutputs); for (int i = 0; i < noutputs; ++i) { - const Node& node = outputs[i].oper->node; + Node* node = &outputs[i].oper->node; int idx = outputs[i].index; TF_RETURN_WITH_CONTEXT_IF_ERROR( - fn_body->graph.IsValidOutputTensor(&node, idx), + fn_body->graph.IsValidOutputTensor(node, idx), "Encountered while processing output ", i, " from function '", fn_name, "'"); - TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(&node, idx), + TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(node, idx), "Encountered while creating function '", fn_name, "'"); - output_tensors->emplace_back(&node, idx); + output_tensors->emplace_back(node, idx); } return Status::OK(); } diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index ba3d8533db7..c34a84fcfee 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -50,6 +50,7 @@ tf_cuda_library( ], "//conditions:default": [], }) + [ + "@com_google_absl//absl/memory", "//tensorflow/core/common_runtime/eager:eager_operation", "//tensorflow/core/distributed_runtime/eager:eager_client", "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", @@ -143,6 +144,7 @@ tf_cuda_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 408277468d7..027d752f420 100755 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -21,9 +21,11 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h" +#include "tensorflow/core/platform/host_info.h" #ifdef TENSORFLOW_EAGER_USE_XLA #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #endif // TENSORFLOW_EAGER_USE_XLA @@ -79,7 +81,7 @@ tensorflow::Status GetAllRemoteDevices( const std::vector& remote_workers, tensorflow::WorkerCacheInterface* worker_cache, std::unique_ptr* device_mgr) { - std::vector remote_devices; + std::vector> remote_devices; tensorflow::Status status; // TODO(nareshmodi) do this in parallel instead of serially. for (const string& remote_worker : remote_workers) { @@ -92,7 +94,7 @@ tensorflow::Status GetAllRemoteDevices( status = s; if (s.ok()) { for (tensorflow::Device* d : *devices) { - remote_devices.push_back(d); + remote_devices.emplace_back(d); } } n.Notify(); @@ -100,7 +102,7 @@ tensorflow::Status GetAllRemoteDevices( n.WaitForNotification(); } std::unique_ptr remote_device_mgr( - new tensorflow::DeviceMgr(remote_devices)); + new tensorflow::DeviceMgr(std::move(remote_devices))); TF_RETURN_IF_ERROR(status); @@ -261,13 +263,13 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx, void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { - std::vector devices; + std::vector> devices; status->status = tensorflow::DeviceFactory::AddDevices( opts->session_options.options, "/job:localhost/replica:0/task:0", &devices); if (!status->status.ok()) return nullptr; std::unique_ptr device_mgr( - new tensorflow::DeviceMgr(devices)); + new tensorflow::DeviceMgr(std::move(devices))); tensorflow::Rendezvous* r = new tensorflow::IntraProcessRendezvous(device_mgr.get()); @@ -409,6 +411,18 @@ const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) { : d->name().c_str(); } +const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h, + TF_Status* status) { + if (h == nullptr || h->handle == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "The passed in handle is a nullptr"); + return nullptr; + } + tensorflow::Device* d = h->handle->device(); + return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0" + : d->name().c_str(); +} + TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor( TFE_TensorHandle* h, TF_Status* status) { if (h == nullptr || h->handle == nullptr) { @@ -458,13 +472,20 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, TF_Status* status) { const char* name = op_or_function_name; // Shorthand const tensorflow::AttrTypeMap* types; - status->status = tensorflow::AttrTypeMapForOp(name, &types); - if (status->status.ok()) return new TFE_Op(ctx, name, types); - if (TF_GetCode(status) == TF_NOT_FOUND) { - if (ctx->context.FindFunctionByName(name)) { - status->status = tensorflow::Status::OK(); - return new TFE_Op(ctx, name, nullptr); + bool is_function = false; + status->status = tensorflow::AttrTypeMapForOp(name, &types, &is_function); + if (status->status.ok()) { + if (is_function && !ctx->context.FindFunctionByName(name)) { + status->status = tensorflow::errors::NotFound( + "'", name, + "' is neither a type of a primitive operation nor a name " + "of a function registered in binary running on ", + tensorflow::port::Hostname(), + ". Make sure the operation or function is " + "registered in the binary running in this process."); + return nullptr; } + return new TFE_Op(ctx, name, is_function, types); } return nullptr; } @@ -497,12 +518,6 @@ void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) { TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, unsigned char* is_list, TF_Status* status) { TF_AttrType ret; - if (op->operation.is_function()) { - status->status = tensorflow::errors::Unimplemented( - "TODO(apassos): Support for attributes for TensorFlow functions is not " - "ready yet."); - return TF_ATTR_INT; // The compiler requires that we return something. - } status->status = tensorflow::AttrTypeByName(*op->operation.AttrTypes(), attr_name, &ret, is_list); return ret; diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index b2454d87220..8d6c8d958d5 100755 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -169,10 +169,33 @@ TF_CAPI_EXPORT extern int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index, TF_Status* status); + +// Returns the device of the operation that produced `h`. +// If `h` was produced by a copy, returns the destination device of +// the copy. Note that returned device name is not always the device +// holding the tensor handle's memory. If you want the latter, use +// TFE_TensorHandleBackingDeviceName. +// This function will block till the operation that produces `h` has completed. +// +// Device on which the kernel of the operation that produced `h` ran. +// +// If `h` was produced by a copy, returns the destination device of +// the copy. +// +// Note that returned device name is not always the device that owns the memory +// that backs the tensor handle. For the latter see +// TFE_TensorHandleBackingDeviceName. +// // This function will block till the operation that produces `h` has completed. TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName( TFE_TensorHandle* h, TF_Status* status); +// Returns the name of the device in whose memory `h` resides. +// +// This function will block till the operation that produces `h` has completed. +TF_CAPI_EXPORT extern const char* TFE_TensorHandleBackingDeviceName( + TFE_TensorHandle* h, TF_Status* status); + // Return a pointer to a new TFE_TensorHandle that shares the underlying tensor // with `h`. On success, `status` is set to OK. On failure, `status` reflects // the error and a nullptr is returned. diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index fa1b22e3af4..67bc1bcd246 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -93,10 +93,9 @@ struct TFE_TensorDebugInfo { }; struct TFE_Op { - // t is NULL iff the TFE_Op corresponds to a TensorFlow function instead of a - // primitive operation. - TFE_Op(TFE_Context* ctx, const char* op, const tensorflow::AttrTypeMap* t) - : operation(&ctx->context, op, t) {} + TFE_Op(TFE_Context* ctx, const char* op, bool is_function, + const tensorflow::AttrTypeMap* t) + : operation(&ctx->context, op, is_function, t) {} tensorflow::EagerOperation operation; }; diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 55331022b9d..6b39b79ee82 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/c/eager/c_api.h" #include +#include "absl/strings/match.h" #include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" #include "tensorflow/core/framework/function.pb.h" @@ -589,9 +590,22 @@ void TensorHandleCopyBetweenTwoGPUDevices(bool async) { TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); const int num_devices = TF_DeviceListCount(devices); + bool has_gpu0 = false; + bool has_gpu1 = false; + for (int i = 0; i < num_devices; ++i) { + const char* dev = TF_DeviceListName(devices, i, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + string device_name(dev); + if (device_name.find("GPU:0") != string::npos) { + has_gpu0 = true; + } + if (device_name.find("GPU:1") != string::npos) { + has_gpu1 = true; + } + } const char* kCPUDevice = "CPU:0"; - if (num_devices < 3) { + if (!has_gpu0 || !has_gpu1) { TF_DeleteDeviceList(devices); TF_DeleteTensor(t); TFE_DeleteTensorHandle(hcpu); @@ -781,6 +795,14 @@ TEST(CAPI, TensorHandleNullptr) { TF_SetStatus(status.get(), TF_OK, ""); + device_name = TFE_TensorHandleBackingDeviceName(h, status.get()); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); + ASSERT_EQ(device_name, nullptr); + ASSERT_EQ("The passed in handle is a nullptr", + string(TF_Message(status.get()))); + + TF_SetStatus(status.get(), TF_OK, ""); + int num_dims = TFE_TensorHandleNumDims(h, status.get()); ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); ASSERT_EQ(num_dims, -1); @@ -796,6 +818,62 @@ TEST(CAPI, TensorHandleNullptr) { string(TF_Message(status.get()))); } +TEST(CAPI, TensorHandleDevices) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status.get()); + TFE_DeleteContextOptions(opts); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + TFE_TensorHandle* hcpu = TestMatrixTensorHandle(); + const char* device_name = TFE_TensorHandleDeviceName(hcpu, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + ASSERT_TRUE(absl::StrContains(device_name, "CPU:0")) << device_name; + const char* backing_device_name = + TFE_TensorHandleBackingDeviceName(hcpu, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + ASSERT_TRUE(absl::StrContains(backing_device_name, "CPU:0")) + << backing_device_name; + + // Disable the test if no GPU is present. + string gpu_device_name; + if (GetDeviceName(ctx, &gpu_device_name, "GPU")) { + TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice( + hcpu, ctx, gpu_device_name.c_str(), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + TFE_Op* shape_op = ShapeOp(ctx, hgpu); + TFE_OpSetDevice(shape_op, gpu_device_name.c_str(), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_TensorHandle* retvals[1]; + int num_retvals = 1; + TFE_Execute(shape_op, &retvals[0], &num_retvals, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + // .device of shape is GPU since the op is executed on GPU + device_name = TFE_TensorHandleDeviceName(retvals[0], status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + ASSERT_TRUE(absl::StrContains(device_name, "GPU:0")) << device_name; + + // .backing_device of shape is CPU since the tensor is backed by CPU + backing_device_name = + TFE_TensorHandleBackingDeviceName(retvals[0], status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + ASSERT_TRUE(absl::StrContains(backing_device_name, "CPU:0")) + << backing_device_name; + + TFE_DeleteOp(shape_op); + TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteTensorHandle(hgpu); + } + + TFE_DeleteTensorHandle(hcpu); + TFE_ContextAsyncWait(ctx, status.get()); + EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_DeleteContext(ctx); +} + void Execute_MatMul_CPU(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); diff --git a/tensorflow/c/eager/c_api_test_util.cc b/tensorflow/c/eager/c_api_test_util.cc index 008f088c2dc..bd38127d50c 100644 --- a/tensorflow/c/eager/c_api_test_util.cc +++ b/tensorflow/c/eager/c_api_test_util.cc @@ -104,6 +104,19 @@ TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) { return op; } +TFE_Op* ShapeOp(TFE_Context* ctx, TFE_TensorHandle* a) { + TF_Status* status = TF_NewStatus(); + + TFE_Op* op = TFE_NewOp(ctx, "Shape", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, a, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a)); + + return op; +} + TFE_TensorHandle* TestAxisTensorHandle() { int64_t dims[] = {1}; int data[] = {1}; diff --git a/tensorflow/c/eager/c_api_test_util.h b/tensorflow/c/eager/c_api_test_util.h index 474cae67c89..75ef9459e93 100644 --- a/tensorflow/c/eager/c_api_test_util.h +++ b/tensorflow/c/eager/c_api_test_util.h @@ -37,6 +37,9 @@ TFE_TensorHandle* TestMatrixTensorHandle3X2(); // Return a matmul op multiplying `a` by `b`. TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b); +// Return a shape op fetching the shape of `a`. +TFE_Op* ShapeOp(TFE_Context* ctx, TFE_TensorHandle* a); + // Return an 1-D INT32 tensor containing a single value 1. TFE_TensorHandle* TestAxisTensorHandle(); diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index 5ba55a203ff..5c11f51e874 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -141,8 +141,9 @@ class GradientTape { // null. The result is populated with one tensor per target element. Status ComputeGradient( const VSpace& vspace, - gtl::ArraySlice target_tensor_ids, - gtl::ArraySlice source_tensor_id, + const gtl::ArraySlice target_tensor_ids, + const gtl::ArraySlice source_tensor_ids, + const gtl::FlatMap sources_that_are_targets, gtl::ArraySlice output_gradients, std::vector* result); @@ -396,6 +397,7 @@ template Status InitialGradients( const VSpace& vspace, gtl::ArraySlice target_tensor_ids, + gtl::FlatMap sources_that_are_targets, gtl::ArraySlice output_gradients, const TensorTape& tensor_tape, const OpTape& op_tape, gtl::FlatMap>* result) { @@ -425,8 +427,13 @@ Status InitialGradients( "none of operations outputs match expected tensor"); } } else { - // No record of the target tensor found on the tape, so no gradient - // needs to be computed from it. Do nothing. + // This target tensor was not generated by any operation recorded on + // the tape, so no gradient needs to be computed from it unless this + // target is also a source. + auto source_tensor = sources_that_are_targets.find(id); + if (source_tensor != sources_that_are_targets.end()) { + (*result)[id].push_back(vspace.Ones(source_tensor->second)); + } } } else { (*result)[id].push_back(output_gradients[i]); @@ -467,8 +474,9 @@ constexpr int kMinAggregateBytes = 128 * 1024 * 1024; template Status GradientTape::ComputeGradient( const VSpace& vspace, - gtl::ArraySlice target_tensor_ids, - gtl::ArraySlice source_tensor_ids, + const gtl::ArraySlice target_tensor_ids, + const gtl::ArraySlice source_tensor_ids, + const gtl::FlatMap sources_that_are_targets, gtl::ArraySlice output_gradients, std::vector* result) { gtl::FlatSet sources_set(source_tensor_ids.begin(), @@ -478,7 +486,8 @@ Status GradientTape::ComputeGradient( std::vector op_stack = InitialStack(state.op_tape, state.op_missing_tensor); gtl::FlatMap> gradients; - Status s = InitialGradients(vspace, target_tensor_ids, output_gradients, + Status s = InitialGradients(vspace, target_tensor_ids, + sources_that_are_targets, output_gradients, tensor_tape_, state.op_tape, &gradients); auto cleanup = [this, &state]() { if (!persistent_) { diff --git a/tensorflow/c/kernels.cc b/tensorflow/c/kernels.cc new file mode 100644 index 00000000000..3caa5bcb038 --- /dev/null +++ b/tensorflow/c/kernels.cc @@ -0,0 +1,143 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/c/c_api_internal.h" +#include "tensorflow/c/kernels.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" + +// This file forms the basis of a stable ABI for third-party kernel +// implementations. It is crucial that changes to this file are made cautiously +// and with a focus on maintaining both source and binary compatibility. + +struct TF_KernelBuilder { + ::tensorflow::KernelDefBuilder* cc_builder; + + void* (*create_function)(TF_OpKernelConstruction*); + void (*compute_function)(void*, TF_OpKernelContext*); + void (*delete_function)(void*); +}; + +TF_KernelBuilder* TF_NewKernelBuilder( + const char* op_name, const char* device_name, + void* (*create_func)(TF_OpKernelConstruction*), + void (*compute_func)(void*, TF_OpKernelContext*), + void (*delete_func)(void*)) { + TF_KernelBuilder* result = new TF_KernelBuilder; + result->cc_builder = new ::tensorflow::KernelDefBuilder(op_name); + result->cc_builder->Device(device_name); + result->create_function = create_func; + result->compute_function = compute_func; + result->delete_function = delete_func; + return result; +} + +void TF_DeleteKernelBuilder(TF_KernelBuilder* builder) { + DCHECK_NE(builder, nullptr); + delete builder->cc_builder; + delete builder; +} + +namespace tensorflow { +namespace { + +// An OpKernel whose methods delegate to C function pointers. +class COpKernel : public OpKernel { + public: + explicit COpKernel(OpKernelConstruction* ctx, + void* (*create_func)(TF_OpKernelConstruction*), + void (*compute_func)(void*, TF_OpKernelContext*), + void (*delete_func)(void*)) + : OpKernel(ctx), compute_func_(compute_func), delete_func_(delete_func) { + if (create_func != nullptr) { + c_kernel_ = + (*create_func)(reinterpret_cast(ctx)); + } else { + c_kernel_ = nullptr; + } + } + + void Compute(OpKernelContext* ctx) override { + (*compute_func_)(c_kernel_, reinterpret_cast(ctx)); + } + + ~COpKernel() override { + if (delete_func_ != nullptr) { + (*delete_func_)(c_kernel_); + } + } + + private: + void (*compute_func_)(void*, TF_OpKernelContext* context); + void (*delete_func_)(void*); + void* c_kernel_; +}; + +// A KernelFactory that returns COpKernel instances. +class KernelBuilderFactory + : public ::tensorflow::kernel_factory::OpKernelFactory { + public: + explicit KernelBuilderFactory(TF_KernelBuilder* builder) + : builder_(builder) {} + ::tensorflow::OpKernel* Create( + ::tensorflow::OpKernelConstruction* context) override { + return new ::tensorflow::COpKernel(context, builder_->create_function, + builder_->compute_function, + builder_->delete_function); + } + ~KernelBuilderFactory() override { TF_DeleteKernelBuilder(builder_); } + + private: + TF_KernelBuilder* builder_; +}; +} // namespace +} // namespace tensorflow + +void TF_RegisterKernelBuilder(const char* name, TF_KernelBuilder* builder, + TF_Status* status) { + using tensorflow::register_kernel::Name; + + tensorflow::kernel_factory::OpKernelRegistrar( + builder->cc_builder->Build(), name, + absl::make_unique(builder)); + + TF_SetStatus(status, TF_OK, ""); +} + +int TF_NumInputs(TF_OpKernelContext* ctx) { + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); + return cc_ctx->num_inputs(); +} + +int TF_NumOutputs(TF_OpKernelContext* ctx) { + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); + return cc_ctx->num_outputs(); +} + +void TF_GetInput(TF_OpKernelContext* ctx, int i, TF_Tensor** tensor, + TF_Status* status) { + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); + if (i < 0 || i >= cc_ctx->num_inputs()) { + TF_SetStatus(status, TF_OUT_OF_RANGE, "input index out of range"); + return; + } + const ::tensorflow::Tensor& cc_tensor(cc_ctx->input(i)); + TF_Tensor* result = ::tensorflow::TF_TensorFromTensor(cc_tensor, status); + if (TF_GetCode(status) == TF_OK) { + *tensor = result; + } +} diff --git a/tensorflow/c/kernels.h b/tensorflow/c/kernels.h new file mode 100644 index 00000000000..d7778829bca --- /dev/null +++ b/tensorflow/c/kernels.h @@ -0,0 +1,110 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_KERNELS_H_ +#define TENSORFLOW_C_KERNELS_H_ + +#include "tensorflow/c/c_api.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// -------------------------------------------------------------------------- +// C API for TensorFlow Kernels. +// +// This API allows developers to register custom kernel implementations for +// TensorFlow. +// +// See c_api.h header comments for a discussion about API conventions. +// +// Users wishing to extend TensorFlow with new kernels will call +// `TF_NewKernelBuilder`. The resulting kernel builder can be registered with +// `TF_RegisterKernelBuilder`, which will allow TF to construct user-provided +// kernels when necessary. + +struct TF_KernelBuilder; +struct TF_OpKernelConstruction; +struct TF_OpKernelContext; + +// Allocates a new kernel builder and returns a pointer to it. +// +// If non-null, TensorFlow will call create_func when it needs to instantiate +// the kernel. The pointer returned by create_func will be passed to +// compute_func and delete_func, thereby functioning as a "this" pointer for +// referring to kernel instances. +// +// The TF_OpKernelConstruction pointer passed to create_func is owned by +// TensorFlow and will be deleted once create_func returns. It must not be used +// after this. +// +// When TensorFlow needs to perform a computation with this kernel, it will +// call compute_func. This function will receive the pointer returned by +// create_func (or null if no create_func was provided), along with the inputs +// to the computation. +// +// The TF_OpKernelContext pointer received by compute_func is owned by +// TensorFlow and will be deleted once compute_func returns. It must not be used +// after this. +// +// Finally, when TensorFlow no longer needs the kernel, it will call +// delete_func if one is provided. This function will receive the pointer +// returned in `create_func` or nullptr if no `create_func` was provided. +// +// The caller should pass the result of this function to +// TF_RegisterKernelBuilder, which will take ownership of the pointer. If, for +// some reason, the kernel builder will not be registered, the caller should +// delete it with TF_DeleteKernelBuilder. +TF_CAPI_EXPORT extern TF_KernelBuilder* TF_NewKernelBuilder( + const char* op_name, const char* device_name, + void* (*create_func)(TF_OpKernelConstruction*), + void (*compute_func)(void*, TF_OpKernelContext*), + void (*delete_func)(void*)); + +// Register the given kernel builder with the TensorFlow runtime. If +// registration fails, the given status will be populated. +// +// This call takes ownership of the `builder` pointer. +TF_CAPI_EXPORT extern void TF_RegisterKernelBuilder(const char* kernel_name, + TF_KernelBuilder* builder, + TF_Status* status); + +// Deletes the given TF_KernelBuilder. This should be called only if the kernel +// builder is not registered with TensorFlow via TF_RegisterKernelBuilder. +TF_CAPI_EXPORT extern void TF_DeleteKernelBuilder(TF_KernelBuilder* builder); + +// -------------------------------------------------------------------------- +// OpKernelContext routines + +// TF_NumInputs returns the number of inputs available in ctx. +TF_CAPI_EXPORT extern int TF_NumInputs(TF_OpKernelContext* ctx); + +// TF_NumOutputs returns the number of outputs to be placed in *ctx by the +// kernel. +TF_CAPI_EXPORT extern int TF_NumOutputs(TF_OpKernelContext* ctx); + +// Retrieves the ith input from ctx. If TF_GetCode(status) is TF_OK, *tensor is +// populated and its ownership is passed to the caller. In any other case, +// *tensor is not modified. +// +// If i < 0 or i >= TF_NumInputs(ctx), *status is set to TF_OUT_OF_RANGE. +TF_CAPI_EXPORT extern void TF_GetInput(TF_OpKernelContext* ctx, int i, + TF_Tensor** tensor, TF_Status* status); + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#endif // TENSORFLOW_C_KERNELS_H_ diff --git a/tensorflow/c/kernels_test.cc b/tensorflow/c/kernels_test.cc new file mode 100644 index 00000000000..80bf12c0969 --- /dev/null +++ b/tensorflow/c/kernels_test.cc @@ -0,0 +1,194 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/kernels.h" + +#include "tensorflow/c/c_api.h" +#include "tensorflow/core/framework/kernel_def.pb.h" +#include "tensorflow/core/framework/node_def.pb_text.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +struct MyCustomKernel { + bool created; + bool compute_called; +}; + +static bool delete_called = false; + +static void* MyCreateFunc(TF_OpKernelConstruction* ctx) { + struct MyCustomKernel* s = new struct MyCustomKernel; + s->created = true; + s->compute_called = false; + return s; +} + +static void MyComputeFunc(void* kernel, TF_OpKernelContext* ctx) { + struct MyCustomKernel* s = static_cast(kernel); + s->compute_called = true; +} + +static void MyDeleteFunc(void* kernel) { + struct MyCustomKernel* s = static_cast(kernel); + EXPECT_TRUE(s->created); + EXPECT_TRUE(s->compute_called); + delete_called = true; + delete s; +} + +namespace tensorflow { + +static std::unique_ptr GetFakeKernel(const char* device_name, + const char* op_name, + Status* status) { + NodeDef def; + def.set_op(op_name); + def.set_device(device_name); + def.add_input("input1"); + def.add_input("input2"); + return CreateOpKernel(DeviceType(device_name), nullptr, nullptr, def, 1, + status); +} + +// Tests registration of a single C kernel and checks that calls through the +// C/C++ boundary are being made. +TEST(TestKernel, TestRegisterKernelBuilder) { + const char* kernel_name = "SomeKernelName"; + const char* op_name = "FooOp"; + const char* device_name = "FakeDeviceName1"; + + REGISTER_OP(op_name) + .Input("input1: double") + .Input("input2: uint8") + .Output("output1: uint8"); + + TF_KernelBuilder* builder = TF_NewKernelBuilder( + op_name, device_name, &MyCreateFunc, &MyComputeFunc, &MyDeleteFunc); + + { + TF_Status* status = TF_NewStatus(); + TF_RegisterKernelBuilder(kernel_name, builder, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)); + TF_Buffer* buf = TF_GetRegisteredKernelsForOp(op_name, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)); + KernelList list; + list.ParseFromArray(buf->data, buf->length); + ASSERT_EQ(1, list.kernel_size()); + ASSERT_EQ(device_name, list.kernel(0).device_type()); + TF_DeleteBuffer(buf); + TF_DeleteStatus(status); + } + + { + Status status; + std::unique_ptr kernel = + GetFakeKernel(device_name, op_name, &status); + TF_EXPECT_OK(status); + ASSERT_NE(nullptr, kernel.get()); + kernel->Compute(nullptr); + } + + ASSERT_TRUE(delete_called); +} + +class DummyDevice : public DeviceBase { + public: + DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {} + bool RequiresRecordingAccessedTensors() const override { return save_; } + Allocator* GetAllocator(AllocatorAttributes /*attr*/) override { + return cpu_allocator(); + } + + private: + bool save_; +}; + +TEST(TestKernel, TestInputAndOutputCount) { + const char* kernel_name = "InputOutputCounterKernel"; + const char* op_name = "BarOp"; + const char* device_name = "FakeDeviceName2"; + + REGISTER_OP(op_name) + .Input("input1: double") + .Input("input2: uint8") + .Output("output1: uint8"); + + static int num_inputs = 0; + static int num_outputs = 0; + + // A kernel whose Compute function has a side-effect of updating num_inputs + // and num_outputs. Various functions on TF_OpKernelContext are also + // exercised. + auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) { + num_inputs = TF_NumInputs(ctx); + num_outputs = TF_NumOutputs(ctx); + + TF_Tensor* input = nullptr; + TF_Status* s = TF_NewStatus(); + TF_GetInput(ctx, 0, &input, s); + EXPECT_EQ(TF_OK, TF_GetCode(s)) << "Failed to get input: " << TF_Message(s); + EXPECT_EQ(123, *static_cast(TF_TensorData(input))); + TF_GetInput(ctx, -1, &input, s); + EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s)); + TF_GetInput(ctx, 3, &input, s); + EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s)); + TF_DeleteStatus(s); + if (input != nullptr) { + TF_DeleteTensor(input); + } + }; + + TF_KernelBuilder* builder = TF_NewKernelBuilder(op_name, device_name, nullptr, + my_compute_func, nullptr); + + { + TF_Status* status = TF_NewStatus(); + TF_RegisterKernelBuilder(kernel_name, builder, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)); + TF_DeleteStatus(status); + } + + { + OpKernelContext::Params p; + DummyDevice dummy_device(nullptr, false); + p.device = &dummy_device; + + Tensor t(tensorflow::uint8(123)); + + gtl::InlinedVector inputs; + // Simulate 2 inputs + inputs.emplace_back(&t); + inputs.emplace_back(); + p.inputs = &inputs; + + Status status; + std::unique_ptr kernel = + GetFakeKernel(device_name, op_name, &status); + TF_EXPECT_OK(status); + ASSERT_NE(nullptr, kernel.get()); + + p.op_kernel = kernel.get(); + OpKernelContext ctx(&p); + kernel->Compute(&ctx); + + ASSERT_EQ(2, num_inputs); + ASSERT_EQ(1, num_outputs); + } +} + +} // namespace tensorflow diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index 247236b760d..98d83933322 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -160,4 +160,17 @@ void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto, ic->set_output_handle_shapes_and_types(output.index, shapes_and_types); } +void AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst, + TF_Status* status) { + mutex_lock l(graph->mu); + status->status = graph->graph.AddWhileInputHack(&new_src.oper->node, + new_src.index, &dst->node); + if (status->status.ok()) { + // This modification only updates the destination node for + // the purposes of running this graph in a session. Thus, we don't + // record the source node as being modified. + RecordMutation(graph, *dst, "adding input tensor"); + } +} + } // namespace tensorflow diff --git a/tensorflow/c/python_api.h b/tensorflow/c/python_api.h index 5cce84020bc..44779ca6561 100644 --- a/tensorflow/c/python_api.h +++ b/tensorflow/c/python_api.h @@ -34,6 +34,7 @@ void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device); +// Updates 'dst' to consume 'new_src'. void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, TF_Status* status); @@ -65,6 +66,13 @@ std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output); // because I couldn't get SWIG to work otherwise. void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto, size_t proto_len, TF_Status* status); + +// This method is used to add a new input edge to 'dst', which must be a While +// op. The While op's "T" attribute must have already been updated to include +// the new edge. This is used to construct tf.while_loop gradients. +void AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst, + TF_Status* status); + } // namespace tensorflow #endif // TENSORFLOW_C_PYTHON_API_H_ diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index 83353b79f72..a09becc49b1 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -489,6 +489,7 @@ tf_gen_op_wrappers_cc( "image_ops", "io_ops", "linalg_ops", + "list_ops", "logging_ops", "lookup_ops", "manip_ops", diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index 3d3895c8fa8..52345a376cc 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -133,5 +133,6 @@ filegroup( "testdata/half_plus_two_pbtxt/**", "testdata/half_plus_two_main_op/**", "testdata/half_plus_two/**", + "testdata/half_plus_two_v2/**", ]), ) diff --git a/tensorflow/cc/saved_model/constants.h b/tensorflow/cc/saved_model/constants.h index 645a3f101d1..6f00dc324bd 100644 --- a/tensorflow/cc/saved_model/constants.h +++ b/tensorflow/cc/saved_model/constants.h @@ -33,10 +33,10 @@ constexpr char kSavedModelFilenamePb[] = "saved_model.pb"; /// SavedModel text format proto filename. constexpr char kSavedModelFilenamePbTxt[] = "saved_model.pbtxt"; -/// SavedModel legacy init op key. +/// SavedModel legacy init op collection key. Used in v1 SavedModels. constexpr char kSavedModelLegacyInitOpKey[] = "legacy_init_op"; -/// SavedModel main op key. +/// SavedModel main op collection key. Used in v1 SavedModels. constexpr char kSavedModelMainOpKey[] = "saved_model_main_op"; /// Directory in which to save the SavedModel variables. @@ -45,6 +45,11 @@ constexpr char kSavedModelVariablesDirectory[] = "variables"; /// SavedModel variables filename. constexpr char kSavedModelVariablesFilename[] = "variables"; +/// SavedModel SignatureDef keys for the initialization and train ops. Used in +/// V2 SavedModels. +constexpr char kSavedModelInitOpSignatureKey[] = "__saved_model_init_op"; +constexpr char kSavedModelTrainOpSignatureKey[] = "__saved_model_train_op"; + } // namespace tensorflow #endif // TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_ diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index c6abe2f41b9..85d3dd01fa5 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -122,38 +122,58 @@ Status RunOnce(const RunOptions& run_options, return run_status; } -bool HasMainOp(const MetaGraphDef& meta_graph_def) { - const auto& collection_def_map = meta_graph_def.collection_def(); - if (collection_def_map.find(kSavedModelMainOpKey) != - collection_def_map.end()) { - return true; - } - return false; -} - -Status RunMainOp(const RunOptions& run_options, const string& export_dir, +// RunInitOp will return OK if the initialization op was run successfully. +// An empty init_op_name indicates that there are no init ops to run. +Status RunInitOp(const RunOptions& run_options, const string& export_dir, const MetaGraphDef& meta_graph_def, const std::vector& asset_file_defs, - Session* session, const string& main_op_key) { - LOG(INFO) << "Running MainOp with key " << main_op_key - << " on SavedModel bundle."; - const auto& collection_def_map = meta_graph_def.collection_def(); - const auto main_op_it = collection_def_map.find(main_op_key); - if (main_op_it != collection_def_map.end()) { - if (main_op_it->second.node_list().value_size() != 1) { - return errors::FailedPrecondition( - strings::StrCat("Expected exactly one main op in : ", export_dir)); - } + Session* session, const string& init_op_name) { + if (!init_op_name.empty()) { + LOG(INFO) << "Running initialization op on SavedModel bundle."; std::vector> inputs; AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs); RunMetadata run_metadata; - const StringPiece main_op_name = main_op_it->second.node_list().value(0); - return RunOnce(run_options, inputs, {}, {string(main_op_name)}, + return RunOnce(run_options, inputs, {}, {init_op_name}, nullptr /* outputs */, &run_metadata, session); } return Status::OK(); } +// A SavedModel may store the name of the initialization op to run in the +// in the SignatureDef (v2) or a collection (v1). If an init_op collection +// exists, then the collection must contain exactly one op. +Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def, + string* init_op_name) { + const auto& sig_def_map = meta_graph_def.signature_def(); + const auto& init_op_sig_it = + meta_graph_def.signature_def().find(kSavedModelInitOpSignatureKey); + if (init_op_sig_it != sig_def_map.end()) { + *init_op_name = init_op_sig_it->second.outputs() + .find(kSavedModelInitOpSignatureKey) + ->second.name(); + return Status::OK(); + } + + const auto& collection_def_map = meta_graph_def.collection_def(); + string init_op_collection_key; + if (collection_def_map.find(kSavedModelMainOpKey) != + collection_def_map.end()) { + init_op_collection_key = kSavedModelMainOpKey; + } else { + init_op_collection_key = kSavedModelLegacyInitOpKey; + } + + const auto init_op_it = collection_def_map.find(init_op_collection_key); + if (init_op_it != collection_def_map.end()) { + if (init_op_it->second.node_list().value_size() != 1) { + return errors::FailedPrecondition( + strings::StrCat("Expected exactly one main op in : ", export_dir)); + } + *init_op_name = init_op_it->second.node_list().value(0); + } + return Status::OK(); +} + Status RunRestore(const RunOptions& run_options, const string& export_dir, const StringPiece restore_op_name, const StringPiece variable_filename_const_op_name, @@ -193,6 +213,15 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir, Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def, std::vector* asset_file_defs) { + // With SavedModel v2, we write asset file def into metagraph instead of + // collection, so read from metagraph first. + if (meta_graph_def.asset_file_def_size() > 0) { + for (const auto& asset : meta_graph_def.asset_file_def()) { + asset_file_defs->push_back(asset); + } + return Status::OK(); + } + // Fall back to read from collection to be backward compatible with v1. const auto& collection_def_map = meta_graph_def.collection_def(); const auto assets_it = collection_def_map.find(kSavedModelAssetsKey); if (assets_it == collection_def_map.end()) { @@ -227,15 +256,12 @@ Status LoadSavedModelInternal(const SessionOptions& session_options, bundle->meta_graph_def.saver_def().restore_op_name(), bundle->meta_graph_def.saver_def().filename_tensor_name(), asset_file_defs, bundle->session.get())); - if (HasMainOp(bundle->meta_graph_def)) { - TF_RETURN_IF_ERROR(RunMainOp(run_options, export_dir, - bundle->meta_graph_def, asset_file_defs, - bundle->session.get(), kSavedModelMainOpKey)); - } else { - TF_RETURN_IF_ERROR(RunMainOp( - run_options, export_dir, bundle->meta_graph_def, asset_file_defs, - bundle->session.get(), kSavedModelLegacyInitOpKey)); - } + string init_op_name; + TF_RETURN_IF_ERROR( + GetInitOp(export_dir, bundle->meta_graph_def, &init_op_name)); + TF_RETURN_IF_ERROR(RunInitOp(run_options, export_dir, bundle->meta_graph_def, + asset_file_defs, bundle->session.get(), + init_op_name)); return Status::OK(); } diff --git a/tensorflow/cc/saved_model/loader_test.cc b/tensorflow/cc/saved_model/loader_test.cc index 72b8bc18710..597e42bb65a 100644 --- a/tensorflow/cc/saved_model/loader_test.cc +++ b/tensorflow/cc/saved_model/loader_test.cc @@ -36,6 +36,8 @@ constexpr char kTestDataMainOp[] = "cc/saved_model/testdata/half_plus_two_main_op/00000123"; constexpr char kTestDataSharded[] = "cc/saved_model/testdata/half_plus_two/00000123"; +constexpr char kTestDataInitOpV2[] = + "cc/saved_model/testdata/half_plus_two_v2/00000123"; class LoaderTest : public ::testing::Test { protected: @@ -227,5 +229,17 @@ TEST_F(LoaderTest, MaybeSavedModelDirectory) { EXPECT_FALSE(MaybeSavedModelDirectory(invalid_export_dir)); } +TEST_F(LoaderTest, SavedModelInitOpV2Format) { + SavedModelBundle bundle; + SessionOptions session_options; + RunOptions run_options; + + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataInitOpV2); + TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir, + {kSavedModelTagServe}, &bundle)); + CheckSavedModelBundle(export_dir, bundle); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/assets/foo.txt b/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/assets/foo.txt new file mode 100644 index 00000000000..f9ff0366880 --- /dev/null +++ b/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/assets/foo.txt @@ -0,0 +1 @@ +asset-file-contents \ No newline at end of file diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/saved_model.pb b/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/saved_model.pb new file mode 100644 index 00000000000..a10bbf8fb6b Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/saved_model.pb differ diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/variables/variables.data-00000-of-00001 b/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/variables/variables.data-00000-of-00001 new file mode 100644 index 00000000000..15b75d6ef6b Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/variables/variables.data-00000-of-00001 differ diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/variables/variables.index b/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/variables/variables.index new file mode 100644 index 00000000000..7ec9fb4fe2d Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/variables/variables.index differ diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index b17bc658fa0..ab1c1be344e 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -164,7 +164,8 @@ string RewriteWithName(const string& name, string code, } // Generate methods for args (inputs). -Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps, +Status GenArgMethods(const tf2xla::Config& config, + const xla::ProgramShapeProto& ps, const CompileResult& compile_result, string* methods) { size_t num_args = ps.parameters_size(); if (config.feed_size() != num_args) { @@ -174,9 +175,10 @@ Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps, } for (int i = 0; i < num_args; ++i) { std::vector> rewrites; - TF_RETURN_IF_ERROR(AddRewritesForShape(i, ps.parameters(i), &rewrites)); + TF_RETURN_IF_ERROR( + AddRewritesForShape(i, xla::Shape(ps.parameters(i)), &rewrites)); const string code = R"( - void set_arg{{NAME}}_data(void* data) { + void set_arg{{NAME}}_data(const void* data) { set_arg_data({{I}}, data); } {{TYPE}}* arg{{NAME}}_data() { @@ -204,7 +206,7 @@ Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps, // Generate methods for results (outputs). Status GenResultMethods(const tf2xla::Config& config, - const xla::ProgramShape& ps, string* methods) { + const xla::ProgramShapeProto& ps, string* methods) { if (ps.result().element_type() != xla::TUPLE) { // The XlaCompiler we use to build the xla computation always generates a // tuple result, and we rely on this to simplify code generation. @@ -217,8 +219,8 @@ Status GenResultMethods(const tf2xla::Config& config, } for (int i = 0; i < ps.result().tuple_shapes_size(); ++i) { std::vector> rewrites; - TF_RETURN_IF_ERROR( - AddRewritesForShape(i, ps.result().tuple_shapes(i), &rewrites)); + TF_RETURN_IF_ERROR(AddRewritesForShape( + i, xla::Shape(ps.result().tuple_shapes(i)), &rewrites)); string code = R"( {{TYPE}}* result{{NAME}}_data() { return static_cast<{{TYPE}}*>(result_data({{I}})); @@ -336,7 +338,7 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, ExtractEntryParamBufferInfos(buffer_infos); std::vector buffer_infos_for_temps = ExtractTempBufferInfos(buffer_infos); - const xla::ProgramShape& ps = compile_result.program_shape; + const xla::ProgramShapeProto& ps = compile_result.program_shape; string methods_arg, methods_result; TF_RETURN_IF_ERROR(GenArgMethods(config, ps, compile_result, &methods_arg)); TF_RETURN_IF_ERROR(GenResultMethods(config, ps, &methods_result)); @@ -548,8 +550,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { static const char** StaticResultNames() {{RESULT_NAMES_CODE}} // Shape of the args and results. - static const xla::ProgramShape* StaticProgramShape() { - static const xla::ProgramShape* kShape = {{PROGRAM_SHAPE_SHIM_EXPRESSION}}; + static const xla::ProgramShapeProto* StaticProgramShape() { + static const xla::ProgramShapeProto* kShape = {{PROGRAM_SHAPE_SHIM_EXPRESSION}}; return kShape; } @@ -587,7 +589,7 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { {"{{METHODS_RESULT}}\n", methods_result}, {"{{NS_END}}\n", ns_end}, {"{{NS_START}}\n", ns_start}, - {"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(ps)}, + {"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(xla::ProgramShape(ps))}, {"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}", metadata_result.program_shape_access_shim}, {"{{RESULT_INDEX}}", absl::StrCat(result_index)}, @@ -615,11 +617,11 @@ static string CreateUniqueIdentifier(const CodegenOpts& opts, Status GenerateMetadata(const CodegenOpts& opts, const CompileResult& compile_result, MetadataResult* metadata_result) { - std::unique_ptr program_shape; + std::unique_ptr program_shape; if (opts.gen_program_shape) { program_shape = - absl::make_unique(compile_result.program_shape); + absl::make_unique(compile_result.program_shape); // The parameter names are currently meaningless, and redundant with the // rest of our metadata, so clear them out to avoid confusion and save @@ -631,8 +633,8 @@ Status GenerateMetadata(const CodegenOpts& opts, // a shim that evaluates to nullptr, which is what we want. ProtobufToEmbed program_shape_protobuf{ - CreateUniqueIdentifier(opts, "ProgramShape"), "xla::ProgramShape", - program_shape.get()}; + CreateUniqueIdentifier(opts, "ProgramShapeProto"), + "xla::ProgramShapeProto", program_shape.get()}; ProtobufToEmbed hlo_profile_printer_data_protobuf{ CreateUniqueIdentifier(opts, "HloProfilePrinterData"), diff --git a/tensorflow/compiler/aot/codegen.h b/tensorflow/compiler/aot/codegen.h index 90410c46a8e..9485e86b10e 100644 --- a/tensorflow/compiler/aot/codegen.h +++ b/tensorflow/compiler/aot/codegen.h @@ -57,7 +57,7 @@ struct MetadataResult { std::vector header_variable_decls; // program_shape_access_shim is a C++ expression that constructs the - // xla::ProgramShape instance for the CompileResult passed to + // xla::ProgramShapeProto instance for the CompileResult passed to // GenerateMetadata. string program_shape_access_shim; diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index bb288d23000..c1788ca32a1 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -181,13 +181,15 @@ TEST(CodegenTest, Golden) { BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/1), BufferInfo::MakeTempBuffer(3), BufferInfo::MakeTempBuffer(120)}, 5, {})); - compile_result.program_shape = xla::ShapeUtil::MakeProgramShape( - { - xla::ShapeUtil::MakeShape(xla::F32, {1, 2}), - xla::ShapeUtil::MakeShape(xla::S64, {3, 4}), - }, - xla::ShapeUtil::MakeTupleShape( - {xla::ShapeUtil::MakeShape(xla::U32, {5, 6})})); + compile_result.program_shape = + xla::ShapeUtil::MakeProgramShape( + { + xla::ShapeUtil::MakeShape(xla::F32, {1, 2}), + xla::ShapeUtil::MakeShape(xla::S64, {3, 4}), + }, + xla::ShapeUtil::MakeTupleShape( + {xla::ShapeUtil::MakeShape(xla::U32, {5, 6})})) + .ToProto(); compile_result.entry_point = "entry_point"; compile_result.pointer_size = 8; diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden index e4d8a02877c..968afad65ed 100644 --- a/tensorflow/compiler/aot/codegen_test_h.golden +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -22,7 +22,7 @@ extern "C" void entry_point( void* result, const xla::ExecutableRunOptions* run_options, const void** args, void** temps, tensorflow::int64* profile_counters); -extern "C" char __tfcompile_foo_bar_MyClass_ProgramShape_protobuf_array_contents[]; +extern "C" char __tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[]; namespace foo { @@ -114,7 +114,7 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { // with dim indices specifying which value. No bounds checking is performed // on dim indices. - void set_arg0_data(void* data) { + void set_arg0_data(const void* data) { set_arg_data(0, data); } float* arg0_data() { @@ -132,7 +132,7 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { arg_data(0)))[dim0][dim1]; } - void set_arg_myfeed_data(void* data) { + void set_arg_myfeed_data(const void* data) { set_arg_data(0, data); } float* arg_myfeed_data() { @@ -150,7 +150,7 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { arg_data(0)))[dim0][dim1]; } - void set_arg1_data(void* data) { + void set_arg1_data(const void* data) { set_arg_data(1, data); } tensorflow::int64* arg1_data() { @@ -253,10 +253,10 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { } // Shape of the args and results. - static const xla::ProgramShape* StaticProgramShape() { - static const xla::ProgramShape* kShape = []() { - xla::ProgramShape* proto = new xla::ProgramShape; - proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShape_protobuf_array_contents[0], 52); + static const xla::ProgramShapeProto* StaticProgramShape() { + static const xla::ProgramShapeProto* kShape = []() { + xla::ProgramShapeProto* proto = new xla::ProgramShapeProto; + proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[0], 52); return proto; }(); return kShape; diff --git a/tensorflow/compiler/aot/codegen_test_o.golden b/tensorflow/compiler/aot/codegen_test_o.golden index eb001c5d45b..ce8e5ec8c96 100644 Binary files a/tensorflow/compiler/aot/codegen_test_o.golden and b/tensorflow/compiler/aot/codegen_test_o.golden differ diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index 2b5f97b34cd..9fc223bdc7c 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -56,17 +56,23 @@ Status CompileXla(xla::CompileOnlyClient* client, return errors::Unknown("Couldn't get XLA program shape: ", pshape_or.status().error_message()); } - compile_result->program_shape = *pshape_or.ValueOrDie(); - xla::ProgramShape* pshape = &compile_result->program_shape; - std::vector arg_layouts; - arg_layouts.reserve(pshape->parameters_size()); + compile_result->program_shape = pshape_or.ValueOrDie()->ToProto(); + xla::ProgramShapeProto* pshape = &compile_result->program_shape; + + // AotXlaComputationInstance::argument_layouts is a vector of Shape + // pointers. Accumulate the Shape objects themselves in a separate vector + // while building the vector of pointers. + std::vector arg_layout_ptrs(pshape->parameters_size()); + std::vector arg_layouts(pshape->parameters_size()); for (int i = 0; i < pshape->parameters_size(); ++i) { - arg_layouts.push_back(pshape->mutable_parameters(i)); + arg_layouts[i] = xla::Shape(*pshape->mutable_parameters(i)); + arg_layout_ptrs[i] = &arg_layouts[i]; } xla::CompileOnlyClient::AotXlaComputationInstance instance; instance.computation = &computation; - instance.argument_layouts = std::move(arg_layouts); - instance.result_layout = &pshape->result(); + instance.argument_layouts = std::move(arg_layout_ptrs); + xla::Shape result_shape(pshape->result()); + instance.result_layout = &result_shape; xla::StatusOr>> aot_or = client->CompileAheadOfTime({instance}, aot_opts); if (!aot_or.ok()) { diff --git a/tensorflow/compiler/aot/compile.h b/tensorflow/compiler/aot/compile.h index e03c5b1aa77..ee7bb26fabd 100644 --- a/tensorflow/compiler/aot/compile.h +++ b/tensorflow/compiler/aot/compile.h @@ -33,9 +33,9 @@ namespace tfcompile { struct CompileResult { // Contains object file and meta-info. std::unique_ptr aot; - xla::ProgramShape program_shape; // Static shape of args and results. - string entry_point; // Name of generated function. - int pointer_size = 0; // Size of a pointer in bytes. + xla::ProgramShapeProto program_shape; // Static shape of args and results. + string entry_point; // Name of generated function. + int pointer_size = 0; // Size of a pointer in bytes. }; // CompileGraph compiles the graph_def into an object file containing a function diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index f10852c7850..4dd79e5882d 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -526,13 +526,15 @@ TEST(TFCompileTest, ProgramShape) { // muladd has the program shape defined. MatMulAndAddComp muladd; - const xla::ProgramShape* muladd_shape = muladd.ProgramShape(); + const xla::ProgramShapeProto* muladd_shape = muladd.ProgramShape(); ASSERT_TRUE(muladd_shape != nullptr); ASSERT_EQ(muladd_shape->parameters_size(), 2); - EXPECT_TRUE(ShapeUtil::Compatible(muladd_shape->parameters(0), f32_2x2)); - EXPECT_TRUE(ShapeUtil::Compatible(muladd_shape->parameters(1), f32_2x2)); + EXPECT_TRUE( + ShapeUtil::Compatible(xla::Shape(muladd_shape->parameters(0)), f32_2x2)); + EXPECT_TRUE( + ShapeUtil::Compatible(xla::Shape(muladd_shape->parameters(1)), f32_2x2)); - const xla::Shape& muladd_result = muladd_shape->result(); + const xla::Shape muladd_result(muladd_shape->result()); ASSERT_EQ(muladd_result.element_type(), xla::TUPLE); ASSERT_EQ(ShapeUtil::TupleElementCount(muladd_result), 2); const xla::Shape& muladd_result0 = diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 5f25e4626ad..be91ed4f432 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -23,7 +23,6 @@ package( load("//tensorflow:tensorflow.bzl", "cc_header_only_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") -load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") @@ -38,7 +37,7 @@ cc_library( ":xla_cpu_device", ":xla_cpu_jit", "//tensorflow/compiler/plugin", - ] + if_cuda_is_configured([ + ] + if_cuda([ ":xla_gpu_device", ":xla_gpu_jit", ]), @@ -51,6 +50,7 @@ cc_library( deps = [ ":jit_compilation_passes", "//tensorflow/compiler/jit/kernels:xla_ops", + "//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops", "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/service:cpu_plugin", @@ -76,10 +76,10 @@ cc_library( srcs = ["xla_cpu_device.cc"], visibility = [":friends"], deps = [ + ":flags", ":jit_compilation_passes", ":xla_device", "//tensorflow/compiler/jit/kernels:xla_ops", - "//tensorflow/compiler/jit/legacy_flags:xla_device_flags", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/service:cpu_plugin", # buildcleaner: keep @@ -210,6 +210,18 @@ cc_library( # Internal targets below this point. +cc_library( + name = "flags", + srcs = ["flags.cc"], + hdrs = ["flags.h"], + visibility = [":friends"], + deps = [ + "//tensorflow/compiler/xla:parse_flags_from_env", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) + cc_library( name = "common", srcs = [ @@ -256,6 +268,7 @@ cc_library( "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", @@ -268,6 +281,7 @@ cc_library( "//tensorflow/core/kernels:variable_ops", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) @@ -487,6 +501,7 @@ cc_library( deps = [ ":common", ":encapsulate_util", + ":flags", ":shape_inference_helpers", ":union_find", ":xla_cluster_util", @@ -494,8 +509,6 @@ cc_library( "//tensorflow/cc:ops", "//tensorflow/cc:scope_internal", "//tensorflow/compiler/jit/graphcycles", - "//tensorflow/compiler/jit/legacy_flags:build_xla_ops_pass_flags", - "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", "//tensorflow/compiler/jit/ops:xla_ops", "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla:resource_operation_table", @@ -724,7 +737,10 @@ tf_custom_op_py_library( visibility = [ ":friends", ], - deps = ["//tensorflow/compiler/jit/ops:xla_ops_wrapper_py"], + deps = [ + "//tensorflow/compiler/jit/ops:xla_ops_grad", + "//tensorflow/compiler/jit/ops:xla_ops_wrapper_py", + ], ) # This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library. diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc index 93637a69d5d..9f4042630ed 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/cc/ops/control_flow_ops.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" -#include "tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" @@ -320,10 +320,10 @@ Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) { return IsXlaCompiledKernel(*n); }); - bool lazy_compilation_enabled = enable_lazy_compilation_ - ? *enable_lazy_compilation_ - : legacy_flags::GetBuildXlaOpsPassFlags() - .tf_xla_enable_lazy_compilation; + bool lazy_compilation_enabled = + enable_lazy_compilation_ + ? *enable_lazy_compilation_ + : GetBuildXlaOpsPassFlags().tf_xla_enable_lazy_compilation; for (Node* n : xla_compiled_kernels) { TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun( diff --git a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc index 11df946cc18..48a23a4c171 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc @@ -42,14 +42,8 @@ class BuildXlaOpsTest : public ::testing::Test { .ok()); } - void TearDown() override { - for (Device* device : devices_) { - delete device; - } - } - private: - std::vector devices_; + std::vector> devices_; }; using ::tensorflow::testing::FindNodeByName; diff --git a/tensorflow/compiler/jit/create_xla_launch_op_test.cc b/tensorflow/compiler/jit/create_xla_launch_op_test.cc index 73866607621..0f872a480f4 100644 --- a/tensorflow/compiler/jit/create_xla_launch_op_test.cc +++ b/tensorflow/compiler/jit/create_xla_launch_op_test.cc @@ -59,8 +59,9 @@ class CreateXlaLaunchOpTest : public ::testing::Test { SessionOptions options; auto* device_count = options.config.mutable_device_count(); device_count->insert({"CPU", 1}); + std::vector> devices; TF_CHECK_OK(DeviceFactory::AddDevices( - options, "/job:localhost/replica:0/task:0", &devices_)); + options, "/job:localhost/replica:0/task:0", &devices)); FunctionDefLibrary proto; for (const auto& fdef : flib) { @@ -69,7 +70,7 @@ class CreateXlaLaunchOpTest : public ::testing::Test { lib_def_ = absl::make_unique( OpRegistry::Global(), proto); OptimizerOptions opts; - device_mgr_ = absl::make_unique(devices_); + device_mgr_ = absl::make_unique(std::move(devices)); pflr_ = absl::make_unique( device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr); @@ -77,7 +78,6 @@ class CreateXlaLaunchOpTest : public ::testing::Test { } FunctionLibraryRuntime* flr_; - std::vector devices_; std::unique_ptr device_mgr_; std::unique_ptr lib_def_; std::unique_ptr pflr_; diff --git a/tensorflow/compiler/jit/encapsulate_util.cc b/tensorflow/compiler/jit/encapsulate_util.cc index 28ec37b1b9c..1f4b9c90a4f 100644 --- a/tensorflow/compiler/jit/encapsulate_util.cc +++ b/tensorflow/compiler/jit/encapsulate_util.cc @@ -86,7 +86,7 @@ Status ProcessControlEdges(Graph* g, const string& xla_computation_attr_name, continue; } else if (src_xla_computation && !dst_xla_computation) { if (src_outside_compilation) { - // Case 1d: outside compilation to host computation control edge. + // Case 1c: outside compilation to host computation control edge. edges_to_remove.push_back(e); TF_RETURN_IF_ERROR(AppendToListAttr( @@ -94,7 +94,7 @@ Status ProcessControlEdges(Graph* g, const string& xla_computation_attr_name, } } else if (!src_xla_computation && dst_xla_computation) { if (dst_outside_compilation) { - // Case 1d: host computation control to outside compilation edge. + // Case 1c: host computation control to outside compilation edge. edges_to_remove.push_back(e); TF_RETURN_IF_ERROR(AppendToListAttr( @@ -103,40 +103,24 @@ Status ProcessControlEdges(Graph* g, const string& xla_computation_attr_name, } else { // src_xla_computation && dst_xla_computation if (*src_xla_computation != *dst_xla_computation) { if (src_outside_compilation && dst_outside_compilation) { - // Case 1c: outside compilation to outside compilation control edge. + // Case 1b: outside compilation to outside compilation control edge. edges_to_remove.push_back(e); TF_RETURN_IF_ERROR(AppendToListAttr( e->dst(), kXlaControlDependenciesAttrName, e->src()->name())); } else if (src_outside_compilation && !dst_outside_compilation) { - // Case 1b: outside compilation to another XLA computaition control + // Case 1a: outside compilation to another XLA computaition control // edge. TF_RETURN_IF_ERROR(AppendToListAttr( e->src(), kXlaConnectedToOtherXlaComputationAttrName, *dst_xla_computation)); } else if (!src_outside_compilation && dst_outside_compilation) { - // Case 1b: another XLA computaition to outside compilation control + // Case 1a: another XLA computaition to outside compilation control // edge. TF_RETURN_IF_ERROR(AppendToListAttr( e->dst(), kXlaConnectedFromOtherXlaComputationAttrName, *src_xla_computation)); } - } else { // *src_xla_computation == *dst_xla_computation - if (src_outside_compilation && dst_outside_compilation) { - if (*src_outside_compilation != *dst_outside_compilation) { - // Case 1c: outside compilation to outside compilation control edge. - edges_to_remove.push_back(e); - - TF_RETURN_IF_ERROR(AppendToListAttr( - e->dst(), kXlaControlDependenciesAttrName, e->src()->name())); - } - } else if (src_outside_compilation && !dst_outside_compilation) { - // Case 1a: outside compilation to its XLA computation control edge. - ReplaceAttr(e->src(), kXlaConnectedToXlaComputationAttrName, true); - } else if (!src_outside_compilation && dst_outside_compilation) { - // Case 1a: XLA computation to outside compilation in it control edge. - ReplaceAttr(e->dst(), kXlaConnectedFromXlaComputationAttrName, true); - } } } } @@ -181,12 +165,6 @@ Status ProcessXlaToXlaDataEdges(Graph* g, edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id()}); VLOG(4) << "XLA -> XLA edge: " << e->DebugString(); } - } else { // *src_xla_computation == *dst_xla_computation - if (src_outside_compilation && dst_outside_compilation && - *src_outside_compilation != *dst_outside_compilation) { - edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id()}); - VLOG(4) << "XLA -> XLA edge: " << e->DebugString(); - } } } @@ -263,7 +241,7 @@ Status ProcessDataEdgeBetweenOutsideCompilationAndHostComputation( // Remove the edge from host to outside compilation. Add a placeholder as // outside compilation node input. - std::map placeholders; + std::map, Node*> placeholders; for (int i = 0; i < edges.size(); i++) { Node* dst = g->FindNodeId(edges[i].dst_node_id); const Edge* e; @@ -275,9 +253,10 @@ Status ProcessDataEdgeBetweenOutsideCompilationAndHostComputation( // Find or create placeholder node. string new_name = edges[i].is_host_to_outside_compilation - ? absl::StrCat(src->name(), "_host_to_oc_placeholder") - : absl::StrCat(src->name(), "_oc_to_host_placeholder"); - auto iter = placeholders.find(new_name); + ? absl::StrCat(src->name(), "_host_to_oc_placeholder_", src_output) + : absl::StrCat(src->name(), "_oc_to_host_placeholder_", src_output); + auto placeholder_index = std::make_pair(src->name(), src_output); + auto iter = placeholders.find(placeholder_index); Node* placeholder_node; if (iter == placeholders.end()) { NodeDefBuilder placeholder_builder(new_name, "Placeholder"); @@ -310,7 +289,7 @@ Status ProcessDataEdgeBetweenOutsideCompilationAndHostComputation( Status s; placeholder_node = g->AddNode(placeholder_def, &s); TF_RETURN_IF_ERROR(s); - placeholders[new_name] = placeholder_node; + placeholders[placeholder_index] = placeholder_node; } else { placeholder_node = iter->second; } @@ -594,14 +573,244 @@ Status AddControlDependencies( return Status::OK(); } +// Step 1 for `PreprocessEdgesBetweenOutsideCompilations`. See comments of +// `PreprocessEdgesBetweenOutsideCompilations` for details. +Status PreprocessControlEdgesBetweenOutsideCompilations( + Graph* g, const string& outside_compilation_attr_name) { + // Gather edges to remove. We should not remove the edge while iterating. + std::vector edges_to_remove; + for (const Edge* e : g->edges()) { + if (!e->IsControlEdge()) { + continue; + } + + auto src_outside_compilation = + GetStringAttr(*e->src(), outside_compilation_attr_name); + auto dst_outside_compilation = + GetStringAttr(*e->dst(), outside_compilation_attr_name); + + if (src_outside_compilation && dst_outside_compilation) { + if (*src_outside_compilation != *dst_outside_compilation) { + // Case 1a: outside compilation to outside compilation control edge. + edges_to_remove.push_back(e); + + TF_RETURN_IF_ERROR(AppendToListAttr( + e->dst(), kXlaControlDependenciesWithinXlaClusterAttrName, + e->src()->name())); + } + } else if (src_outside_compilation && !dst_outside_compilation) { + // Case 1b: outside compilation to its XLA computation control edge. + ReplaceAttr(e->src(), kXlaConnectedToXlaComputationAttrName, true); + } else if (!src_outside_compilation && dst_outside_compilation) { + // Case 1b: XLA computation to outside compilation in it control edge. + ReplaceAttr(e->dst(), kXlaConnectedFromXlaComputationAttrName, true); + } + } + + for (auto e : edges_to_remove) { + g->RemoveEdge(e); + } + return Status::OK(); +} + +// Step 2 for `PreprocessEdgesBetweenOutsideCompilations`. See comments of +// `PreprocessEdgesBetweenOutsideCompilations` for details. +Status PreprocessDataEdgesBetweenOutsideCompilations( + Graph* g, const string& outside_compilation_attr_name) { + // Gather edges between outside compilation and host computation. Notice that + // we do not store `Edge*` directly because we remove some nodes while adding + // Identity nodes, and those Edge pointers might be invalidated. + struct EdgeInfo { + int dst_input, dst_node_id; + }; + std::vector edges; + for (const Edge* e : g->edges()) { + if (e->IsControlEdge()) { + continue; + } + + auto src_outside_compilation = + GetStringAttr(*e->src(), outside_compilation_attr_name); + auto dst_outside_compilation = + GetStringAttr(*e->dst(), outside_compilation_attr_name); + + if (src_outside_compilation && dst_outside_compilation && + *src_outside_compilation != *dst_outside_compilation) { + edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id()}); + VLOG(4) << "Oc -> oc edge: " << e->DebugString(); + } + } + + // Remove the edge from host to outside compilation. Add a placeholder as + // outside compilation node input. + std::map, Node*> placeholders; + for (int i = 0; i < edges.size(); i++) { + Node* dst = g->FindNodeId(edges[i].dst_node_id); + const Edge* e; + TF_RETURN_IF_ERROR(dst->input_edge(edges[i].dst_input, &e)); + Node* src = e->src(); + int src_output = e->src_output(), dst_input = e->dst_input(); + g->RemoveEdge(e); + + // Find or create placeholder node. + string new_name = + absl::StrCat(src->name(), "_oc_to_oc_placeholder_", src_output); + auto placeholder_index = std::make_pair(src->name(), src_output); + auto iter = placeholders.find(placeholder_index); + Node* placeholder_node; + if (iter == placeholders.end()) { + NodeDefBuilder placeholder_builder(new_name, "Placeholder"); + placeholder_builder.Attr("dtype", src->output_type(src_output)); + string outside_compilation_attr; + TF_RETURN_IF_ERROR(GetNodeAttr(dst->attrs(), + outside_compilation_attr_name, + &outside_compilation_attr)); + placeholder_builder.Attr(outside_compilation_attr_name, + outside_compilation_attr); + placeholder_builder.Attr(kOutsideCompilationOriginalNodeAttrName, + src->name()); + placeholder_builder.Attr(kOutsideCompilationSrcOutputAttrName, + src_output); + NodeDef placeholder_def; + TF_RETURN_IF_ERROR(placeholder_builder.Finalize(&placeholder_def)); + Status s; + placeholder_node = g->AddNode(placeholder_def, &s); + TF_RETURN_IF_ERROR(s); + placeholders[placeholder_index] = placeholder_node; + } else { + placeholder_node = iter->second; + } + g->AddEdge(placeholder_node, 0, dst, dst_input); + + // Replace `e->dst()` because its input node changed. + NodeDef new_def = dst->def(); + *new_def.mutable_input(dst_input) = placeholder_node->name(); + TF_ASSIGN_OR_RETURN(Node * dst_replace_node, ReplaceNode(g, dst, new_def)); + + // Other edge in `edges` might have `e->dst()` as src or dst + // node. Before removing `e->dst()`, replace those edges with + // corresponding edges for `dst_replace_node`. + for (int j = i + 1; j < edges.size(); j++) { + if (edges[j].dst_node_id == edges[i].dst_node_id) { + edges[j].dst_node_id = dst_replace_node->id(); + } + } + } + return Status::OK(); +} + +// Step 1 for `PostprocessEdgesBetweenOutsideCompilations`. See comments of +// `PostprocessEdgesBetweenOutsideCompilations` for details. +Status PostprocessDataEdgesBetweenOutsideCompilations( + Graph* g, const string& outside_compilation_attr_name) { + // Gather all outside compilation to outside compilation nodes. + std::vector placeholder_nodes; + for (Node* n : g->nodes()) { + if (n->type_string() == "Placeholder" && + HasNodeAttr(n->def(), kOutsideCompilationOriginalNodeAttrName)) { + placeholder_nodes.push_back(n); + } + } + + // Remove the placeholder nodes, and reconnect original edge. + auto node_name_index = g->BuildNodeNameIndex(); + for (auto n : placeholder_nodes) { + string node_name; + int node_src_output; + TF_RETURN_IF_ERROR(GetNodeAttr( + n->attrs(), kOutsideCompilationOriginalNodeAttrName, &node_name)); + TF_RETURN_IF_ERROR(GetNodeAttr( + n->attrs(), kOutsideCompilationSrcOutputAttrName, &node_src_output)); + auto iter = node_name_index.find(node_name); + if (iter == node_name_index.end()) { + return errors::Internal( + "Cannot find original node for oc -> host placeholder node ", + node_name); + } + + // Change all usage node to use the original node instead. + Node* original_node = iter->second; + std::vector control_edges; + std::vector data_edges; + for (auto e : n->out_edges()) { + if (e->IsControlEdge()) { + control_edges.push_back(e); + } else { + data_edges.push_back({e->dst(), e->src_output(), e->dst_input()}); + } + } + for (const Edge* e : control_edges) { + g->AddControlEdge(original_node, e->dst()); + g->RemoveEdge(e); + } + for (int i = 0; i < data_edges.size(); i++) { + Node* dst = data_edges[i].dst; + NodeDef new_def = dst->def(); + int dst_input = data_edges[i].dst_input; + *new_def.mutable_input(dst_input) = + absl::StrCat(original_node->name(), ":", node_src_output); + TF_ASSIGN_OR_RETURN(Node * replace_node, ReplaceNode(g, dst, new_def)); + + const Edge* edge_to_replace = nullptr; + TF_RETURN_IF_ERROR(replace_node->input_edge(dst_input, &edge_to_replace)); + g->RemoveEdge(edge_to_replace); + g->AddEdge(original_node, node_src_output, replace_node, dst_input); + + // Other edges might have `dst` as dst node. Update those edges with + // `replace_node`. + for (int j = i + 1; j < data_edges.size(); j++) { + if (data_edges[j].dst == dst) { + data_edges[j].dst = replace_node; + } + } + + // Other placeholder node might have `dst` as original node. Update + // `node_name_index` with `replace_node`. + node_name_index[replace_node->name()] = replace_node; + } + + // Remove placeholder node. + g->RemoveNode(n); + } + return Status::OK(); +} + +// Step 2 for `PostprocessEdgesBetweenOutsideCompilations`. See comments of +// `PostprocessEdgesBetweenOutsideCompilations` for details. +Status PostprocessControlEdgesBetweenOutsideCompilations( + Graph* g, const string& outside_compilation_attr_name) { + auto node_name_index = g->BuildNodeNameIndex(); + + // Reconnect outside compilation to outside compilation control edge. + for (Node* n : g->nodes()) { + std::vector control_deps; + Status s = + GetNodeAttr(n->attrs(), kXlaControlDependenciesWithinXlaClusterAttrName, + &control_deps); + if (!s.ok()) { + if (s.code() != error::NOT_FOUND) { + return s; + } else { + continue; + } + } else { + n->ClearAttr(kXlaControlDependenciesWithinXlaClusterAttrName); + for (const string& control_input : control_deps) { + auto iter = node_name_index.find(control_input); + if (iter == node_name_index.end()) { + return errors::Internal("Cannot find original node for ", + control_input); + } + g->AddControlEdge(iter->second, n); + } + } + } + return Status::OK(); +} } // namespace const char kXlaInferredShapesAttrName[] = "_xla_inferred_shapes"; -const char kXlaConnectedToXlaComputationAttrName[] = - "_xla_connected_to_xla_computation"; -const char kXlaConnectedFromXlaComputationAttrName[] = - "_xla_connected_from_xla_computation"; const char kXlaConnectedToOtherXlaComputationAttrName[] = "_xla_connected_to_other_xla_computation"; const char kXlaConnectedFromOtherXlaComputationAttrName[] = @@ -616,6 +825,15 @@ const char kHostToOutsideCompilationOriginalNodeAttrName[] = "_xla_host_to_oc_node_name"; const char kHostToOutsideCompilationSrcOutputAttrName[] = "_xla_host_to_oc_src_output"; +const char kXlaConnectedToXlaComputationAttrName[] = + "_xla_connected_to_xla_computation"; +const char kXlaConnectedFromXlaComputationAttrName[] = + "_xla_connected_from_xla_computation"; +const char kOutsideCompilationOriginalNodeAttrName[] = + "_xla_oc_to_oc_node_name"; +const char kOutsideCompilationSrcOutputAttrName[] = "_xla_oc_to_oc_src_output"; +const char kXlaControlDependenciesWithinXlaClusterAttrName[] = + "_xla_control_dependencies_within_xla_cluster"; Status PerformStaticShapeInferenceBeforeEncapsulation( Graph* g, const string& xla_computation_attr_name, @@ -699,4 +917,39 @@ Status PostprocessForEncapsulation( return Status::OK(); } +Status PreprocessEdgesBetweenOutsideCompilations( + Graph* g, const string& outside_compilation_attr_name) { + // Remove edges from source node to outside compilation nodes, and edges + // from outside compilation nodes to sink node. + std::vector edges_to_remove; + for (const Edge* e : g->source_node()->out_edges()) { + if (HasNodeAttr(e->dst()->def(), outside_compilation_attr_name)) { + edges_to_remove.push_back(e); + } + } + for (const Edge* e : g->sink_node()->in_edges()) { + if (HasNodeAttr(e->src()->def(), outside_compilation_attr_name)) { + edges_to_remove.push_back(e); + } + } + for (auto e : edges_to_remove) { + g->RemoveEdge(e); + } + + TF_RETURN_IF_ERROR(PreprocessControlEdgesBetweenOutsideCompilations( + g, outside_compilation_attr_name)); + TF_RETURN_IF_ERROR(PreprocessDataEdgesBetweenOutsideCompilations( + g, outside_compilation_attr_name)); + return Status::OK(); +} + +Status PostprocessEdgesBetweenOutsideCompilations( + Graph* g, const string& outside_compilation_attr_name) { + TF_RETURN_IF_ERROR(PostprocessDataEdgesBetweenOutsideCompilations( + g, outside_compilation_attr_name)); + TF_RETURN_IF_ERROR(PostprocessControlEdgesBetweenOutsideCompilations( + g, outside_compilation_attr_name)); + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/encapsulate_util.h b/tensorflow/compiler/jit/encapsulate_util.h index 5e0c4bf6a0c..e363bc5754a 100644 --- a/tensorflow/compiler/jit/encapsulate_util.h +++ b/tensorflow/compiler/jit/encapsulate_util.h @@ -44,14 +44,6 @@ Status PerformStaticShapeInferenceBeforeEncapsulation( Graph* g, const string& xla_computation_attr_name, const string& outside_compilation_attr_name); -// Attribute indicating that some ops in this node's XLA computation has control -// dependency on this node. Attribute value will always be "true". -extern const char kXlaConnectedToXlaComputationAttrName[]; - -// Attribute indicating that this node has control dependency on some ops in -// this node's XLA computation. Attribute value will always be "true". -extern const char kXlaConnectedFromXlaComputationAttrName[]; - // Attribute indicating that some ops in other XLA computation has control // dependency on this node. Attribute value will be a list of string (XLA // computation names). @@ -81,6 +73,14 @@ extern const char kOutsideCompilationToHostOriginalNodeAttrName[]; // int (src_output for original edge). extern const char kOutsideCompilationToHostSrcOutputAttrName[]; +// Attribute indicating that some ops in this node's XLA computation has control +// dependency on this node. Attribute value will always be "true". +extern const char kXlaConnectedToXlaComputationAttrName[]; + +// Attribute indicating that this node has control dependency on some ops in +// this node's XLA computation. Attribute value will always be "true". +extern const char kXlaConnectedFromXlaComputationAttrName[]; + // Attribute indicating that this is an Placeholder node added to act as a // temporary input node for an host node. Attribute value will be string // (original input node name). @@ -91,19 +91,31 @@ extern const char kHostToOutsideCompilationOriginalNodeAttrName[]; // for original edge). extern const char kHostToOutsideCompilationSrcOutputAttrName[]; -// Preprocesses the graph for encapsulation. It will perform the following -// operations in order: +// Attribute indicating that this is an Placeholder node added to act as a +// temporary input node for an outside compilation node. Attribute value will be +// string (original input node name). +extern const char kOutsideCompilationOriginalNodeAttrName[]; + +// Attribute indicating that this is an Placeholder node added to act as a +// temporary input node for an outside compilation node. Attribute value will be +// int (src_output for original edge). +extern const char kOutsideCompilationSrcOutputAttrName[]; + +// Attribute indicating that this node has control dependencies on some other +// nodes within the same XLA cluster. Attribute value will be a list of string +// (node names). +extern const char kXlaControlDependenciesWithinXlaClusterAttrName[]; + +// Preprocesses edges between different XLA clusters for encapsulation. It will +// perform the following operations in order: // -// 1a. For control edges between outside compilation and its XLA computation, -// add attr "kXlaConnected{From, To}XlaComputationAttrName = true" to the -// outside compilation node. -// 1b. For control edges between outside compilation and another XLA +// 1a. For control edges between outside compilation and another XLA // computation, add attr "kXlaConnected{From, To}OtherXlaComputationAttrName // = XLA computation node name" to the outside compilation node. -// 1c. For control edges between different outside compilations, remove the edge -// and add attr "kXlaControlDependenciesAttrName = src node name" to dst -// node. -// 1d. For control edges between outside compilation and host computation, +// 1b. For control edges between different outside compilations (in different +// XLA computations), remove the edge and add attr +// "kXlaControlDependenciesAttrName = src node name" to dst node. +// 1c. For control edges between outside compilation and host computation, // remove the edge and add attr "kXlaControlDependenciesAttrName = src node // name" to dst node. // 2. For data edges between different XLA computations, if either src or dst @@ -146,26 +158,53 @@ struct XlaClusterInfo { const std::map host_compute_core; }; -// Postprocesses the graph for encapsulation. This function reverts what -// `PreprocessForEncapsulation` did. It will perform the following operations in -// order: +// Postprocesses edges between different XLA clusters for encapsulation. This +// function reverts what `PreprocessForEncapsulation` did. It will perform the +// following operations in order: // // 1. Remove Placeholder nodes between outside compilation and host computation // (created in `PreprocessForEncapsulation` step 3). // 2. Remove Identity nodes created in `PreprocessForEncapsulation` step 2. -// 3a. Reconnect control edges between different outside compilations (marked by -// `PreprocessForEncapsulation` step 1c) and control edges between outside -// compilation and host computation (marked by `PreprocessForEncapsulation` -// step 1d). -// 3b. Reconnect control edges between outside compilation and another XLA -// computation (marked by `PreprocessForEncapsulation` step 1b). -// Notice that control edges marked by `PreprocessForEncapsulation` step 1a are -// not handled here. They are handled in `RewriteOutsideCompilationSubgraphFn`. +// 3a. Reconnect control edges between outside compilation and another XLA +// computation (marked by `PreprocessForEncapsulation` step 1a). +// 3b. Reconnect control edges between different outside compilations (marked by +// `PreprocessForEncapsulation` step 1b). +// 3c. Reconnect control edges between outside compilation and host computation +// (marked by `PreprocessForEncapsulation` step 1c). Status PostprocessForEncapsulation( Graph* g, const string& xla_computation_attr_name, const string& outside_compilation_attr_name, const std::unordered_map& clusters); +// Preprocesses edges within the same XLA cluster. It will perform the following +// operations in order: +// +// 0. Remove edges from source node to outside compilation nodes, and edges +// from outside compilation nodes to sink node. +// 1a. For edges between different outside compilation clusters, remove the edge +// and add attr "kXlaControlDependenciesWithinXlaClusterAttrName = src node +// name" to dst node. +// 1b. For control edges between outside compilation and its XLA computation, +// add attr "kXlaConnected{From, To}XlaComputationAttrName = true" to the +// outside compilation node. +// 2. For data edges between different outside compilations, remove the edge +// and create a Placeholder node as dst node's input. +Status PreprocessEdgesBetweenOutsideCompilations( + Graph* g, const string& outside_compilation_attr_name); + +// Postprocesses edges within the same XLA cluster. This function reverts what +// `PreprocessEdgesBetweenOutsideCompilations` did. It will perform the +// following operations in order: +// +// 1. Remove Placeholder nodes between different outside compilations (created +// in `PreprocessEdgesBetweenOutsideCompilations` step 2). +// 2a. Reconnect control edges between different outside compilations (marked by +// `PreprocessEdgesBetweenOutsideCompilations` step 1a). +// Notice that control edges marked by +// `PreprocessEdgesBetweenOutsideCompilations` step 1b are not handled here. +// They are handled in `RewriteOutsideCompilationSubgraphFn`. +Status PostprocessEdgesBetweenOutsideCompilations( + Graph* g, const string& outside_compilation_attr_name); } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_ diff --git a/tensorflow/compiler/jit/encapsulate_util_test.cc b/tensorflow/compiler/jit/encapsulate_util_test.cc index 7255df31129..3b8b49cb92f 100644 --- a/tensorflow/compiler/jit/encapsulate_util_test.cc +++ b/tensorflow/compiler/jit/encapsulate_util_test.cc @@ -107,28 +107,19 @@ TEST(PreprocessForEncapsulationTest, ControlEdges) { identity4_node->AddAttr("_xla", "1"); identity4_node->AddAttr("_oc", "0"); identity5_node->AddAttr("_xla", "1"); - // Case 1a: control edges between outside compilation and its XLA computation. - g.AddControlEdge(add_node, identity0_node); - g.AddControlEdge(identity0_node, identity1_node); - // Case 1b: control edges between outside compilation and another XLA + // Case 1a: control edges between outside compilation and another XLA // computation. g.AddControlEdge(identity0_node, identity3_node); g.AddControlEdge(identity1_node, identity4_node); - // Case 1c: control edges between different outside compilations. + // Case 1b: control edges between different outside compilations. g.AddControlEdge(identity0_node, identity4_node); - // Case 1d: control edges between outside compilation and host computation. + // Case 1c: control edges between outside compilation and host computation. g.AddControlEdge(const0_node, identity0_node); g.AddControlEdge(identity0_node, identity2_node); TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc")); - // Case 1a: add attr "_xla_connected_{from/to}_xla_computation = true" to the - // outside compilation node. - EXPECT_TRUE(HasNodeAttr(identity0_node->def(), - kXlaConnectedFromXlaComputationAttrName)); - EXPECT_TRUE(HasNodeAttr(identity0_node->def(), - kXlaConnectedToXlaComputationAttrName)); - // Case 1b: add attr "_xla_control_deps_{from/to} = XLA computation node name" + // Case 1a: add attr "_xla_control_deps_{from/to} = XLA computation node name" // to the outside compilation node. std::vector attr; TF_CHECK_OK(GetNodeAttr(identity0_node->def(), @@ -140,13 +131,13 @@ TEST(PreprocessForEncapsulationTest, ControlEdges) { kXlaConnectedFromOtherXlaComputationAttrName, &attr)); EXPECT_EQ(attr.size(), 1); EXPECT_EQ(attr[0], "0"); - // Case 1c: add attr "_xla_control_deps = src node name" to dst node. + // Case 1b: add attr "_xla_control_deps = src node name" to dst node. attr.clear(); TF_CHECK_OK(GetNodeAttr(identity4_node->def(), kXlaControlDependenciesAttrName, &attr)); EXPECT_EQ(attr.size(), 1); EXPECT_EQ(attr[0], "identity0"); - // Case 1d: add attr "_xla_control_deps = src node name" to dst node. + // Case 1c: add attr "_xla_control_deps = src node name" to dst node. attr.clear(); TF_CHECK_OK(GetNodeAttr(identity0_node->def(), kXlaControlDependenciesAttrName, &attr)); @@ -162,23 +153,33 @@ TEST(PreprocessForEncapsulationTest, ControlEdges) { TEST(PreprocessForEncapsulationTest, DataEdges) { // Build the graph: // "const_0" and "const_1" in host computation + // "identityn0" = ("const_0", "const_1") in host computation 0 // "add0" = "const_0" + "const_1" in XLA computation 0 // "add1" = "add0" + "const_0" in XLA computation 0 & outside compilation 0 // "identity0" = "add1" in XLA computation 0 // "add2" = "add1" + "identity0" in host computation // "add3" = "add1" + "add2" in XLA computation 1 - // "add4" = "identity0" + "add2" in XLA computation 1 & outside compilation 1 + // "add4" = "identity0" + "add2" in XLA computation 1 & outside compilation 0 + // "add5" = "identityn0"[0] + "identityn0"[1] in XLA computation 1 & + // outside compilation 0 + // "identityn1" = ("identityn0"[0], "identityn0"[1]) in XLA computation 1 & + // outside compilation 0 // "identity1" = "add4" in XLA computation 1 // "identity2" = "identity1" in host computation tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output const_0 = ops::Const(s.WithOpName("const_0"), 1, {}); Output const_1 = ops::Const(s.WithOpName("const_1"), 2, {}); + auto identityn0 = + ops::IdentityN(s.WithOpName("identityn_0"), {const_0, const_1}); Output add0 = ops::Add(s.WithOpName("add0"), const_0, const_1); Output add1 = ops::Add(s.WithOpName("add1"), add0, const_0); Output identity0 = ops::Identity(s.WithOpName("identity0"), add1); Output add2 = ops::Add(s.WithOpName("add2"), add1, identity0); Output add3 = ops::Add(s.WithOpName("add3"), add1, add2); Output add4 = ops::Add(s.WithOpName("add4"), identity0, add2); + Output add5 = ops::Add(s.WithOpName("add5"), identityn0[0], identityn0[1]); + auto identityn1 = ops::IdentityN(s.WithOpName("identityn_1"), + {identityn0[0], identityn0[1]}); Output identity1 = ops::Identity(s.WithOpName("identity1"), add4); Output identity2 = ops::Identity(s.WithOpName("identity2"), add4); Graph g(OpRegistry::Global()); @@ -189,6 +190,8 @@ TEST(PreprocessForEncapsulationTest, DataEdges) { Node *add0_node = node_index["add0"], *add1_node = node_index["add1"], *identity0_node = node_index["identity0"], *add3_node = node_index["add3"], *add4_node = node_index["add4"], + *add5_node = node_index["add5"], + *identityn1_node = node_index["identityn_1"], *identity1_node = node_index["identity1"]; add0_node->AddAttr("_xla", "0"); add1_node->AddAttr("_xla", "0"); @@ -197,6 +200,10 @@ TEST(PreprocessForEncapsulationTest, DataEdges) { add3_node->AddAttr("_xla", "1"); add4_node->AddAttr("_xla", "1"); add4_node->AddAttr("_oc", "0"); + add5_node->AddAttr("_xla", "1"); + add5_node->AddAttr("_oc", "0"); + identityn1_node->AddAttr("_xla", "1"); + identityn1_node->AddAttr("_oc", "0"); identity1_node->AddAttr("_xla", "1"); TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc")); @@ -214,8 +221,9 @@ TEST(PreprocessForEncapsulationTest, DataEdges) { EXPECT_NE(bridge_identity0_add4, nullptr); // Step 3: add placeholder for edges between host computation and outside // compilation. - EXPECT_EQ(bridge_add1_add3->def().input(0), "add1_oc_to_host_placeholder"); - Node *add1_oc_to_host_placeholder = node_index["add1_oc_to_host_placeholder"]; + EXPECT_EQ(bridge_add1_add3->def().input(0), "add1_oc_to_host_placeholder_0"); + Node *add1_oc_to_host_placeholder = + node_index["add1_oc_to_host_placeholder_0"]; TF_CHECK_OK(GetNodeAttr(add1_oc_to_host_placeholder->attrs(), kOutsideCompilationToHostOriginalNodeAttrName, &str)); EXPECT_EQ(str, "add1"); @@ -226,15 +234,34 @@ TEST(PreprocessForEncapsulationTest, DataEdges) { add4_node = node_index["add4"]; ASSERT_NE(add4_node, nullptr); EXPECT_EQ(add4_node->def().input(0), - "bridge_identity0_add4_host_to_oc_placeholder"); + "bridge_identity0_add4_host_to_oc_placeholder_0"); Node *identity0_host_to_oc_placeholder = - node_index["bridge_identity0_add4_host_to_oc_placeholder"]; + node_index["bridge_identity0_add4_host_to_oc_placeholder_0"]; TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(), kHostToOutsideCompilationOriginalNodeAttrName, &str)); EXPECT_EQ(str, "bridge_identity0_add4"); TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(), kHostToOutsideCompilationSrcOutputAttrName, &i)); EXPECT_EQ(i, 0); + + // Check different placeholder nodes are created for different src_output. + Node *placeholder0 = node_index["identityn_0_host_to_oc_placeholder_0"], + *placeholder1 = node_index["identityn_0_host_to_oc_placeholder_1"]; + EXPECT_NE(placeholder0, nullptr); + EXPECT_NE(placeholder1, nullptr); + // Check we only have 2 placeholder nodes created for "identityn_0". + int placeholder_count = 0; + for (Node *n : g.nodes()) { + if (HasNodeAttr(n->def(), kHostToOutsideCompilationOriginalNodeAttrName)) { + string attr; + TF_CHECK_OK(GetNodeAttr( + n->attrs(), kHostToOutsideCompilationOriginalNodeAttrName, &attr)); + if (attr == "identityn_0") { + ++placeholder_count; + } + } + } + EXPECT_EQ(placeholder_count, 2); } TEST(PostprocessForEncapsulationTest, ControlEdges) { diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc index 2ce6fa73fc4..d334100aa4a 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc @@ -195,8 +195,11 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors, e->dst()->attrs().Find(kXlaClusterAttr) == nullptr && e->dst()->type_string() != kXlaClusterOutput) { return errors::InvalidArgument( - "Undeclared output of XLA computation. A common cause of this error " - "is variable initializers that depend on the XLA computation. Edge: ", + "Undeclared output of XLA computation. Some common causes of this " + "error are: 1) variable initializers that depend on the XLA " + "computation; 2) gradient computations that depend on the XLA " + "computation, which can be mitigated by moving gradient computations " + "inside XLA computation. Offending edge: ", e->src()->name(), ":", e->src_output(), " -> ", e->dst()->name(), ":", e->dst_input()); } diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc index 8b3587c5087..e3c7e2f89be 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc @@ -366,7 +366,7 @@ Status ReplaceOrRemoveOutsideCompilationCallNode( // replace this node with compilation result node. // 3) all outside compilation graphs. Status ConstructHostGraph( - const string& xla_cluster_name, + const string& xla_cluster_name, const string& outside_compilation_attr_name, const std::vector& outside_compilation_host_graphs, FunctionLibraryDefinition* fld, std::unique_ptr* host_graph) { host_graph->reset(new Graph(fld)); @@ -476,6 +476,10 @@ Status ConstructHostGraph( host_graph->get(), std::unordered_set{(*host_graph)->sink_node()}); + // Postprocess edges between different outside compilations. + TF_RETURN_IF_ERROR(PostprocessEdgesBetweenOutsideCompilations( + host_graph->get(), outside_compilation_attr_name)); + if (VLOG_IS_ON(4)) { dump_graph::DumpGraphToFile( absl::StrCat("extract_outside_compilation_host_graph_for_", @@ -801,6 +805,11 @@ Status ExtractOutsideCompilationForFunction( }, &fbody)); std::unique_ptr fbody_deleter(fbody); + + // Preprocess edges between different outside compilations. They will be + // restored in `ConstructHostGraph()`. + TF_RETURN_IF_ERROR(PreprocessEdgesBetweenOutsideCompilations( + fbody->graph, outside_compilation_attr_name)); if (VLOG_IS_ON(4)) { dump_graph::DumpGraphToFile( absl::StrCat("extract_outside_compilation_for_func_before_", func_name), @@ -860,8 +869,9 @@ Status ExtractOutsideCompilationForFunction( // Construct host graph. if (!outside_compilation_host_graphs.empty()) { - TF_RETURN_IF_ERROR(ConstructHostGraph( - xla_cluster_name, outside_compilation_host_graphs, fld, host_graph)); + TF_RETURN_IF_ERROR( + ConstructHostGraph(xla_cluster_name, outside_compilation_attr_name, + outside_compilation_host_graphs, fld, host_graph)); } // Remove the outside compilation graphs from function library. diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc index c5bd64f004e..bff956100da 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc @@ -290,21 +290,18 @@ TEST(ExtractOutsideCompilationForFunctionTest, Basic) { TF_CHECK_OK(GetNodeAttr(host_compute_1->attrs(), "shapes", &shapes)); EXPECT_EQ(shapes.size(), 1); EXPECT_EQ(shapes[0].dim_size(), 1); - // Check XlaHostCompute nodes' "shape_inference_graph" attr. "0" should have a - // non-empty value, and "1" should have an empty value. + // Check XlaHostCompute nodes' "shape_inference_graph" attr. Both should have + // empty values. string shape_inference_graph; TF_CHECK_OK(GetNodeAttr(host_compute_0->attrs(), "shape_inference_graph", &shape_inference_graph)); - EXPECT_EQ(shape_inference_graph, - "_outside_compilation_shape_inference_cluster_0"); + EXPECT_EQ(shape_inference_graph, ""); TF_CHECK_OK(GetNodeAttr(host_compute_1->attrs(), "shape_inference_graph", &shape_inference_graph)); EXPECT_EQ(shape_inference_graph, ""); // Check `shape_inference_graphs`. - EXPECT_EQ(shape_inference_graphs.size(), 1); - EXPECT_EQ(shape_inference_graphs[0], - "_outside_compilation_shape_inference_cluster_0"); + EXPECT_EQ(shape_inference_graphs.size(), 0); // Check `host_graph`: verify we have key placeholder and sequencer. Node *key_placeholder = nullptr, *sequencer = nullptr; @@ -333,8 +330,8 @@ TEST(ExtractOutsideCompilationForFunctionTest, Basic) { send_recv_nodes.push_back(n); } } - EXPECT_EQ(num_send_from_host, 2); - EXPECT_EQ(num_recv_at_host, 2); + EXPECT_EQ(num_send_from_host, 1); + EXPECT_EQ(num_recv_at_host, 1); for (Node *n : send_recv_nodes) { Node *input_node; TF_CHECK_OK(n->input_node(n->num_inputs() - 1, &input_node)); diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc new file mode 100644 index 00000000000..98e344b3a08 --- /dev/null +++ b/tensorflow/compiler/jit/flags.cc @@ -0,0 +1,152 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include // NOLINT + +#include "tensorflow/compiler/jit/flags.h" +#include "tensorflow/compiler/xla/parse_flags_from_env.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { +namespace { + +BuildXlaOpsPassFlags* build_ops_flags; +DumpGraphFlags* dump_graph_flags; +MarkForCompilationPassFlags* mark_for_compilation_flags; +XlaDeviceFlags* device_flags; +XlaOpsCommonFlags* ops_flags; + +std::vector* flag_list; +std::once_flag flags_init; + +void AppendDumpGraphFlagsInternal(std::vector* flag_list) { + std::vector new_flags = { + Flag("tf_dump_graph_prefix", &dump_graph_flags->tf_dump_graph_prefix, + "Path prefix to which graphs dumped during debugging should be " + "written."), + }; + flag_list->insert(flag_list->end(), new_flags.begin(), new_flags.end()); +} + +void AppendMarkForCompilationPassFlagsInternal(std::vector* flag_list) { + std::vector new_flags = { + Flag("tf_xla_auto_jit", &mark_for_compilation_flags->tf_xla_auto_jit, + "Control compilation of operators into XLA computations on CPU and " + "GPU devices. 0 = use ConfigProto setting; -1 = off; 1 = on for " + "things very likely to be improved; 2 = on for everything. " + "Experimental."), + Flag("tf_xla_min_cluster_size", + &mark_for_compilation_flags->tf_xla_min_cluster_size, + "Minimum number of operators in an XLA compilation. Ignored for " + "operators placed on an XLA device or operators explicitly marked " + "for compilation."), + Flag("tf_xla_max_cluster_size", + &mark_for_compilation_flags->tf_xla_max_cluster_size, + "Maximum number of operators in an XLA compilation."), + Flag("tf_xla_clustering_debug", + &mark_for_compilation_flags->tf_xla_clustering_debug, + "Dump graphs during XLA compilation."), + Flag("tf_xla_cpu_global_jit", + &mark_for_compilation_flags->tf_xla_cpu_global_jit, + "Enables global JIT compilation for CPU via SessionOptions."), + Flag("tf_xla_clustering_fuel", + &mark_for_compilation_flags->tf_xla_clustering_fuel, + "Places an artificial limit on the number of ops marked as " + "eligible for clustering."), + Flag("tf_xla_fusion_only", + &mark_for_compilation_flags->tf_xla_fusion_only, + "enable fusion of element-wise operations only using XLA when " + "global_jit_level is ON*.")}; + flag_list->insert(flag_list->end(), new_flags.begin(), new_flags.end()); +} + +void AllocateAndParseFlags() { + build_ops_flags = new BuildXlaOpsPassFlags; + build_ops_flags->tf_xla_enable_lazy_compilation = true; + + dump_graph_flags = new DumpGraphFlags; + dump_graph_flags->tf_dump_graph_prefix = "/tmp/"; + + mark_for_compilation_flags = new MarkForCompilationPassFlags; + mark_for_compilation_flags->tf_xla_auto_jit = 0; + mark_for_compilation_flags->tf_xla_min_cluster_size = 2; + mark_for_compilation_flags->tf_xla_max_cluster_size = + std::numeric_limits::max(); + mark_for_compilation_flags->tf_xla_clustering_debug = false; + mark_for_compilation_flags->tf_xla_cpu_global_jit = false; + mark_for_compilation_flags->tf_xla_clustering_fuel = + std::numeric_limits::max(); + mark_for_compilation_flags->tf_xla_fusion_only = false; + + device_flags = new XlaDeviceFlags; + device_flags->tf_xla_compile_on_demand = false; + + ops_flags = new XlaOpsCommonFlags; + ops_flags->tf_xla_always_defer_compilation = false; + + flag_list = new std::vector({ + Flag("tf_xla_enable_lazy_compilation", + &build_ops_flags->tf_xla_enable_lazy_compilation, ""), + + Flag("tf_xla_compile_on_demand", &device_flags->tf_xla_compile_on_demand, + "Switch a device into 'on-demand' mode, where instead of " + "autoclustering ops are compiled one by one just-in-time."), + + Flag("tf_xla_always_defer_compilation", + &ops_flags->tf_xla_always_defer_compilation, ""), + }); + AppendDumpGraphFlagsInternal(flag_list); + AppendMarkForCompilationPassFlagsInternal(flag_list); + xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list); +} + +} // namespace + +const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags() { + std::call_once(flags_init, &AllocateAndParseFlags); + return *build_ops_flags; +} + +DumpGraphFlags* GetDumpGraphFlags() { + std::call_once(flags_init, &AllocateAndParseFlags); + return dump_graph_flags; +} + +MarkForCompilationPassFlags* GetMarkForCompilationPassFlags() { + std::call_once(flags_init, &AllocateAndParseFlags); + return mark_for_compilation_flags; +} + +XlaDeviceFlags* GetXlaDeviceFlags() { + std::call_once(flags_init, &AllocateAndParseFlags); + return device_flags; +} + +const XlaOpsCommonFlags& GetXlaOpsCommonFlags() { + std::call_once(flags_init, &AllocateAndParseFlags); + return *ops_flags; +} + +void AppendMarkForCompilationPassFlags(std::vector* flag_list) { + std::call_once(flags_init, &AllocateAndParseFlags); + AppendMarkForCompilationPassFlagsInternal(flag_list); +} + +void AppendDumpGraphFlags(std::vector* flag_list) { + std::call_once(flags_init, &AllocateAndParseFlags); + AppendDumpGraphFlagsInternal(flag_list); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h b/tensorflow/compiler/jit/flags.h similarity index 57% rename from tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h rename to tensorflow/compiler/jit/flags.h index 79b47357a17..5ddea588eef 100644 --- a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h +++ b/tensorflow/compiler/jit/flags.h @@ -13,10 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_MARK_FOR_COMPILATION_PASS_FLAGS_H_ -#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_MARK_FOR_COMPILATION_PASS_FLAGS_H_ - -// Legacy flags for the XLA bridge's mark_for_compilation_pass module. +#ifndef TENSORFLOW_COMPILER_JIT_FLAGS_H_ +#define TENSORFLOW_COMPILER_JIT_FLAGS_H_ #include @@ -24,15 +22,8 @@ limitations under the License. #include "tensorflow/core/util/command_line_flags.h" namespace tensorflow { -namespace legacy_flags { -// Append to *flag_list flag definitions associated with the XLA bridge's -// mark_for_compilation_pass module. -void AppendMarkForCompilationPassFlags( - std::vector* flag_list); - -// The values of flags associated with the XLA bridge's -// mark_for_compilation_pass module. +// Flags associated with the XLA bridge's mark_for_compilation_pass module. struct MarkForCompilationPassFlags { int32 tf_xla_auto_jit; // Control compilation of operators into XLA // computations on CPU and GPU devices. 0 = use @@ -57,12 +48,56 @@ struct MarkForCompilationPassFlags { // only using XLA. }; -// Return a pointer to the MarkForCompilationPassFlags struct; +// Flags associated with the XLA bridge's xla_device module. +struct XlaDeviceFlags { + // Switch the CPU device into "on-demand" mode, where instead of + // autoclustering ops are compiled one by one just-in-time. + // Enabling this mode by a legacy flag is a temporary mechanism. When this + // feature is battle-tested, we will switch this to be a session option. + bool tf_xla_compile_on_demand; +}; + +// Flags common to the _Xla* ops and their kernels. +struct XlaOpsCommonFlags { + // If true, _XlaCompile always refuses to compile the cluster, which means the + // XLA clusters always run in the TF executor. Defaults to false. + bool tf_xla_always_defer_compilation; +}; + +// Flags for the build_xla_ops pass. +struct BuildXlaOpsPassFlags { + // Enables lazy compilation for TF/XLA (only when auto-clustering) if true. + // Defaults to true. + bool tf_xla_enable_lazy_compilation; +}; + +// Flags for the XLA bridge's dump_graph module. +struct DumpGraphFlags { + // Path prefix to which graphs dumped during debugging should be written. + string tf_dump_graph_prefix; +}; + +// Return a pointer to the DumpGraphFlags struct; // repeated calls return the same pointer. // This should be called only after Flags::Parse() has returned. -MarkForCompilationPassFlags* GetMarkForCompilationPassFlags(); -} // namespace legacy_flags +// Getters for flags structs defined above. The first call to any of these +// parses TF_XLA_FLAGS for all of them. Those functions which return a pointer +// always return the same pointer. +MarkForCompilationPassFlags* GetMarkForCompilationPassFlags(); +const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags(); +XlaDeviceFlags* GetXlaDeviceFlags(); +const XlaOpsCommonFlags& GetXlaOpsCommonFlags(); +DumpGraphFlags* GetDumpGraphFlags(); + +// Appends the flag definitions associated with +// MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`. +// +// Has the side-effect of parsing TF_XLA_FLAGS if that hasn't happened yet. +void AppendMarkForCompilationPassFlags( + std::vector* flag_list); +void AppendDumpGraphFlags(std::vector* flag_list); + } // namespace tensorflow -#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_MARK_FOR_COMPILATION_PASS_FLAGS_H_ +#endif // TENSORFLOW_COMPILER_JIT_FLAGS_H_ diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc index d984ca15cb7..ce53f70b79d 100644 --- a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc +++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/const_op.h" #include "tensorflow/cc/ops/math_ops.h" -#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" @@ -208,8 +208,12 @@ Status ComputeSliceSize(const Scope& host_scope, DCHECK_EQ(slice_size.back().type(), DT_INT64); } - *size = ops::Concat(host_scope.WithOpName("slice_size"), slice_size, - ops::Const(host_scope.WithOpName("concat_axis"), 0)); + // Trivial ConcatV2 nodes (with exactly one input) are disallowed. + *size = + slice_size.size() == 1 + ? slice_size[0] + : ops::Concat(host_scope.WithOpName("slice_size"), slice_size, + ops::Const(host_scope.WithOpName("concat_axis"), 0)); return Status::OK(); } @@ -242,6 +246,9 @@ Status ConvertTensorFlowSliceToStaticShapedSlice( .WithOpName("static_shaped_slice"), slice_inputs_int64.input, slice_inputs_int64.begin, slice_size) .node(); + + TF_RETURN_IF_ERROR(main_scope.status()); + std::vector compile_time_const_inputs; compile_time_const_inputs.push_back("size"); (*result)->AddAttr(kXlaCompileTimeConstantInputsAttr, @@ -284,49 +291,45 @@ Status RewriteSlice(Graph* g, Node* slice, const SliceInputs& slice_inputs, return Status::OK(); } -// If `n` is a slice we can rewrite to have a static shape (i.e. have the output -// shape only depend on the "size" input) then returns the a SliceInputs -// representing the inputs to `n`. Otherwise returns nullopt. -StatusOrOptional IsRewritableSlice(Node* n) { +// Return true if `n` is a slice we can rewrite to have a static shape +// (i.e. have the output shape only depend on the "size" input). +xla::StatusOr IsRewritableSlice(Node* n) { if (n->type_string() != "Slice") { - return {absl::nullopt}; + return false; } if (!GetXlaClusterForNode(*n).has_value()) { // There is no need to change slice ops outside XLA clusters. - return {absl::nullopt}; + return false; } TF_ASSIGN_OR_RETURN(absl::optional slice_inputs, GetSliceInputs(n)); if (!slice_inputs.has_value()) { - return {absl::nullopt}; + return false; } // If slice_size[i] < -1 for any i then executing the slice will throw an // error, and we don't do anything here. - bool slice_is_ok = absl::c_all_of(slice_inputs->size_as_vector, - [](int64 size_i) { return size_i >= -1; }); - if (!slice_is_ok) { - return {absl::nullopt}; - } - - return slice_inputs; + return absl::c_all_of(slice_inputs->size_as_vector, + [](int64 size_i) { return size_i >= -1; }); } Status FindAndRewriteSlices(Graph* g, bool* changed) { - std::vector> slices_to_rewrite; + std::vector slices_to_rewrite; for (Node* n : g->nodes()) { - TF_ASSIGN_OR_RETURN(absl::optional slice_inputs, - IsRewritableSlice(n)); - if (slice_inputs.has_value()) { - slices_to_rewrite.push_back({n, std::move(*slice_inputs)}); + TF_ASSIGN_OR_RETURN(bool is_rewritable, IsRewritableSlice(n)); + if (is_rewritable) { + slices_to_rewrite.push_back(n); } } - for (const auto& pair : slices_to_rewrite) { - TF_RETURN_IF_ERROR(RewriteSlice(g, pair.first, pair.second, - *GetXlaClusterForNode(*pair.first))); + for (Node* n : slices_to_rewrite) { + TF_ASSIGN_OR_RETURN(absl::optional slice_inputs, + GetSliceInputs(n)); + TF_RET_CHECK(slice_inputs.has_value()); + TF_RETURN_IF_ERROR( + RewriteSlice(g, n, *slice_inputs, *GetXlaClusterForNode(*n))); } if (!slices_to_rewrite.empty()) { @@ -342,8 +345,7 @@ Status FindAndRewriteSlices(Graph* g, bool* changed) { Status IncreaseDynamismForAutoJitPass::Run( const GraphOptimizationPassOptions& options) { - legacy_flags::MarkForCompilationPassFlags* flags = - legacy_flags::GetMarkForCompilationPassFlags(); + MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); if (flags->tf_xla_clustering_debug) { dump_graph::DumpGraphToFile("before_increase_dynamism_for_auto_jit_pass", **options.graph, options.flib_def); diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc index 0f6f612e967..a2f1b831ad7 100644 --- a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc +++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc @@ -27,6 +27,7 @@ limitations under the License. namespace tensorflow { namespace { +using ::testing::_; using testing::matchers::AssignedDevice; using testing::matchers::Attr; using testing::matchers::Const; @@ -142,6 +143,26 @@ TEST(SliceToDynamicSliceRewriteTest, Basic) { EXPECT_THAT(static_shaped_slice, m_dynamic_slice); } +TEST(SliceToDynamicSliceRewriteTest, SliceFromVector) { + Scope root = Scope::NewRootScope() + .ExitOnError() + .WithAssignedDevice(kDeviceName) + .WithXlaCluster("cluster_0"); + + Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT); + Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32); + Output size = ops::Const(root.WithOpName("size"), {-1}); + Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size); + + std::unique_ptr result; + TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result)); + + Node* static_shaped_slice = testing::FindNodeByName( + result.get(), "slice/static_shaped_slice/static_shaped_slice"); + EXPECT_NE(static_shaped_slice, nullptr); + EXPECT_THAT(result->nodes(), Not(Contains(NodeWith(Op("ConcatV2"))))); +} + TEST(SliceToDynamicSliceRewriteTest, ControlDependencePreserved) { Scope root = Scope::NewRootScope() .ExitOnError() @@ -166,18 +187,18 @@ TEST(SliceToDynamicSliceRewriteTest, ControlDependencePreserved) { CtrlDeps(NodeWith(Op("Placeholder"), Name("control"))))); } +int64 ToInt64(int v) { return static_cast(v); } + TEST(SliceToDynamicSliceRewriteTest, Int64Indices) { Scope root = Scope::NewRootScope() .ExitOnError() .WithAssignedDevice(kDeviceName) .WithXlaCluster("cluster_0"); - auto to_int64 = [](int v) { return static_cast(v); }; - Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT); Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT64); Output size = - ops::Const(root.WithOpName("size"), {to_int64(-1), to_int64(500)}); + ops::Const(root.WithOpName("size"), {ToInt64(-1), ToInt64(500)}); Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size); std::unique_ptr result; @@ -252,13 +273,35 @@ TEST(SliceToDynamicSliceRewriteTest, DontRewriteSliceWithNonConstSize) { Attr(kXlaCompileTimeConstantInputsAttr))))); } +TEST(SliceToDynamicSliceRewriteTest, ScalarSlice) { + Scope root = Scope::NewRootScope() + .ExitOnError() + .WithAssignedDevice(kDeviceName) + .WithXlaCluster("cluster_0"); + + Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT); + Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT64); + Output size = ops::Const(root.WithOpName("size"), {}); + Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size); + + std::unique_ptr result; + TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result)); + + Node* static_shaped_slice = testing::FindNodeByName( + result.get(), "slice/static_shaped_slice/static_shaped_slice"); + ASSERT_NE(static_shaped_slice, nullptr); + EXPECT_THAT(static_shaped_slice, + NodeWith(Op("Slice"), Attr(kXlaCompileTimeConstantInputsAttr), + Inputs(_, _, Out(NodeWith(Name(size.node()->name())))))); +} + TEST(SliceToDynamicSliceRewriteTest, IndicesNotVector) { Scope root = Scope::NewRootScope() .ExitOnError() .WithAssignedDevice(kDeviceName) .WithXlaCluster("cluster_0"); - auto to_int64 = [](int v) { return static_cast(v); }; + auto ToInt64 = [](int v) { return static_cast(v); }; Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT); Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT64); @@ -271,7 +314,7 @@ TEST(SliceToDynamicSliceRewriteTest, IndicesNotVector) { ops::Slice(root.WithOpName("slice"), input, begin, size_placeholder); Output size = - ops::Const(root.WithOpName("size"), {{to_int64(-1)}, {to_int64(500)}}); + ops::Const(root.WithOpName("size"), {{ToInt64(-1)}, {ToInt64(500)}}); TF_ASSERT_OK(root.graph()->UpdateEdge(size.node(), 0, slice.node(), 2)); std::unique_ptr result; @@ -281,5 +324,82 @@ TEST(SliceToDynamicSliceRewriteTest, IndicesNotVector) { Not(Contains(NodeWith(Op("Slice"), Attr(kXlaCompileTimeConstantInputsAttr))))); } + +TEST(SliceToDynamicSliceRewriteTest, SliceWithSliceInput) { + Scope root = Scope::NewRootScope() + .ExitOnError() + .WithAssignedDevice(kDeviceName) + .WithXlaCluster("cluster_0"); + + Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT); + Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32); + Output size_a = ops::Const(root.WithOpName("size_a"), {-1, 500}); + Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size_a); + + Output size_b = ops::Const(root.WithOpName("size_a"), {-1, 200}); + Output slice_with_slice_input = ops::Slice( + root.WithOpName("slice_with_slice_input"), slice, begin, size_b); + + std::unique_ptr result; + TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result)); + + Node* static_shaped_slice = testing::FindNodeByName( + result.get(), + "slice_with_slice_input/static_shaped_slice/static_shaped_slice"); + ASSERT_NE(static_shaped_slice, nullptr); + EXPECT_EQ(static_shaped_slice->output_type(0), DT_FLOAT) + << "Expected DT_FLOAT, was " + << DataType_Name(static_shaped_slice->output_type(0)); + EXPECT_THAT( + static_shaped_slice, + NodeWith( + Op("Slice"), + Inputs(Out(NodeWith( + Op("Slice"), + Name("slice/static_shaped_slice/static_shaped_slice"))), + _, _))); +} + +TEST(SliceToDynamicSliceRewriteTest, SliceWithSliceBegin) { + Scope root = Scope::NewRootScope() + .ExitOnError() + .WithAssignedDevice(kDeviceName) + .WithXlaCluster("cluster_0"); + + Output input_float = + ops::Placeholder(root.WithOpName("input_float"), DT_FLOAT); + Output input_i64 = ops::Placeholder(root.WithOpName("input_i64"), DT_INT64); + + Output begin_begin = + ops::Placeholder(root.WithOpName("begin_begin"), DT_INT32); + Output begin_size = ops::Const(root.WithOpName("begin_size"), {-1}); + Output begin = + ops::Slice(root.WithOpName("begin"), input_i64, begin_begin, begin_size); + + Output size = + ops::Const(root.WithOpName("size"), {ToInt64(-1), ToInt64(200)}); + Output slice_with_slice_begin = ops::Slice( + root.WithOpName("slice_with_slice_begin"), input_float, begin, size); + + std::unique_ptr result; + TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result)); + + Node* static_shaped_slice = testing::FindNodeByName( + result.get(), + "slice_with_slice_begin/static_shaped_slice/static_shaped_slice"); + ASSERT_NE(static_shaped_slice, nullptr); + EXPECT_EQ(static_shaped_slice->output_type(0), DT_FLOAT) + << "Expected DT_FLOAT, was " + << DataType_Name(static_shaped_slice->output_type(0)); + EXPECT_THAT( + static_shaped_slice, + NodeWith( + Op("Slice"), + Inputs(_, + Out(NodeWith( + Op("Slice"), + Name("begin/static_shaped_slice/static_shaped_slice"))), + _))); +} } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index 830db9ebdd9..0583774714c 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -12,10 +12,10 @@ cc_library( hdrs = ["xla_ops.h"], deps = [ "//tensorflow/compiler/jit:common", + "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/jit:xla_compilation_cache", "//tensorflow/compiler/jit:xla_device", "//tensorflow/compiler/jit:xla_launch_util", - "//tensorflow/compiler/jit/legacy_flags:xla_ops_common_flags", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_compiler", diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index 055de7afcc5..ad71df5a694 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -18,7 +18,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/jit/defs.h" -#include "tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -418,7 +418,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { cannot_compile_cluster = cannot_compile_cluster_; } - if (legacy_flags::GetXlaOpsCommonFlags().tf_xla_always_defer_compilation || + if (GetXlaOpsCommonFlags().tf_xla_always_defer_compilation || cannot_compile_cluster) { executable = nullptr; } else { diff --git a/tensorflow/compiler/jit/legacy_flags/BUILD b/tensorflow/compiler/jit/legacy_flags/BUILD deleted file mode 100644 index 5fa6c85f06f..00000000000 --- a/tensorflow/compiler/jit/legacy_flags/BUILD +++ /dev/null @@ -1,65 +0,0 @@ -# Legacy command line flags for the XLA bridge libraries. - -# Please do not add more flags to this package. - -# The XLA bridge libraries were written in an environment that allowed -# command-line flags to be scattered freely throughout the libraries. This -# model, while initially convenient, leads to a proliferation in unused command -# line flags in tests and binaries, and serious problems in servers, where one -# might wish parameters to be different in independent RPC calls to the same -# routine. -# -# Please don't add more flags. If you're a library author, pass options and -# parameters explicitly through the library's interface. - -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//tensorflow:internal"]) - -cc_library( - name = "mark_for_compilation_pass_flags", - srcs = ["mark_for_compilation_pass_flags.cc"], - hdrs = ["mark_for_compilation_pass_flags.h"], - deps = - [ - "//tensorflow/compiler/xla:parse_flags_from_env", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "xla_device_flags", - srcs = ["xla_device_flags.cc"], - hdrs = ["xla_device_flags.h"], - deps = - [ - "//tensorflow/compiler/xla:parse_flags_from_env", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "build_xla_ops_pass_flags", - srcs = ["build_xla_ops_pass_flags.cc"], - hdrs = ["build_xla_ops_pass_flags.h"], - deps = - [ - "//tensorflow/compiler/xla:parse_flags_from_env", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "xla_ops_common_flags", - srcs = ["xla_ops_common_flags.cc"], - hdrs = ["xla_ops_common_flags.h"], - deps = - [ - "//tensorflow/compiler/xla:parse_flags_from_env", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) diff --git a/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.cc b/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.cc deleted file mode 100644 index 961c17c17ea..00000000000 --- a/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.cc +++ /dev/null @@ -1,47 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include // NOLINT - -#include "tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.h" -#include "tensorflow/compiler/xla/parse_flags_from_env.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { -namespace { - -BuildXlaOpsPassFlags* flags; -std::vector* flag_list; -std::once_flag flags_init; - -void AllocateAndParseFlags() { - flags = new BuildXlaOpsPassFlags; - flags->tf_xla_enable_lazy_compilation = true; - flag_list = new std::vector({ - Flag("tf_xla_enable_lazy_compilation", - &flags->tf_xla_enable_lazy_compilation, ""), - }); - xla::ParseFlagsFromEnv(*flag_list); -} - -} // namespace - -const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags() { - std::call_once(flags_init, &AllocateAndParseFlags); - return *flags; -} -} // namespace legacy_flags -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.h b/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.h deleted file mode 100644 index 9aa5cf64d6d..00000000000 --- a/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.h +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_BUILD_XLA_OPS_PASS_FLAGS_H_ -#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_BUILD_XLA_OPS_PASS_FLAGS_H_ - -namespace tensorflow { -namespace legacy_flags { - -// Flags for the build_xla_ops pass. -struct BuildXlaOpsPassFlags { - // Enables lazy compilation for TF/XLA (only when auto-clustering) if true. - // Defaults to true. - bool tf_xla_enable_lazy_compilation; -}; - -// Parses the flags in BuildXlaOpsPassFlags from the TF_XLA_FLAGS environment -// variable and returns a reference to the parsed copy. Parses TF_XLA_FLAGS -// only the first time this routine is called. -const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags(); - -} // namespace legacy_flags -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_BUILD_XLA_OPS_PASS_FLAGS_H_ diff --git a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc deleted file mode 100644 index bad306e0b0a..00000000000 --- a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc +++ /dev/null @@ -1,98 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for the XLA bridge's mark_for_compilation_pass module. - -#include -#include - -#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" -#include "tensorflow/compiler/xla/parse_flags_from_env.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static MarkForCompilationPassFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new MarkForCompilationPassFlags; - flags->tf_xla_auto_jit = 0; - flags->tf_xla_min_cluster_size = 2; - flags->tf_xla_max_cluster_size = std::numeric_limits::max(); - flags->tf_xla_clustering_debug = false; - flags->tf_xla_cpu_global_jit = false; - flags->tf_xla_clustering_fuel = std::numeric_limits::max(); - flags->tf_xla_fusion_only = false; - flag_list = new std::vector( - {Flag("tf_xla_auto_jit", &flags->tf_xla_auto_jit, - "Control compilation of operators into XLA computations on CPU and " - "GPU devices. 0 = use ConfigProto setting; -1 = off; 1 = on for " - "things very likely to be improved; 2 = on for everything. " - "Experimental."), - Flag("tf_xla_min_cluster_size", &flags->tf_xla_min_cluster_size, - "Minimum number of operators in an XLA compilation. Ignored for " - "operators placed on an XLA device or operators explicitly marked " - "for compilation."), - Flag("tf_xla_max_cluster_size", &flags->tf_xla_max_cluster_size, - "Maximum number of operators in an XLA compilation."), - Flag("tf_xla_clustering_debug", &flags->tf_xla_clustering_debug, - "Dump graphs during XLA compilation."), - Flag("tf_xla_cpu_global_jit", &flags->tf_xla_cpu_global_jit, - "Enables global JIT compilation for CPU via SessionOptions."), - Flag("tf_xla_clustering_fuel", &flags->tf_xla_clustering_fuel, - "Places an artificial limit on the number of ops marked as " - "eligible for clustering."), - Flag("tf_xla_fusion_only", &flags->tf_xla_fusion_only, - "enable fusion of element-wise operations only using XLA when " - "global_jit_level is ON*.")}); - xla::ParseFlagsFromEnv(*flag_list); - - if (VLOG_IS_ON(1)) { - VLOG(1) << "Parsed MarkForCompilationPassFlags:"; - VLOG(1) << " tf_xla_auto_jit = " << flags->tf_xla_auto_jit; - VLOG(1) << " tf_xla_min_cluster_size = " << flags->tf_xla_min_cluster_size; - VLOG(1) << " tf_xla_max_cluster_size = " << flags->tf_xla_max_cluster_size; - VLOG(1) << " tf_xla_clustering_debug = " << flags->tf_xla_clustering_debug; - VLOG(1) << " tf_xla_cpu_global_jit = " << flags->tf_xla_cpu_global_jit; - VLOG(1) << " tf_xla_clustering_fuel = " << flags->tf_xla_clustering_fuel; - VLOG(1) << " tf_xla_fusion_only = " << flags->tf_xla_fusion_only; - } -} - -// Append to *append_to flag definitions associated with the XLA bridge's -// mark_for_compilation_pass module. -void AppendMarkForCompilationPassFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the MarkForCompilationPassFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -MarkForCompilationPassFlags* GetMarkForCompilationPassFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/legacy_flags/xla_device_flags.cc b/tensorflow/compiler/jit/legacy_flags/xla_device_flags.cc deleted file mode 100644 index 76b80d3034c..00000000000 --- a/tensorflow/compiler/jit/legacy_flags/xla_device_flags.cc +++ /dev/null @@ -1,56 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for the XLA bridge's xla_device module. - -#include -#include - -#include "tensorflow/compiler/jit/legacy_flags/xla_device_flags.h" -#include "tensorflow/compiler/xla/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static XlaDeviceFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new XlaDeviceFlags; - flags->tf_xla_compile_on_demand = false; - flag_list = new std::vector({ - Flag("tf_xla_compile_on_demand", &flags->tf_xla_compile_on_demand, - "Switch a device into 'on-demand' mode, where instead of " - "autoclustering ops are compiled one by one just-in-time."), - }); - xla::ParseFlagsFromEnv(*flag_list); -} - -// Return a pointer to the XlaDeviceFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -XlaDeviceFlags* GetXlaDeviceFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/legacy_flags/xla_device_flags.h b/tensorflow/compiler/jit/legacy_flags/xla_device_flags.h deleted file mode 100644 index 27b22121ac1..00000000000 --- a/tensorflow/compiler/jit/legacy_flags/xla_device_flags.h +++ /dev/null @@ -1,47 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_DEVICE_FLAGS_H_ -#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_DEVICE_FLAGS_H_ - -// Legacy flags for the XLA bridge's xla_device module. - -#include - -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -// The values of flags associated with the XLA bridge's -// xla_device module. -typedef struct { - // Switch the CPU device into "on-demand" mode, where instead of - // autoclustering ops are compiled one by one just-in-time. - // Enabling this mode by a legacy flag is a temporary mechanism. When this - // feature is battle-tested, we will switch this to be a session option. - bool tf_xla_compile_on_demand; -} XlaDeviceFlags; - -// Return a pointer to the XlaDeviceFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -XlaDeviceFlags* GetXlaDeviceFlags(); - -} // namespace legacy_flags -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_DEVICE_FLAGS_H_ diff --git a/tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.cc b/tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.cc deleted file mode 100644 index 1443d48a734..00000000000 --- a/tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.cc +++ /dev/null @@ -1,52 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include // NOLINT -#include - -#include "tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.h" -#include "tensorflow/compiler/xla/parse_flags_from_env.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -XlaOpsCommonFlags* flags; -std::vector* flag_list; -std::once_flag flags_init; - -void AllocateAndParseFlags() { - flags = new XlaOpsCommonFlags; - flags->tf_xla_always_defer_compilation = false; - flag_list = new std::vector({ - Flag("tf_xla_always_defer_compilation", - &flags->tf_xla_always_defer_compilation, ""), - }); - xla::ParseFlagsFromEnv(*flag_list); - - if (VLOG_IS_ON(1)) { - VLOG(1) << "Parsed XlaOpsCommonFlags:"; - VLOG(1) << " tf_xla_always_defer_compilation = " - << flags->tf_xla_always_defer_compilation; - } -} - -const XlaOpsCommonFlags& GetXlaOpsCommonFlags() { - std::call_once(flags_init, &AllocateAndParseFlags); - return *flags; -} -} // namespace legacy_flags -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.h b/tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.h deleted file mode 100644 index 7c5c1818ef2..00000000000 --- a/tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.h +++ /dev/null @@ -1,36 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_OPS_COMMON_FLAGS_H_ -#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_OPS_COMMON_FLAGS_H_ - -namespace tensorflow { -namespace legacy_flags { - -// Flags common to the _Xla* ops and their kernels. -struct XlaOpsCommonFlags { - // If true, _XlaCompile always refuses to compile the cluster, which means the - // XLA clusters always run in the TF executor. Defaults to false. - bool tf_xla_always_defer_compilation; -}; - -// Parses the flags in XlaOpsCommonFlags from the TF_XLA_FLAGS environment -// variable and returns a reference to the parsed copy. Parses TF_XLA_FLAGS -// only the first time this routine is called. -const XlaOpsCommonFlags& GetXlaOpsCommonFlags(); - -} // namespace legacy_flags -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_OPS_COMMON_FLAGS_H_ diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 70033cae0af..6618e3a58ab 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -24,8 +24,8 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/jit/deadness_analysis.h" #include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" -#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" #include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" @@ -72,6 +72,11 @@ struct OperationFilter { // to resort to a dummy implementation. Currently Assert and CheckNumerics ops // have dummy XLA implementations. bool allow_dummy_ops; + + // Whether ops that produce or consume DT_VARIANT values are allowed. We + // don't auto-cluster these ops because we don't yet support live-in or + // live-out DT_VARIANT values. + bool allow_ops_producing_or_consuming_variant; }; bool IsDummyImplOp(absl::string_view op_name) { @@ -81,7 +86,13 @@ bool IsDummyImplOp(absl::string_view op_name) { bool IsStatefulRandomOp(absl::string_view op_name) { return op_name == "RandomUniform" || op_name == "RandomShuffle" || op_name == "RandomUniformInt" || op_name == "RandomStandardNormal" || - op_name == "TruncatedNormal"; + op_name == "TruncatedNormal" || op_name == "Multinomial"; +} + +bool OpProducesOrConsumesVariant(const Node& node) { + auto is_variant = [](DataType dtype) { return dtype == DT_VARIANT; }; + return absl::c_any_of(node.input_types(), is_variant) || + absl::c_any_of(node.output_types(), is_variant); } bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) { @@ -246,6 +257,10 @@ bool IsCompilableCall(const NodeDef& call_def, if (!op_filter.allow_dummy_ops && IsDummyImplOp(node->type_string())) { return false; } + if (!op_filter.allow_ops_producing_or_consuming_variant && + OpProducesOrConsumesVariant(*node)) { + return false; + } if (!HasXLAKernel(*node, jit_device_type) && !IsCompilableCall(node->def(), jit_device_type, op_filter, depth + 1, lib_runtime)) { @@ -427,8 +442,7 @@ Status FindCompilationCandidates( BackwardsConstAnalysis(graph, /*compile_time_const_arg_indices=*/nullptr, &compile_time_const_nodes)); - int64& fuel = - legacy_flags::GetMarkForCompilationPassFlags()->tf_xla_clustering_fuel; + int64& fuel = GetMarkForCompilationPassFlags()->tf_xla_clustering_fuel; // Iterate over nodes in sorted order so that compiler fuel is deterministic. // We can't simply pass op_nodes().begin() and op_nodes().end to the @@ -471,16 +485,15 @@ Status FindCompilationCandidates( XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)); DeviceType jit_device_type(registration->compilation_device_name); + bool always_auto_cluster = registration->autoclustering_policy == + XlaOpRegistry::AutoclusteringPolicy::kAlways; + OperationFilter op_filter; op_filter.allow_resource_ops = registration->compile_resource_ops; - op_filter.allow_stateful_rng_ops = - (registration->autoclustering_policy == - XlaOpRegistry::AutoclusteringPolicy::kAlways); - op_filter.allow_control_trigger = - (registration->autoclustering_policy == - XlaOpRegistry::AutoclusteringPolicy::kAlways); - op_filter.allow_dummy_ops = (registration->autoclustering_policy == - XlaOpRegistry::AutoclusteringPolicy::kAlways); + op_filter.allow_stateful_rng_ops = always_auto_cluster; + op_filter.allow_control_trigger = always_auto_cluster; + op_filter.allow_dummy_ops = always_auto_cluster; + op_filter.allow_ops_producing_or_consuming_variant = always_auto_cluster; if (!HasXLAKernel(*node, jit_device_type) && !IsCompilableCall(node->def(), jit_device_type, op_filter, 0, @@ -504,6 +517,12 @@ Status FindCompilationCandidates( << node->type_string() << ")"; continue; } + if (!op_filter.allow_ops_producing_or_consuming_variant && + OpProducesOrConsumesVariant(*node)) { + VLOG(2) << "Rejecting " << node->name() + << ": produces or consumes DT_VARIANT"; + continue; + } if (!op_filter.allow_resource_ops && (HasResourceOutput(*node) || IsNonResourceVarResourceOp(*node))) { @@ -607,8 +626,7 @@ OptimizerOptions::GlobalJitLevel GetGlobalJitLevel( // To set compilation to be on by default, change the following line. global_jit_level = OptimizerOptions::OFF; } - legacy_flags::MarkForCompilationPassFlags* flags = - legacy_flags::GetMarkForCompilationPassFlags(); + MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); if (flags->tf_xla_auto_jit == -1 || (1 <= flags->tf_xla_auto_jit && flags->tf_xla_auto_jit <= 2)) { // If the flag tf_xla_auto_jit is a valid, non-zero setting, it overrides @@ -641,6 +659,7 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) { op_filter.allow_stateful_rng_ops = true; op_filter.allow_control_trigger = true; op_filter.allow_dummy_ops = true; + op_filter.allow_ops_producing_or_consuming_variant = true; return IsCompilableCall(ndef, jit_device_type, op_filter, 0, flr); } @@ -651,8 +670,7 @@ Status MarkForCompilationPass::Run( // device ahead of time. OptimizerOptions::GlobalJitLevel global_jit_level = GetGlobalJitLevel(options); - legacy_flags::MarkForCompilationPassFlags* flags = - legacy_flags::GetMarkForCompilationPassFlags(); + MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); bool fusion_only = flags->tf_xla_fusion_only; VLOG(1) << "flags->tf_xla_fusion_only = " << flags->tf_xla_fusion_only; @@ -953,8 +971,7 @@ Status MarkForCompilationPass::RunImpl( OptimizerOptions::GlobalJitLevel global_jit_level = GetGlobalJitLevel(options); - legacy_flags::MarkForCompilationPassFlags* flags = - legacy_flags::GetMarkForCompilationPassFlags(); + MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); // Repeatedly contract edges between clusters that are on the same device, // provided the contraction would not create a cycle. diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 24d78c07726..bf2c5508ea9 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/control_flow_ops_internal.h" #include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/list_ops.h" #include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/sendrecv_ops.h" #include "tensorflow/cc/ops/standard_ops.h" @@ -1147,5 +1148,80 @@ TEST(XlaCompilationTest, DontAutoClusterDummyOps) { EXPECT_EQ(clusters["test/check"], ""); } +TEST(XlaCompilationTest, DontAutoClusterOpsProducingVariant) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output a = ops::Placeholder(root.WithOpName("test/a"), DT_INT64); + Output b = ops::Placeholder(root.WithOpName("test/b"), DT_INT64); + + Output cast_a = ops::Cast(root.WithOpName("test/cast_a"), a, DT_INT32); + Output cast_b = ops::Cast(root.WithOpName("test/cast_b"), b, DT_INT32); + + Output tensor_list_reserve = ops::TensorListReserve( + root.WithOpName("test/tensor_list_reserve"), cast_a, cast_b, DT_FLOAT); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + EXPECT_EQ(clusters["test/tensor_list_reserve"], ""); +} + +TEST(XlaCompilationTest, DontAutoClusterOpsConsumingVariant) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output dummy_input = + ops::Placeholder(root.WithOpName("test/dummy_input"), DT_INT64); + Output variant_input = + ops::Placeholder(root.WithOpName("test/variant_input"), DT_VARIANT); + + // Create one more node so that we don't avoid creating a cluster solely + // because it would be trivial. + Output dummy_cast = + ops::Cast(root.WithOpName("test/dummy_cast"), dummy_input, DT_INT32); + + Output tensor_list_element_shape = ops::TensorListElementShape( + root.WithOpName("test/tensor_list_element_shape"), variant_input, + DT_INT32); + + root.graph()->AddControlEdge(dummy_cast.node(), + tensor_list_element_shape.node()); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + EXPECT_EQ(clusters["test/tensor_list_element_shape"], ""); +} + +TEST(XlaCompilationTest, ClusterOpsProducingVariantIfOnXlaDevice) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output a = ops::Placeholder(root.WithOpName("test/a"), DT_INT64); + Output b = ops::Placeholder(root.WithOpName("test/b"), DT_INT64); + + Output cast_a = ops::Cast(root.WithOpName("test/cast_a"), a, DT_INT32); + Output cast_b = ops::Cast(root.WithOpName("test/cast_b"), b, DT_INT32); + + Output tensor_list_reserve = ops::TensorListReserve( + root.WithOpName("test/tensor_list_reserve"), cast_a, cast_b, DT_FLOAT); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + string xla_cpu_device = "/job:worker/replica:0/task:0/device:XLA_CPU:0"; + for (Node* n : graph->nodes()) { + if (absl::StartsWith(n->name(), /*prefix=*/"test/")) { + n->set_assigned_device_name(xla_cpu_device); + } + } + + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + EXPECT_NE(clusters["test/tensor_list_reserve"], ""); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc index d56d0f8ccfc..64a33017457 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc @@ -34,15 +34,9 @@ namespace tensorflow { // // It may be worth refactoring out XlaOpRegistry::RegisterCompilationDevice to // make this more direct, but probably not worth it solely for this test. - std::vector devices; + std::vector> devices; TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(*session_options, "", &devices)); - auto delete_devices = gtl::MakeCleanup([&] { - for (Device* d : devices) { - delete d; - } - }); - GraphOptimizationPassOptions opt_options; opt_options.graph = graph; opt_options.session_options = session_options; diff --git a/tensorflow/compiler/jit/ops/BUILD b/tensorflow/compiler/jit/ops/BUILD index f72224545b2..64409d93347 100644 --- a/tensorflow/compiler/jit/ops/BUILD +++ b/tensorflow/compiler/jit/ops/BUILD @@ -18,3 +18,9 @@ tf_gen_op_wrapper_py( out = "xla_ops.py", deps = ["//tensorflow/compiler/jit/ops:xla_ops"], ) + +py_library( + name = "xla_ops_grad", + srcs = ["xla_ops_grad.py"], + deps = ["//tensorflow/python:framework_ops"], +) diff --git a/tensorflow/contrib/estimator/python/estimator/dnn.py b/tensorflow/compiler/jit/ops/xla_ops_grad.py similarity index 62% rename from tensorflow/contrib/estimator/python/estimator/dnn.py rename to tensorflow/compiler/jit/ops/xla_ops_grad.py index 10f657df8de..2d31d8dc714 100644 --- a/tensorflow/contrib/estimator/python/estimator/dnn.py +++ b/tensorflow/compiler/jit/ops/xla_ops_grad.py @@ -1,3 +1,4 @@ +"""Gradients for XLA ops.""" # Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,21 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""dnn python module. - -Importing from tensorflow.python.estimator is unsupported -and will soon break! -""" -# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow_estimator.contrib.estimator.python.estimator import dnn +from tensorflow.python.framework import ops -# Include attrs that start with single underscore. -_HAS_DYNAMIC_ATTRIBUTES = True -dnn.__all__ = [s for s in dir(dnn) if not s.startswith('__')] -from tensorflow_estimator.contrib.estimator.python.estimator.dnn import * +@ops.RegisterGradient("XlaClusterOutput") +def _XlaClusterOutputGrad(_, grad): + del grad # unused + raise RuntimeError("Gradient computation of graph in xla.compile() is " + "prohibited because it can cause performance degradation." + "Please move gradient computation inside xla.compile().") diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc index 36b345ecbff..42ea3926e16 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass.cc @@ -26,6 +26,10 @@ limitations under the License. namespace tensorflow { namespace { + +bool NotBackedge(const Edge& edge) { return !edge.src()->IsNextIteration(); } + +namespace reduce_device_to_host_copies { Status FindNodesToDecluster(const Graph& graph, absl::flat_hash_set* result, absl::Span post_order) { @@ -140,8 +144,6 @@ Status PartiallyDeclusterNode(Graph* graph, Node* n) { return Status::OK(); } -bool NotBackedge(const Edge& edge) { return !edge.src()->IsNextIteration(); } - // Clones nodes to outside their cluster to avoid device-to-host copies. For // instance, converts this: // @@ -168,7 +170,7 @@ bool NotBackedge(const Edge& edge) { return !edge.src()->IsNextIteration(); } // where the ===> arrow has a hostmem source and destination and would entail a // device to host copy if the source and destination were not in the same XLA // cluster. -Status PartiallyDeclusterToRemoveDeviceToHostCopies(Graph* graph) { +Status PartiallyDeclusterGraph(Graph* graph) { // When deciding whether to decluster a particular node, we base our decision // on if we've decided that some of its consumers have to be declustered too. // Iterating the graph in post-order guarantees that consumers have been @@ -206,7 +208,9 @@ Status PartiallyDeclusterToRemoveDeviceToHostCopies(Graph* graph) { return Status::OK(); } +} // namespace reduce_device_to_host_copies +namespace reduce_recompilation { bool IsIntraClusterEdge(const Edge& edge) { absl::optional src_cluster_name = GetXlaClusterForNode(*edge.src()); @@ -269,7 +273,7 @@ Status MustCompileNode(const Node* n, bool* must_compile) { // regress performance in any significant manner. We will have to revisit this // algorith with a more complex cost model if this assumption turns out to be // incorrect. -Status DeclusterNodesToReduceRecompilations(Graph* graph) { +Status PartiallyDeclusterGraph(Graph* graph) { std::vector compile_time_const_nodes(graph->num_node_ids()); TF_RETURN_IF_ERROR(BackwardsConstAnalysis( *graph, nullptr, &compile_time_const_nodes, IsIntraClusterEdge)); @@ -322,7 +326,7 @@ Status DeclusterNodesToReduceRecompilations(Graph* graph) { return Status::OK(); } - +} // namespace reduce_recompilation } // namespace Status PartiallyDeclusterPass::Run( @@ -334,8 +338,9 @@ Status PartiallyDeclusterPass::Run( Graph* graph = options.graph->get(); - TF_RETURN_IF_ERROR(PartiallyDeclusterToRemoveDeviceToHostCopies(graph)); - TF_RETURN_IF_ERROR(DeclusterNodesToReduceRecompilations(graph)); + TF_RETURN_IF_ERROR( + reduce_device_to_host_copies::PartiallyDeclusterGraph(graph)); + TF_RETURN_IF_ERROR(reduce_recompilation::PartiallyDeclusterGraph(graph)); return Status::OK(); } diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc index 1fc5da5071f..38a54cc5efa 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass_test.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc @@ -386,7 +386,7 @@ TEST(PartiallyDeclusterPassTest, DontDeclusterXlaDeviceOps) { TF_ASSERT_OK(s.ToGraph(graph.get())); // This is needed to register the XLA_GPU device. - std::vector devices; + std::vector> devices; TF_ASSERT_OK(DeviceFactory::AddDevices( SessionOptions(), "/job:localhost/replica:0/task:0", &devices)); @@ -400,10 +400,6 @@ TEST(PartiallyDeclusterPassTest, DontDeclusterXlaDeviceOps) { TF_ASSERT_OK(PartiallyDecluster(&graph)); EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0"); - - for (Device* d : devices) { - delete d; - } } TEST(PartiallyDeclusterPassTest, DontDeclusterNonTensorFlowOps) { diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index 116e0756036..7df898ad12a 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -17,8 +17,8 @@ limitations under the License. // operators using XLA via the XLA "Host" (CPU) backend. #include "absl/memory/memory.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/kernels/xla_ops.h" -#include "tensorflow/compiler/jit/legacy_flags/xla_device_flags.h" #include "tensorflow/compiler/jit/xla_compile_on_demand_op.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_device_ops.h" @@ -31,13 +31,13 @@ namespace tensorflow { class XlaCpuDeviceFactory : public DeviceFactory { public: Status CreateDevices(const SessionOptions& options, const string& name_prefix, - std::vector* devices) override; + std::vector>* devices) override; }; -Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& session_options, - const string& name_prefix, - std::vector* devices) { - legacy_flags::XlaDeviceFlags* flags = legacy_flags::GetXlaDeviceFlags(); +Status XlaCpuDeviceFactory::CreateDevices( + const SessionOptions& session_options, const string& name_prefix, + std::vector>* devices) { + XlaDeviceFlags* flags = GetXlaDeviceFlags(); bool compile_on_demand = flags->tf_xla_compile_on_demand; XlaOpRegistry::DeviceRegistration registration; @@ -63,8 +63,7 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& session_options, options.device_ordinal = 0; options.compilation_device_name = DEVICE_CPU_XLA_JIT; options.use_multiple_streams = false; - auto device = absl::make_unique(session_options, options); - devices->push_back(device.release()); + devices->push_back(absl::make_unique(session_options, options)); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 5c1b55cb57f..4201ff91a89 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -218,6 +218,9 @@ XlaDevice::XlaDevice(const SessionOptions& session_options, XlaDevice::~XlaDevice() { VLOG(1) << "Destroying XLA device " << jit_device_name_ << " " << this; mutex_lock lock(mu_); + while (outstanding_asynchronous_operations_ > 0) { + outstanding_asynchronous_operations_cv_.wait(lock); + } if (device_context_) { device_context_->Unref(); } @@ -384,6 +387,7 @@ void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, Status XlaDevice::Sync() { VLOG(1) << "XlaDevice::Sync"; + tracing::ScopedActivity activity("XlaDevice::Sync", /*is_expensive=*/true); std::shared_ptr stream; { mutex_lock lock(mu_); @@ -391,13 +395,46 @@ Status XlaDevice::Sync() { } if (!stream) return Status::OK(); - if (!stream->parent()->SynchronizeAllActivity() || !stream->ok()) { + Status status = stream->BlockHostUntilDone(); + { + mutex_lock lock(mu_); + while (outstanding_asynchronous_operations_ > 0) { + outstanding_asynchronous_operations_cv_.wait(lock); + } + } + TF_RETURN_IF_ERROR(status); + if (!stream->ok()) { return errors::Internal("XlaDevice::Sync() failed."); } VLOG(1) << "XlaDevice::Sync completed"; return Status::OK(); } +void XlaDevice::Sync(const DoneCallback& done) { + VLOG(1) << "XlaDevice::Sync (asynchronous)"; + std::shared_ptr stream; + { + mutex_lock lock(mu_); + stream = stream_; + } + if (!stream) { + done(Status::OK()); + return; + } + + stream->ThenEnqueueOnBackgroundThread( + [this, stream, done](se::StreamExecutor*) { + tracing::ScopedActivity activity("XlaDevice::Sync::Callback", + /*is_expensive=*/true); + mutex_lock lock(mu_); + while (outstanding_asynchronous_operations_ > 0) { + outstanding_asynchronous_operations_cv_.wait(lock); + } + done(stream->ok() ? Status::OK() + : errors::Internal("XlaDevice::Sync() failed.")); + }); +} + Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs, Tensor* tensor) { @@ -441,6 +478,49 @@ bool XlaDevice::RequiresSyncOnCompletion() const { return sync_on_completion_; } +XlaDevice::AsynchronousOperationHandle::AsynchronousOperationHandle( + XlaDevice* device) + : device_(device) { + mutex_lock lock(device_->mu_); + ++device_->outstanding_asynchronous_operations_; +} + +XlaDevice::AsynchronousOperationHandle::~AsynchronousOperationHandle() { + if (device_) { + mutex_lock lock(device_->mu_); + --device_->outstanding_asynchronous_operations_; + device_->outstanding_asynchronous_operations_cv_.notify_all(); + } +} + +XlaDevice::AsynchronousOperationHandle::AsynchronousOperationHandle( + const XlaDevice::AsynchronousOperationHandle& other) + : device_(other.device_) { + mutex_lock lock(device_->mu_); + ++device_->outstanding_asynchronous_operations_; +} + +XlaDevice::AsynchronousOperationHandle::AsynchronousOperationHandle( + XlaDevice::AsynchronousOperationHandle&& other) + : device_(other.device_) { + other.device_ = nullptr; +} + +XlaDevice::AsynchronousOperationHandle& XlaDevice::AsynchronousOperationHandle:: +operator=(const XlaDevice::AsynchronousOperationHandle& other) { + device_ = other.device_; + mutex_lock lock(device_->mu_); + ++device_->outstanding_asynchronous_operations_; + return *this; +} + +XlaDevice::AsynchronousOperationHandle& XlaDevice::AsynchronousOperationHandle:: +operator=(XlaDevice::AsynchronousOperationHandle&& other) { + device_ = other.device_; + other.device_ = nullptr; + return *this; +} + XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device, const char* jit_device) { // Any op assigned to the device that isn't rewritten by the graph rewriter diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index 49f53b477ef..c8bb276cdb9 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -135,6 +135,7 @@ class XlaDevice : public LocalDevice { void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, AsyncOpKernel::DoneCallback done) override; Status Sync() override; + void Sync(const DoneCallback& done) override; Status FillContextMap(const Graph* graph, DeviceContextMap* device_context_map) override @@ -164,7 +165,30 @@ class XlaDevice : public LocalDevice { bool RequiresSyncOnCompletion() const override LOCKS_EXCLUDED(mu_); + // A simple RAII handle. On construction the device's + // outstanding_asynchronous_operations_ field is incremented; on destruction + // it is decremented. + class AsynchronousOperationHandle { + public: + AsynchronousOperationHandle(XlaDevice* device); + ~AsynchronousOperationHandle(); + AsynchronousOperationHandle(const AsynchronousOperationHandle& other); + AsynchronousOperationHandle(AsynchronousOperationHandle&& other); + AsynchronousOperationHandle& operator=( + const AsynchronousOperationHandle& other); + AsynchronousOperationHandle& operator=(AsynchronousOperationHandle&& other); + + private: + XlaDevice* device_ = nullptr; + }; + + AsynchronousOperationHandle CreateAsynchronousOperationHandle() { + return AsynchronousOperationHandle(this); + } + private: + friend class AsynchronousOperationHandle; + xla::LocalClient* client() const; Allocator* GetAllocatorLocked(AllocatorAttributes attr) EXCLUSIVE_LOCKS_REQUIRED(mu_); @@ -227,6 +251,11 @@ class XlaDevice : public LocalDevice { // True if the device requires XlaDevice::Sync to be called on completion // regardless of status. bool sync_on_completion_ GUARDED_BY(mu_) = false; + + // Count of outstanding asynchronous operations which must be zero on Sync() + // completion. + int64 outstanding_asynchronous_operations_ GUARDED_BY(mu_) = 0; + condition_variable outstanding_asynchronous_operations_cv_; }; // Builds OpKernel registrations on 'device' for the JIT operators diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index 44197016958..944f732b99c 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -29,12 +29,12 @@ namespace tensorflow { class XlaGpuDeviceFactory : public DeviceFactory { public: Status CreateDevices(const SessionOptions& options, const string& name_prefix, - std::vector* devices) override; + std::vector>* devices) override; }; -Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& session_options, - const string& name_prefix, - std::vector* devices) { +Status XlaGpuDeviceFactory::CreateDevices( + const SessionOptions& session_options, const string& name_prefix, + std::vector>* devices) { XlaOpRegistry::DeviceRegistration registration; registration.compilation_device_name = DEVICE_GPU_XLA_JIT; registration.autoclustering_policy = @@ -70,7 +70,7 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& session_options, return status; } - devices->push_back(device.release()); + devices->push_back(std::move(device)); } return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc index e828bae865d..4007309ed1c 100644 --- a/tensorflow/compiler/jit/xla_interpreter_device.cc +++ b/tensorflow/compiler/jit/xla_interpreter_device.cc @@ -33,12 +33,12 @@ constexpr std::array kExecAllTypes = { class XlaInterpreterDeviceFactory : public DeviceFactory { public: Status CreateDevices(const SessionOptions& options, const string& name_prefix, - std::vector* devices) override; + std::vector>* devices) override; }; Status XlaInterpreterDeviceFactory::CreateDevices( const SessionOptions& session_options, const string& name_prefix, - std::vector* devices) { + std::vector>* devices) { static XlaDeviceOpRegistrations* registrations = RegisterXlaDeviceKernels( DEVICE_XLA_INTERPRETER, DEVICE_INTERPRETER_XLA_JIT); (void)registrations; @@ -61,8 +61,7 @@ Status XlaInterpreterDeviceFactory::CreateDevices( options.device_ordinal = 0; options.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT; options.use_multiple_streams = false; - auto device = absl::make_unique(session_options, options); - devices->push_back(device.release()); + devices->push_back(absl::make_unique(session_options, options)); return Status::OK(); } diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 6b8e6bba1e1..bc3d60b90e5 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -375,27 +375,6 @@ tf_xla_py_test( ], ) -tf_xla_py_test( - name = "resampler_ops_test", - size = "small", - srcs = ["resampler_ops_test.py"], - disabled_backends = [ - # TODO(b/74459949) Support BatchDot in CPU backend. - "cpu", - "cpu_ondemand", - ], - # TODO(b/112295522): figure out how to make OSS build pass. - tags = ["no_oss"], - deps = [ - ":xla_test", - "//tensorflow/contrib/resampler:resampler_ops", - "//tensorflow/contrib/resampler:resampler_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:platform_test", - ], -) - tf_xla_py_test( name = "dynamic_stitch_test", size = "small", @@ -474,7 +453,6 @@ tf_xla_py_test( "//tensorflow/python:extra_py_tests_deps", "//tensorflow/python:framework", "//tensorflow/python:platform_test", - "//tensorflow/python:spectral_ops", "//tensorflow/python/ops/signal", ], ) diff --git a/tensorflow/compiler/tests/adagrad_da_test.py b/tensorflow/compiler/tests/adagrad_da_test.py index 69fb3ec2964..e9c2d363aca 100644 --- a/tensorflow/compiler/tests/adagrad_da_test.py +++ b/tensorflow/compiler/tests/adagrad_da_test.py @@ -50,8 +50,8 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase): zip([grads0, grads1], [var0, var1]), global_step=global_step) variables.global_variables_initializer().run() - self.assertAllClose([0.0, 0.0], var0.eval()) - self.assertAllClose([0.0, 0.0], var1.eval()) + self.assertAllClose([0.0, 0.0], self.evaluate(var0)) + self.assertAllClose([0.0, 0.0], self.evaluate(var1)) # Run a step of AdagradDA update.run() @@ -63,9 +63,9 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase): # For -0.1*3.0*(0.1 - 0)/(0 + sqrt(0.1 + 0.1*0.1)) = -0.904534 # similarly for others. self.assertAllCloseAccordingToType( - np.array([-0.904534, -1.603567]), var0.eval()) + np.array([-0.904534, -1.603567]), self.evaluate(var0)) self.assertAllCloseAccordingToType( - np.array([-0.094821, -0.189358]), var1.eval()) + np.array([-0.094821, -0.189358]), self.evaluate(var1)) def testAdagradDAwithoutRegularizationBasic2(self): for dtype in self.float_types: @@ -87,16 +87,16 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase): zip([grads0, grads1], [var0, var1]), global_step=global_step) variables.global_variables_initializer().run() - self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) - self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval()) + self.assertAllCloseAccordingToType([1.0, 2.0], self.evaluate(var0)) + self.assertAllCloseAccordingToType([4.0, 3.0], self.evaluate(var1)) # Run a step of AdagradDA update.run() self.assertAllCloseAccordingToType( - np.array([-0.904534, -1.603567]), var0.eval()) + np.array([-0.904534, -1.603567]), self.evaluate(var0)) self.assertAllCloseAccordingToType( - np.array([-0.094821, -0.189358]), var1.eval()) + np.array([-0.094821, -0.189358]), self.evaluate(var1)) def testAdagradDAWithL1(self): for dtype in self.float_types: @@ -118,16 +118,16 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase): zip([grads0, grads1], [var0, var1]), global_step=global_step) variables.global_variables_initializer().run() - self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) - self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval()) + self.assertAllCloseAccordingToType([1.0, 2.0], self.evaluate(var0)) + self.assertAllCloseAccordingToType([4.0, 3.0], self.evaluate(var1)) # Run a step of AdagradDA update.run() self.assertAllCloseAccordingToType( - np.array([-0.895489, -1.59555]), var0.eval()) + np.array([-0.895489, -1.59555]), self.evaluate(var0)) self.assertAllCloseAccordingToType( - np.array([-0.085339, -0.17989]), var1.eval()) + np.array([-0.085339, -0.17989]), self.evaluate(var1)) def testAdagradDAWithL1_L2(self): for dtype in self.float_types: @@ -149,16 +149,16 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase): zip([grads0, grads1], [var0, var1]), global_step=global_step) variables.global_variables_initializer().run() - self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) - self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval()) + self.assertAllCloseAccordingToType([1.0, 2.0], self.evaluate(var0)) + self.assertAllCloseAccordingToType([4.0, 3.0], self.evaluate(var1)) # Run a step of AdagradDA update.run() self.assertAllCloseAccordingToType( - np.array([-0.046907, -0.093659]), var0.eval()) + np.array([-0.046907, -0.093659]), self.evaluate(var0)) self.assertAllCloseAccordingToType( - np.array([-0.004275, -0.009023]), var1.eval()) + np.array([-0.004275, -0.009023]), self.evaluate(var1)) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/adagrad_test.py b/tensorflow/compiler/tests/adagrad_test.py index ab69319c59f..e26483303c3 100644 --- a/tensorflow/compiler/tests/adagrad_test.py +++ b/tensorflow/compiler/tests/adagrad_test.py @@ -42,17 +42,19 @@ class AdagradOptimizerTest(xla_test.XLATestCase): zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Run 3 steps of adagrad for _ in range(3): ada_update.run() # Validate updated params self.assertAllCloseAccordingToType( - np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval(), + np.array([-1.6026098728179932, -0.6026098728179932]), + self.evaluate(var0), float_rtol=1e-5) self.assertAllCloseAccordingToType( - np.array([2.715679168701172, 3.715679168701172]), var1.eval(), + np.array([2.715679168701172, 3.715679168701172]), + self.evaluate(var1), float_rtol=1e-5) def testTensorLearningRate(self): @@ -68,17 +70,19 @@ class AdagradOptimizerTest(xla_test.XLATestCase): zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Run 3 steps of adagrad for _ in range(3): ada_update.run() # Validate updated params self.assertAllCloseAccordingToType( - np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval(), + np.array([-1.6026098728179932, -0.6026098728179932]), + self.evaluate(var0), float_rtol=1e-5) self.assertAllCloseAccordingToType( - np.array([2.715679168701172, 3.715679168701172]), var1.eval(), + np.array([2.715679168701172, 3.715679168701172]), + self.evaluate(var1), float_rtol=1e-5) def testSharing(self): @@ -103,18 +107,20 @@ class AdagradOptimizerTest(xla_test.XLATestCase): variables.global_variables_initializer().run() # Fetch params to validate initial values. - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Mix the first and the second adagrad for 3 steps. ada_update1.run() ada_update2.run() ada_update1.run() # Validate updated params (the same as with only 1 Adagrad). self.assertAllCloseAccordingToType( - np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval(), + np.array([-1.6026098728179932, -0.6026098728179932]), + self.evaluate(var0), float_rtol=1e-5) self.assertAllCloseAccordingToType( - np.array([2.715679168701172, 3.715679168701172]), var1.eval(), + np.array([2.715679168701172, 3.715679168701172]), + self.evaluate(var1), float_rtol=1e-5) diff --git a/tensorflow/compiler/tests/adam_test.py b/tensorflow/compiler/tests/adam_test.py index 058576b3d4b..8bcff9d379d 100644 --- a/tensorflow/compiler/tests/adam_test.py +++ b/tensorflow/compiler/tests/adam_test.py @@ -75,23 +75,24 @@ class AdamOptimizerTest(xla_test.XLATestCase): variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) beta1_power, beta2_power = opt._get_beta_accumulators() # Run 3 steps of Adam for t in range(1, 4): - self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) - self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType(0.999**t, + self.evaluate(beta2_power)) update.run(feed_dict={grads0: grads0_np, grads1: grads1_np}) var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) # Validate updated params - self.assertAllCloseAccordingToType(var0_np, var0.eval()) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) def testTensorLearningRate(self): for dtype in self.float_types: @@ -117,23 +118,24 @@ class AdamOptimizerTest(xla_test.XLATestCase): variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) beta1_power, beta2_power = opt._get_beta_accumulators() # Run 3 steps of Adam for t in range(1, 4): - self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) - self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType(0.999**t, + self.evaluate(beta2_power)) update.run(feed_dict={grads0: grads0_np, grads1: grads1_np}) var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) # Validate updated params - self.assertAllCloseAccordingToType(var0_np, var0.eval()) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) def testSharing(self): for dtype in self.float_types: @@ -162,13 +164,14 @@ class AdamOptimizerTest(xla_test.XLATestCase): beta1_power, beta2_power = opt._get_beta_accumulators() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Run 3 steps of intertwined Adam1 and Adam2. for t in range(1, 4): - self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) - self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType(0.999**t, + self.evaluate(beta2_power)) if t % 2 == 0: update1.run(feed_dict={grads0: grads0_np, grads1: grads1_np}) else: @@ -178,8 +181,8 @@ class AdamOptimizerTest(xla_test.XLATestCase): var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) # Validate updated params - self.assertAllCloseAccordingToType(var0_np, var0.eval()) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/adamax_test.py b/tensorflow/compiler/tests/adamax_test.py index 3ed1d41b712..961b46375c9 100644 --- a/tensorflow/compiler/tests/adamax_test.py +++ b/tensorflow/compiler/tests/adamax_test.py @@ -78,8 +78,8 @@ class AdaMaxOptimizerTest(xla_test.XLATestCase): variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) beta1_power = opt._get_beta_accumulators() @@ -87,14 +87,17 @@ class AdaMaxOptimizerTest(xla_test.XLATestCase): for t in range(1, 4): update.run() - self.assertAllCloseAccordingToType(0.9**(t + 1), beta1_power.eval()) + self.assertAllCloseAccordingToType(0.9**(t + 1), + self.evaluate(beta1_power)) var0_np, m0, v0 = adamax_update_numpy(var0_np, grads0_np, t, m0, v0) var1_np, m1, v1 = adamax_update_numpy(var1_np, grads1_np, t, m1, v1) # Validate updated params - self.assertAllCloseAccordingToType(var0_np, var0.eval(), rtol=1e-2) - self.assertAllCloseAccordingToType(var1_np, var1.eval(), rtol=1e-2) + self.assertAllCloseAccordingToType( + var0_np, self.evaluate(var0), rtol=1e-2) + self.assertAllCloseAccordingToType( + var1_np, self.evaluate(var1), rtol=1e-2) self.assertEqual("var0_%d/AdaMax:0" % (i,), opt.get_slot(var=var0, name="m").name) @@ -118,22 +121,23 @@ class AdaMaxOptimizerTest(xla_test.XLATestCase): variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) beta1_power = opt._get_beta_accumulators() # Run 3 steps of AdaMax for t in range(1, 4): - self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power)) update.run() var0_np, m0, v0 = adamax_update_numpy(var0_np, grads0_np, t, m0, v0) var1_np, m1, v1 = adamax_update_numpy(var1_np, grads1_np, t, m1, v1) # Validate updated params - self.assertAllCloseAccordingToType(var0_np, var0.eval()) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/addsign_test.py b/tensorflow/compiler/tests/addsign_test.py index 1bc07ace23c..a37c97e6d37 100644 --- a/tensorflow/compiler/tests/addsign_test.py +++ b/tensorflow/compiler/tests/addsign_test.py @@ -90,8 +90,8 @@ class AddSignTest(xla_test.XLATestCase): variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Run 7 steps of AddSign # first 4 steps with positive gradient @@ -125,8 +125,8 @@ class AddSignTest(xla_test.XLATestCase): # Validate updated params self.assertAllCloseAccordingToType( - var0_np, var0.eval(), half_rtol=1e-2) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) + var0_np, self.evaluate(var0), half_rtol=1e-2) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) def testDense(self): decay_steps = 10 diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 332381c59ee..9a5423c1b2a 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -218,6 +218,21 @@ class BinaryOpsTest(xla_test.XLATestCase): ], equality_test=self.ListsAreClose) + # TF doesn't define these for bf16. + if dtype != dtypes.bfloat16.as_numpy_dtype: + self._testBinary( + gen_math_ops.xdivy, + np.array([0, 4, 3, 2, 1, 0], dtype=dtype), + np.array([0, 5, 6, 7, 8, float("NaN")], dtype=dtype), + expected=np.array([0, 0.8, 0.5, 0.285714, 0.125, 0], dtype=dtype)) + + self._testBinary( + gen_math_ops.xlogy, + np.array([0, 4, 3, 2, 1, 0], dtype=dtype), + np.array([0, 5, 6, 7, 8, float("NaN")], dtype=dtype), + expected=np.array([0, 6.437752, 5.375278, 3.89182, 2.079442, 0], + dtype=dtype)) + def testIntOps(self): for dtype in self.signed_int_types: self._testBinary( diff --git a/tensorflow/compiler/tests/categorical_op_test.py b/tensorflow/compiler/tests/categorical_op_test.py index a57d1dc81ea..5d5e486f616 100644 --- a/tensorflow/compiler/tests/categorical_op_test.py +++ b/tensorflow/compiler/tests/categorical_op_test.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import stateless_random_ops from tensorflow.python.platform import googletest @@ -56,11 +57,11 @@ class CategoricalTest(xla_test.XLATestCase): Returns: Frequencies from sampled classes; shape [batch_size, num_classes]. """ - with self.cached_session() as sess, self.test_scope(): + with self.cached_session(), self.test_scope(): random_seed.set_random_seed(1618) op = random_ops.multinomial(logits, num_samples, output_dtype=dtypes.int32) - d = sess.run(op) + d = self.evaluate(op) batch_size, num_classes = logits.shape freqs_mat = [] @@ -79,15 +80,15 @@ class CategoricalTest(xla_test.XLATestCase): def _testRngIsNotConstant(self, rng, dtype, output_dtype): # Tests that 'rng' does not always return the same value. - with self.cached_session() as sess: + with self.cached_session(): with self.test_scope(): x = rng(dtype, output_dtype) # The random-number generator, if working correctly, should produce the # same output multiple times with low probability. - y = sess.run(x) - z = sess.run(x) - w = sess.run(x) + y = self.evaluate(x) + z = self.evaluate(x) + w = self.evaluate(x) # We use exact equality here. If the random-number generator is producing # deterministic output, all three outputs will be bitwise identical. @@ -107,12 +108,12 @@ class CategoricalTest(xla_test.XLATestCase): def testCategoricalIsInRange(self): for dtype in self.float_types: for output_dtype in self.output_dtypes(): - with self.cached_session() as sess: + with self.cached_session(): with self.test_scope(): x = random_ops.multinomial( array_ops.ones(shape=[1, 20], dtype=dtype), 1000, output_dtype=output_dtype) - y = sess.run(x) + y = self.evaluate(x) self.assertTrue((y >= 0).sum() == 1000) self.assertTrue((y < 20).sum() == 1000) @@ -138,6 +139,57 @@ class CategoricalTest(xla_test.XLATestCase): chi2 = self._chi2(probs, freqs) self.assertLess(chi2, 1e-3) + def testStatelessMultinomialIsInRange(self): + for dtype in self.float_types: + for output_dtype in self.output_dtypes(): + with self.cached_session() as sess: + with self.test_scope(): + seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) + x = stateless_random_ops.stateless_multinomial( + array_ops.ones(shape=[1, 20], dtype=dtype), + 1000, + seed_t, + output_dtype=output_dtype) + y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]}) + self.assertTrue((y >= 0).sum() == 1000) + self.assertTrue((y < 20).sum() == 1000) + + def testDeterminismMultinomial(self): + # Stateless values should be equal iff the seeds are equal (roughly) + num_samples = 10 + with self.cached_session(), self.test_scope(): + seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) + seeds = [(x, y) for x in range(5) for y in range(5)] * 3 + for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2], + [0.25, 0.75]]): + pure = stateless_random_ops.stateless_multinomial( + logits, num_samples, seed=seed_t) + values = [(seed, pure.eval(feed_dict={seed_t: seed})) for seed in seeds] + for s0, v0 in values: + for s1, v1 in values: + self.assertEqual(s0 == s1, np.all(v0 == v1)) + + def testEmpty(self): + with self.cached_session(): + with self.test_scope(): + x = random_ops.multinomial( + array_ops.zeros([42, 40]), 0, output_dtype=dtypes.int32) + y = self.evaluate(x) + self.assertEqual(y.shape, (42, 0)) + + def testEmptyStateless(self): + with self.cached_session() as sess: + with self.test_scope(): + seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) + x = stateless_random_ops.stateless_multinomial( + array_ops.zeros([42, 40]), + 0, + seed=seed_t, + output_dtype=dtypes.int32) + y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]}) + self.assertEqual(y.shape, (42, 0)) + + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/compiler/tests/clustering_test.py b/tensorflow/compiler/tests/clustering_test.py index 88bd58b2da6..ef2d7af69de 100644 --- a/tensorflow/compiler/tests/clustering_test.py +++ b/tensorflow/compiler/tests/clustering_test.py @@ -43,7 +43,7 @@ class ClusteringTest(xla_test.XLATestCase): input1 = constant_op.constant(val1, name="const1") input2 = constant_op.constant(val2, name="const2") output = math_ops.add(input1, input2) - result = output.eval() + result = self.evaluate(output) self.assertAllClose(result, expected, rtol=1e-3) def testAddFromCpuMultiple(self): @@ -57,7 +57,7 @@ class ClusteringTest(xla_test.XLATestCase): with self.test_scope(): output = math_ops.add(input1, input2) for _ in xrange(10): - result = output.eval() + result = self.evaluate(output) self.assertAllClose(result, expected, rtol=1e-3) def testDeadlock(self): diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py index 2d225ad226c..2187f57960f 100644 --- a/tensorflow/compiler/tests/concat_ops_test.py +++ b/tensorflow/compiler/tests/concat_ops_test.py @@ -72,7 +72,7 @@ class ConcatTest(xla_test.XLATestCase): x2 = constant_op.constant(p2) with self.test_scope(): c = array_ops.concat([x1, x2], 0) - result = c.eval() + result = self.evaluate(c) self.assertAllEqual(result[:2, :], p1) self.assertAllEqual(result[2:, :], p2) @@ -150,7 +150,7 @@ class ConcatTest(xla_test.XLATestCase): [float(x) for x in grad_inp.flatten()], shape=output_shape) grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor]) concated_grad = array_ops.concat(grad, 1) - result = concated_grad.eval() + result = self.evaluate(concated_grad) self.assertAllEqual(result, grad_inp) def testGradientsSimpleAll(self): @@ -177,7 +177,7 @@ class ConcatTest(xla_test.XLATestCase): [float(x) for x in grad_inp.flatten()], shape=output_shape) grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor]) concated_grad = array_ops.concat(grad, 0) - result = concated_grad.eval() + result = self.evaluate(concated_grad) self.assertAllEqual(result, grad_inp) @@ -205,7 +205,7 @@ class ConcatTest(xla_test.XLATestCase): [float(x) for x in grad_inp.flatten()], shape=output_shape) grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor]) concated_grad = array_ops.concat(grad, 2) - result = concated_grad.eval() + result = self.evaluate(concated_grad) self.assertAllEqual(result, grad_inp) @@ -242,7 +242,7 @@ class ConcatTest(xla_test.XLATestCase): [float(x) for x in grad_inp.flatten()], shape=output_shape) grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor]) concated_grad = array_ops.concat(grad, concat_dim) - result = concated_grad.eval() + result = self.evaluate(concated_grad) self.assertAllEqual(result, grad_inp) @@ -254,7 +254,7 @@ class ConcatTest(xla_test.XLATestCase): def DISABLED_testZeroSize(self): # Verify that concat doesn't crash and burn for zero size inputs np.random.seed(7) - with self.cached_session() as sess: + with self.cached_session(): with self.test_scope(): for shape0 in (), (2,): axis = len(shape0) @@ -270,7 +270,7 @@ class ConcatTest(xla_test.XLATestCase): self.assertAllEqual(c.eval(), correct) # Check gradients dc = np.random.randn(*c.get_shape().as_list()) - dxs = sess.run(gradients_impl.gradients(c, xs, dc)) + dxs = self.evaluate(gradients_impl.gradients(c, xs, dc)) self.assertAllEqual(dc, np.concatenate(dxs, axis=axis)) def testConcatTuple(self): @@ -280,7 +280,7 @@ class ConcatTest(xla_test.XLATestCase): with self.test_scope(): concat_list_t = array_ops.concat([c1, c2], 0) concat_tuple_t = array_ops.concat((c1, c2), 0) - self.assertAllEqual(concat_list_t.eval(), concat_tuple_t.eval()) + self.assertAllEqual(concat_list_t.eval(), self.evaluate(concat_tuple_t)) def testConcatNoScalars(self): with self.cached_session(): @@ -330,47 +330,47 @@ class ConcatTest(xla_test.XLATestCase): class ConcatOffsetTest(xla_test.XLATestCase): def testBasic(self): - with self.cached_session() as sess: + with self.cached_session(): with self.test_scope(): cdim = constant_op.constant(1, dtypes.int32) s0 = constant_op.constant([2, 3, 5], dtypes.int32) s1 = constant_op.constant([2, 7, 5], dtypes.int32) s2 = constant_op.constant([2, 20, 5], dtypes.int32) off = gen_array_ops.concat_offset(cdim, [s0, s1, s2]) - ans = sess.run(off) + ans = self.evaluate(off) self.assertAllEqual(ans, [[0, 0, 0], [0, 3, 0], [0, 10, 0]]) class PackTest(xla_test.XLATestCase): def testBasic(self): - with self.cached_session() as sess: + with self.cached_session(): with self.test_scope(): s0 = constant_op.constant([2, 3, 5], dtypes.int32) s1 = constant_op.constant([2, 7, 5], dtypes.int32) s2 = constant_op.constant([2, 20, 5], dtypes.int32) packed = array_ops.stack([s0, s1, s2]) - ans = sess.run(packed) + ans = self.evaluate(packed) self.assertAllEqual(ans, [[2, 3, 5], [2, 7, 5], [2, 20, 5]]) def testScalars(self): - with self.cached_session() as sess: + with self.cached_session(): with self.test_scope(): s0 = constant_op.constant(2, dtypes.int32) s1 = constant_op.constant(3, dtypes.int32) s2 = constant_op.constant(5, dtypes.int32) packed = array_ops.stack([s0, s1, s2]) - ans = sess.run(packed) + ans = self.evaluate(packed) self.assertAllEqual(ans, [2, 3, 5]) def testEmpty(self): - with self.cached_session() as sess: + with self.cached_session(): with self.test_scope(): s0 = constant_op.constant([[]], dtypes.int32) s1 = constant_op.constant([[]], dtypes.int32) s2 = constant_op.constant([[]], dtypes.int32) packed = array_ops.stack([s0, s1, s2]) - ans = sess.run(packed) + ans = self.evaluate(packed) self.assertAllEqual(ans, [[[]], [[]], [[]]]) diff --git a/tensorflow/compiler/tests/conv3d_test.py b/tensorflow/compiler/tests/conv3d_test.py index d59fd0236f4..01cc1b63928 100644 --- a/tensorflow/compiler/tests/conv3d_test.py +++ b/tensorflow/compiler/tests/conv3d_test.py @@ -85,7 +85,7 @@ class Conv3DTransposeTest(xla_test.XLATestCase): 1.0, shape=f_shape, name="filter", dtype=dtypes.float32) output = nn_ops.conv3d_transpose( x, f, y_shape, strides=strides, padding="SAME") - value = output.eval() + value = self.evaluate(output) # We count the number of cells being added at the locations in the output. # At the center, #cells = kernel_depth * kernel_height * kernel_width @@ -135,7 +135,7 @@ class Conv3DTransposeTest(xla_test.XLATestCase): 1.0, shape=f_shape, name="filter", dtype=dtypes.float32) output = nn_ops.conv3d_transpose( x, f, y_shape, strides=strides, padding="SAME") - value = output.eval() + value = self.evaluate(output) for n in xrange(x_shape[0]): for k in xrange(f_shape[3]): @@ -173,7 +173,7 @@ class Conv3DTransposeTest(xla_test.XLATestCase): 1.0, shape=f_shape, name="filter", dtype=dtypes.float32) output = nn_ops.conv3d_transpose( x, f, y_shape, strides=strides, padding="VALID") - value = output.eval() + value = self.evaluate(output) cache_values = np.zeros(y_shape, dtype=np.float32) diff --git a/tensorflow/compiler/tests/dense_layer_test.py b/tensorflow/compiler/tests/dense_layer_test.py index d1b90f098d7..bf5ea7b1fb6 100644 --- a/tensorflow/compiler/tests/dense_layer_test.py +++ b/tensorflow/compiler/tests/dense_layer_test.py @@ -42,7 +42,7 @@ def GetRunMetadataLabels(run_metadata): def InLabels(labels, substr): """Returns true iff one of the labels contains substr.""" - return any([substr in x for x in labels]) + return any(substr in x for x in labels) class DenseLayerTest(test.TestCase): @@ -72,7 +72,7 @@ class DenseLayerTest(test.TestCase): x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32) y = layers.dense(x, 3) - sess.run(variables.initialize_all_variables()) + self.evaluate(variables.initialize_all_variables()) run_metadata = config_pb2.RunMetadata() test_utils.RunWithWarmup( sess, @@ -97,7 +97,7 @@ class DenseLayerTest(test.TestCase): with jit_scope(): y = layers.dense(x, 3) - sess.run(variables.initialize_all_variables()) + self.evaluate(variables.initialize_all_variables()) run_metadata = config_pb2.RunMetadata() test_utils.RunWithWarmup( sess, @@ -126,7 +126,7 @@ class DenseLayerTest(test.TestCase): with jit_scope(): y = layers.dense(x, 3) - sess.run(variables.initialize_all_variables()) + self.evaluate(variables.initialize_all_variables()) run_metadata = config_pb2.RunMetadata() test_utils.RunWithWarmup( sess, diff --git a/tensorflow/compiler/tests/dynamic_stitch_test.py b/tensorflow/compiler/tests/dynamic_stitch_test.py index 50b04daa6b9..e89cf975f5d 100644 --- a/tensorflow/compiler/tests/dynamic_stitch_test.py +++ b/tensorflow/compiler/tests/dynamic_stitch_test.py @@ -58,6 +58,15 @@ class DynamicStitchTest(xla_test.XLATestCase): [idx1, idx2], [val1, val2], expected=np.array([[], [], [], []], np.int32)) + def testEmptyIndex(self): + idx1 = np.array([], dtype=np.int32) + idx2 = np.array([[], []], dtype=np.int32) + val1 = np.ndarray(shape=(0, 9), dtype=np.int32) + val2 = np.ndarray(shape=(2, 0, 9), dtype=np.int32) + self._AssertDynamicStitchResultIs([idx1, idx2], [val1, val2], + expected=np.ndarray( + shape=(0, 9), dtype=np.int32)) + def testSimple1D(self): val1 = np.array([0, 4, 7], dtype=np.int32) val2 = np.array([1, 6, 2, 3, 5], dtype=np.int32) diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py index 63cee550fde..2af32b537ba 100644 --- a/tensorflow/compiler/tests/eager_test.py +++ b/tensorflow/compiler/tests/eager_test.py @@ -101,12 +101,12 @@ class EagerTest(xla_test.XLATestCase): self.assertAllEqual(15, product) # Run some ops graphly - with context.graph_mode(), self.cached_session() as sess: + with context.graph_mode(), self.cached_session(): with self.test_scope(): three = constant_op.constant(3) five = constant_op.constant(5) product = three * five - self.assertAllEqual(15, sess.run(product)) + self.assertAllEqual(15, self.evaluate(product)) def testDegenerateSlices(self): with self.test_scope(): diff --git a/tensorflow/compiler/tests/fft_test.py b/tensorflow/compiler/tests/fft_test.py index e92afd5d6fe..0edd0c35aa2 100644 --- a/tensorflow/compiler/tests/fft_test.py +++ b/tensorflow/compiler/tests/fft_test.py @@ -27,8 +27,7 @@ from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients_impl -from tensorflow.python.ops import signal -from tensorflow.python.ops import spectral_ops +from tensorflow.python.ops.signal import signal from tensorflow.python.platform import googletest BATCH_DIMS = (3, 5) @@ -107,39 +106,39 @@ class FFTTest(xla_test.XLATestCase): def testFFT(self): self._VerifyFftMethod(INNER_DIMS_1D, lambda x: x, np.fft.fft, - spectral_ops.fft) + signal.fft) def testFFT2D(self): self._VerifyFftMethod(INNER_DIMS_2D, lambda x: x, np.fft.fft2, - spectral_ops.fft2d) + signal.fft2d) def testFFT3D(self): self._VerifyFftMethod(INNER_DIMS_3D, lambda x: x, lambda x: np.fft.fftn(x, axes=(-3, -2, -1)), - spectral_ops.fft3d) + signal.fft3d) def testIFFT(self): self._VerifyFftMethod(INNER_DIMS_1D, lambda x: x, np.fft.ifft, - spectral_ops.ifft) + signal.ifft) def testIFFT2D(self): self._VerifyFftMethod(INNER_DIMS_2D, lambda x: x, np.fft.ifft2, - spectral_ops.ifft2d) + signal.ifft2d) def testIFFT3D(self): self._VerifyFftMethod(INNER_DIMS_3D, lambda x: x, lambda x: np.fft.ifftn(x, axes=(-3, -2, -1)), - spectral_ops.ifft3d) + signal.ifft3d) def testRFFT(self): self._VerifyFftMethod( INNER_DIMS_1D, np.real, lambda x: np.fft.rfft(x, n=x.shape[-1]), - lambda x: spectral_ops.rfft(x, fft_length=[x.shape[-1].value])) + lambda x: signal.rfft(x, fft_length=[x.shape[-1].value])) def testRFFT2D(self): def _tf_fn(x): - return spectral_ops.rfft2d( + return signal.rfft2d( x, fft_length=[x.shape[-2].value, x.shape[-1].value]) self._VerifyFftMethod( @@ -153,16 +152,33 @@ class FFTTest(xla_test.XLATestCase): x, axes=(-3, -2, -1), s=[x.shape[-3], x.shape[-2], x.shape[-1]]) def _tf_fn(x): - return spectral_ops.rfft3d( + return signal.rfft3d( x, fft_length=[x.shape[-3].value, x.shape[-2].value, x.shape[-1].value]) self._VerifyFftMethod(INNER_DIMS_3D, np.real, _to_expected, _tf_fn) + def testRFFT3DMismatchedSize(self): + + def _to_expected(x): + return np.fft.rfftn( + x, + axes=(-3, -2, -1), + s=[x.shape[-3] // 2, x.shape[-2], x.shape[-1] * 2]) + + def _tf_fn(x): + return signal.rfft3d( + x, + fft_length=[ + x.shape[-3].value // 2, x.shape[-2].value, x.shape[-1].value * 2 + ]) + + self._VerifyFftMethod(INNER_DIMS_3D, np.real, _to_expected, _tf_fn) + def testIRFFT(self): def _tf_fn(x): - return spectral_ops.irfft(x, fft_length=[2 * (x.shape[-1].value - 1)]) + return signal.irfft(x, fft_length=[2 * (x.shape[-1].value - 1)]) self._VerifyFftMethod( INNER_DIMS_1D, lambda x: np.fft.rfft(np.real(x), n=x.shape[-1]), @@ -171,7 +187,7 @@ class FFTTest(xla_test.XLATestCase): def testIRFFT2D(self): def _tf_fn(x): - return spectral_ops.irfft2d( + return signal.irfft2d( x, fft_length=[x.shape[-2].value, 2 * (x.shape[-1].value - 1)]) self._VerifyFftMethod( @@ -195,7 +211,7 @@ class FFTTest(xla_test.XLATestCase): s=[x.shape[-3], x.shape[-2], 2 * (x.shape[-1] - 1)]) def _tf_fn(x): - return spectral_ops.irfft3d( + return signal.irfft3d( x, fft_length=[ x.shape[-3].value, x.shape[-2].value, 2 * (x.shape[-1].value - 1) @@ -203,6 +219,30 @@ class FFTTest(xla_test.XLATestCase): self._VerifyFftMethod(INNER_DIMS_3D, _to_input, _to_expected, _tf_fn) + def testIRFFT3DMismatchedSize(self): + + def _to_input(x): + return np.fft.rfftn( + np.real(x), + axes=(-3, -2, -1), + s=[x.shape[-3] // 2, x.shape[-2], x.shape[-1] * 2]) + + def _to_expected(x): + return np.fft.irfftn( + x, + axes=(-3, -2, -1), + s=[x.shape[-3] // 2, x.shape[-2], x.shape[-1] * 2]) + + def _tf_fn(x): + return signal.irfft3d( + x, + fft_length=[ + x.shape[-3].value // 2, x.shape[-2].value, x.shape[-1].value * 2 + ]) + + self._VerifyFftMethod(INNER_DIMS_3D, _to_input, _to_expected, _tf_fn) + + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/compiler/tests/fifo_queue_test.py b/tensorflow/compiler/tests/fifo_queue_test.py index 8c7edfd277c..91d77d2f791 100644 --- a/tensorflow/compiler/tests/fifo_queue_test.py +++ b/tensorflow/compiler/tests/fifo_queue_test.py @@ -129,7 +129,7 @@ class FIFOQueueTest(xla_test.XLATestCase): enqueue_op.run() for i in xrange(len(elems)): - vals = dequeued_t.eval() + vals = self.evaluate(dequeued_t) self.assertEqual([elems[i]], vals) def testEnqueueAndBlockingDequeue(self): @@ -192,9 +192,9 @@ class FIFOQueueTest(xla_test.XLATestCase): self.assertEqual([], size.get_shape()) enqueue_op.run() - self.assertEqual(1, size.eval()) + self.assertEqual(1, self.evaluate(size)) dequeued_t.op.run() - self.assertEqual(0, size.eval()) + self.assertEqual(0, self.evaluate(size)) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/ftrl_test.py b/tensorflow/compiler/tests/ftrl_test.py index 5b197afd655..b078053cdbd 100644 --- a/tensorflow/compiler/tests/ftrl_test.py +++ b/tensorflow/compiler/tests/ftrl_test.py @@ -50,14 +50,14 @@ class FtrlOptimizerTest(xla_test.XLATestCase): ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([0.0, 0.0], var0.eval()) - self.assertAllClose([0.0, 0.0], var1.eval()) + self.assertAllClose([0.0, 0.0], self.evaluate(var0)) + self.assertAllClose([0.0, 0.0], self.evaluate(var1)) # Run Ftrl for a few steps for _ in range(steps): ftrl_update.run() - return var0.eval(), var1.eval() + return self.evaluate(var0), self.evaluate(var1) def equivAdagradTest_AdagradPart(self, steps, dtype): var0, var1, grads0, grads1 = self.initVariableAndGradient(dtype) @@ -65,14 +65,14 @@ class FtrlOptimizerTest(xla_test.XLATestCase): adagrad_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([0.0, 0.0], var0.eval()) - self.assertAllClose([0.0, 0.0], var1.eval()) + self.assertAllClose([0.0, 0.0], self.evaluate(var0)) + self.assertAllClose([0.0, 0.0], self.evaluate(var1)) # Run Adagrad for a few steps for _ in range(steps): adagrad_update.run() - return var0.eval(), var1.eval() + return self.evaluate(var0), self.evaluate(var1) def equivGradientDescentTest_FtrlPart(self, steps, dtype): var0, var1, grads0, grads1 = self.initVariableAndGradient(dtype) @@ -85,14 +85,14 @@ class FtrlOptimizerTest(xla_test.XLATestCase): ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([0.0, 0.0], var0.eval()) - self.assertAllClose([0.0, 0.0], var1.eval()) + self.assertAllClose([0.0, 0.0], self.evaluate(var0)) + self.assertAllClose([0.0, 0.0], self.evaluate(var1)) # Run Ftrl for a few steps for _ in range(steps): ftrl_update.run() - return var0.eval(), var1.eval() + return self.evaluate(var0), self.evaluate(var1) def equivGradientDescentTest_GradientDescentPart(self, steps, dtype): var0, var1, grads0, grads1 = self.initVariableAndGradient(dtype) @@ -100,14 +100,14 @@ class FtrlOptimizerTest(xla_test.XLATestCase): sgd_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([0.0, 0.0], var0.eval()) - self.assertAllClose([0.0, 0.0], var1.eval()) + self.assertAllClose([0.0, 0.0], self.evaluate(var0)) + self.assertAllClose([0.0, 0.0], self.evaluate(var1)) # Run GradientDescent for a few steps for _ in range(steps): sgd_update.run() - return var0.eval(), var1.eval() + return self.evaluate(var0), self.evaluate(var1) def testFtrlwithoutRegularization(self): for dtype in self.float_types: @@ -124,8 +124,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase): ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([0.0, 0.0], var0.eval()) - self.assertAllClose([0.0, 0.0], var1.eval()) + self.assertAllClose([0.0, 0.0], self.evaluate(var0)) + self.assertAllClose([0.0, 0.0], self.evaluate(var1)) # Run 3 steps FTRL for _ in range(3): @@ -134,12 +134,12 @@ class FtrlOptimizerTest(xla_test.XLATestCase): # Validate updated params self.assertAllCloseAccordingToType( np.array([-2.60260963, -4.29698515]), - var0.eval(), + self.evaluate(var0), float_rtol=1e-4, half_rtol=1e-2) self.assertAllCloseAccordingToType( np.array([-0.28432083, -0.56694895]), - var1.eval(), + self.evaluate(var1), float_rtol=1e-5, half_rtol=1e-2) @@ -158,8 +158,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase): ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([4.0, 3.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([4.0, 3.0], self.evaluate(var1)) # Run 3 steps FTRL for _ in range(3): @@ -167,10 +167,14 @@ class FtrlOptimizerTest(xla_test.XLATestCase): # Validate updated params self.assertAllCloseAccordingToType( - np.array([-2.55607247, -3.98729396]), var0.eval(), 1e-5, 1e-5, + np.array([-2.55607247, -3.98729396]), + self.evaluate(var0), + 1e-5, + 1e-5, float_rtol=1e-4) self.assertAllCloseAccordingToType( - np.array([-0.28232238, -0.56096673]), var1.eval(), 1e-5, 1e-5) + np.array([-0.28232238, -0.56096673]), self.evaluate(var1), 1e-5, + 1e-5) def testFtrlWithL1(self): for dtype in self.float_types: @@ -187,8 +191,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase): ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([4.0, 3.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([4.0, 3.0], self.evaluate(var1)) # Run 10 steps FTRL for _ in range(10): @@ -197,12 +201,14 @@ class FtrlOptimizerTest(xla_test.XLATestCase): # Validate updated params self.assertAllCloseAccordingToType( np.array([-7.66718769, -10.91273689]), - var0.eval(), + self.evaluate(var0), rtol=1e-4, bfloat16_rtol=1e-1, bfloat16_atol=1e-1) self.assertAllCloseAccordingToType( - np.array([-0.93460727, -1.86147261]), var1.eval(), rtol=1e-4) + np.array([-0.93460727, -1.86147261]), + self.evaluate(var1), + rtol=1e-4) def testFtrlWithL1_L2(self): for dtype in self.float_types: @@ -219,8 +225,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase): ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([4.0, 3.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([4.0, 3.0], self.evaluate(var1)) # Run 10 steps FTRL for _ in range(10): @@ -228,9 +234,13 @@ class FtrlOptimizerTest(xla_test.XLATestCase): # Validate updated params self.assertAllCloseAccordingToType( - np.array([-0.24059935, -0.46829352]), var0.eval(), rtol=1e-5) + np.array([-0.24059935, -0.46829352]), + self.evaluate(var0), + rtol=1e-5) self.assertAllCloseAccordingToType( - np.array([-0.02406147, -0.04830509]), var1.eval(), rtol=1e-5) + np.array([-0.02406147, -0.04830509]), + self.evaluate(var1), + rtol=1e-5) def testFtrlWithL1_L2_L2Shrinkage(self): """Test the new FTRL op with support for l2 shrinkage. @@ -254,8 +264,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase): ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) - self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval()) + self.assertAllCloseAccordingToType([1.0, 2.0], self.evaluate(var0)) + self.assertAllCloseAccordingToType([4.0, 3.0], self.evaluate(var1)) # Run 10 steps FTRL for _ in range(10): @@ -263,9 +273,13 @@ class FtrlOptimizerTest(xla_test.XLATestCase): # Validate updated params self.assertAllCloseAccordingToType( - np.array([-0.22578996, -0.44345799]), var0.eval(), rtol=1e-4) + np.array([-0.22578996, -0.44345799]), + self.evaluate(var0), + rtol=1e-4) self.assertAllCloseAccordingToType( - np.array([-0.14378493, -0.13229476]), var1.eval(), rtol=1e-4) + np.array([-0.14378493, -0.13229476]), + self.evaluate(var1), + rtol=1e-4) def testFtrlWithL2ShrinkageDoesNotChangeLrSchedule(self): """Verifies that l2 shrinkage in FTRL does not change lr schedule.""" @@ -291,8 +305,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase): update1 = opt1.apply_gradients([(grads1, var1)]) variables.global_variables_initializer().run() - self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) - self.assertAllCloseAccordingToType([1.0, 2.0], var1.eval()) + self.assertAllCloseAccordingToType([1.0, 2.0], self.evaluate(var0)) + self.assertAllCloseAccordingToType([1.0, 2.0], self.evaluate(var1)) # Run 10 steps FTRL for _ in range(10): @@ -301,7 +315,7 @@ class FtrlOptimizerTest(xla_test.XLATestCase): # var0 is experiencing L2 shrinkage so it should be smaller than var1 # in magnitude. - self.assertTrue((var0.eval()**2 < var1.eval()**2).all()) + self.assertTrue((var0.eval()**2 < self.evaluate(var1)**2).all()) accum0 = list(opt0._slots["accum"].values())[0].eval() accum1 = list(opt1._slots["accum"].values())[0].eval() # L2 shrinkage should not change how we update grad accumulator. diff --git a/tensorflow/compiler/tests/function_test.py b/tensorflow/compiler/tests/function_test.py index b1891b918c6..a61827c2ae4 100644 --- a/tensorflow/compiler/tests/function_test.py +++ b/tensorflow/compiler/tests/function_test.py @@ -40,7 +40,7 @@ class FunctionTest(xla_test.XLATestCase): bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32) expected = APlus2B(aval, bval) - with self.cached_session() as sess: + with self.cached_session(): @function.Defun(dtypes.float32, dtypes.float32) def Foo(a, b): @@ -50,7 +50,7 @@ class FunctionTest(xla_test.XLATestCase): b = constant_op.constant(bval, name="b") with self.test_scope(): call_f = Foo(a, b) - result = sess.run(call_f) + result = self.evaluate(call_f) self.assertAllClose(result, expected, rtol=1e-3) def testNestedFunctions(self): @@ -66,7 +66,7 @@ class FunctionTest(xla_test.XLATestCase): bval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32) expected = APlus2B(aval, bval) - with self.cached_session() as sess: + with self.cached_session(): @function.Defun(dtypes.float32, dtypes.float32) def Foo(a, b): @@ -76,7 +76,7 @@ class FunctionTest(xla_test.XLATestCase): b = constant_op.constant(bval, name="b") with self.test_scope(): call_g = Foo(a, b) - result = sess.run(call_g) + result = self.evaluate(call_g) self.assertAllClose(result, expected, rtol=1e-3) def testFunctionMultipleRetvals(self): @@ -90,7 +90,7 @@ class FunctionTest(xla_test.XLATestCase): bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32) expected = Func(aval, bval) - with self.cached_session() as sess: + with self.cached_session(): @function.Defun(dtypes.float32, dtypes.float32) def Foo(a, b): @@ -100,7 +100,7 @@ class FunctionTest(xla_test.XLATestCase): b = constant_op.constant(bval, name="b") with self.test_scope(): call_f = Foo(a, b) - result = sess.run(call_f) + result = self.evaluate(call_f) self.assertAllClose(result, expected, rtol=1e-3) def testCompileTimeConstantsInDefun(self): diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py index 6f51ae33a1b..dbea9849e21 100644 --- a/tensorflow/compiler/tests/jit_test.py +++ b/tensorflow/compiler/tests/jit_test.py @@ -75,7 +75,7 @@ def RunMetadataLabels(run_metadata): def InLabels(labels, substr): """Returns true iff one of the labels contains substr.""" - return any([substr in x for x in labels]) + return any(substr in x for x in labels) def MetadataHasXlaRunOp(run_metadata): diff --git a/tensorflow/compiler/tests/listdiff_op_test.py b/tensorflow/compiler/tests/listdiff_op_test.py index 58622114e4f..0210201fa71 100644 --- a/tensorflow/compiler/tests/listdiff_op_test.py +++ b/tensorflow/compiler/tests/listdiff_op_test.py @@ -33,13 +33,13 @@ class ListDiffTest(xla_test.XLATestCase): def _testListDiff(self, x, y, out, idx): for dtype in [dtypes.int32, dtypes.int64]: for index_dtype in [dtypes.int32, dtypes.int64]: - with self.cached_session() as sess: + with self.cached_session(): x_tensor = ops.convert_to_tensor(x, dtype=dtype) y_tensor = ops.convert_to_tensor(y, dtype=dtype) with self.test_scope(): out_tensor, idx_tensor = array_ops.listdiff( x_tensor, y_tensor, out_idx=index_dtype) - tf_out, tf_idx = sess.run([out_tensor, idx_tensor]) + tf_out, tf_idx = self.evaluate([out_tensor, idx_tensor]) self.assertAllEqual(out, tf_out) self.assertAllEqual(idx, tf_idx) self.assertEqual(1, out_tensor.get_shape().ndims) diff --git a/tensorflow/compiler/tests/lrn_ops_test.py b/tensorflow/compiler/tests/lrn_ops_test.py index c6ad67993e8..5dddf6ae4e8 100644 --- a/tensorflow/compiler/tests/lrn_ops_test.py +++ b/tensorflow/compiler/tests/lrn_ops_test.py @@ -120,8 +120,8 @@ class LRNTest(xla_test.XLATestCase): with self.test_scope(): actual = gen_nn_ops.lrn_grad(out_grads, in_image, out_image, depth_radius, bias, alpha, beta) - expected_val = expected.eval() - actual_val = actual.eval() + expected_val = self.evaluate(expected) + actual_val = self.evaluate(actual) self.assertAllClose(actual_val, expected_val, rtol=1e-3) diff --git a/tensorflow/compiler/tests/lstm_test.py b/tensorflow/compiler/tests/lstm_test.py index 265c0b6d141..776ed899e68 100644 --- a/tensorflow/compiler/tests/lstm_test.py +++ b/tensorflow/compiler/tests/lstm_test.py @@ -88,8 +88,8 @@ class LSTMTest(test.TestCase): (basename, m_prev_scalar, c_prev_scalar, pad_scalar)) # Initialize variables and run the unrolled LSTM step. - sess.run(variables.global_variables_initializer()) - return sess.run([m, c]) + self.evaluate(variables.global_variables_initializer()) + return self.evaluate([m, c]) def testLSTMCell(self): # Run with all-0 weights, no padding. @@ -173,8 +173,8 @@ class LSTMTest(test.TestCase): (basename, m_init_scalar, c_init_scalar, pad_scalar)) # Initialize variables and run the unrolled LSTM layer. - sess.run(variables.global_variables_initializer()) - return sess.run(out_seq) + self.evaluate(variables.global_variables_initializer()) + return self.evaluate(out_seq) def testLSTMLayer(self): # Run with all-0 weights, no padding. diff --git a/tensorflow/compiler/tests/momentum_test.py b/tensorflow/compiler/tests/momentum_test.py index f77521a7c49..3416f7dbd6b 100644 --- a/tensorflow/compiler/tests/momentum_test.py +++ b/tensorflow/compiler/tests/momentum_test.py @@ -61,37 +61,43 @@ class MomentumOptimizerTest(xla_test.XLATestCase): self.assertFalse(slot1 in variables.trainable_variables()) # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Step 1: the momentum accumulators where 0. So we should see a normal # update: v -= grad * learning_rate mom_update.run() # Check that the momentum accumulators have been updated. - self.assertAllCloseAccordingToType(np.array([0.1, 0.1]), slot0.eval()) - self.assertAllCloseAccordingToType(np.array([0.01, 0.01]), slot1.eval()) + self.assertAllCloseAccordingToType( + np.array([0.1, 0.1]), self.evaluate(slot0)) + self.assertAllCloseAccordingToType( + np.array([0.01, 0.01]), self.evaluate(slot1)) # Check that the parameters have been updated. self.assertAllCloseAccordingToType( - np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]), var0.eval()) + np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]), + self.evaluate(var0)) self.assertAllCloseAccordingToType( - np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]), var1.eval()) + np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]), + self.evaluate(var1)) # Step 2: the momentum accumulators contain the previous update. mom_update.run() # Check that the momentum accumulators have been updated. self.assertAllCloseAccordingToType( - np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), slot0.eval()) + np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), + self.evaluate(slot0)) self.assertAllCloseAccordingToType( - np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]), slot1.eval()) + np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]), + self.evaluate(slot1)) # Check that the parameters have been updated. self.assertAllCloseAccordingToType( np.array([ 1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0), 2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0) - ]), var0.eval()) + ]), self.evaluate(var0)) self.assertAllCloseAccordingToType( np.array([ - 2.98 - ((0.9 * 0.01 + 0.01) * 2.0), 3.98 - ( - (0.9 * 0.01 + 0.01) * 2.0) - ]), var1.eval()) + 2.98 - ((0.9 * 0.01 + 0.01) * 2.0), + 3.98 - ((0.9 * 0.01 + 0.01) * 2.0) + ]), self.evaluate(var1)) def testNesterovMomentum(self): for dtype in self.float_types: @@ -115,8 +121,8 @@ class MomentumOptimizerTest(xla_test.XLATestCase): var0_np, accum0_np, var0_np * 0.8, 0.1, 0.9) var1_np, accum1_np = self._update_nesterov_momentum_numpy( var1_np, accum1_np, 0.9, 0.1, 0.9) - self.assertAllCloseAccordingToType(var0_np, var0.eval()) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) def testTensorLearningRateAndMomentum(self): for dtype in self.float_types: @@ -141,37 +147,43 @@ class MomentumOptimizerTest(xla_test.XLATestCase): self.assertFalse(slot1 in variables.trainable_variables()) # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Step 1: the momentum accumulators where 0. So we should see a normal # update: v -= grad * learning_rate mom_update.run() # Check that the momentum accumulators have been updated. - self.assertAllCloseAccordingToType(np.array([0.1, 0.1]), slot0.eval()) - self.assertAllCloseAccordingToType(np.array([0.01, 0.01]), slot1.eval()) + self.assertAllCloseAccordingToType( + np.array([0.1, 0.1]), self.evaluate(slot0)) + self.assertAllCloseAccordingToType( + np.array([0.01, 0.01]), self.evaluate(slot1)) # Check that the parameters have been updated. self.assertAllCloseAccordingToType( - np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]), var0.eval()) + np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]), + self.evaluate(var0)) self.assertAllCloseAccordingToType( - np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]), var1.eval()) + np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]), + self.evaluate(var1)) # Step 2: the momentum accumulators contain the previous update. mom_update.run() # Check that the momentum accumulators have been updated. self.assertAllCloseAccordingToType( - np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), slot0.eval()) + np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), + self.evaluate(slot0)) self.assertAllCloseAccordingToType( - np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]), slot1.eval()) + np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]), + self.evaluate(slot1)) # Check that the parameters have been updated. self.assertAllCloseAccordingToType( np.array([ 1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0), 2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0) - ]), var0.eval()) + ]), self.evaluate(var0)) self.assertAllCloseAccordingToType( np.array([ - 2.98 - ((0.9 * 0.01 + 0.01) * 2.0), 3.98 - ( - (0.9 * 0.01 + 0.01) * 2.0) - ]), var1.eval()) + 2.98 - ((0.9 * 0.01 + 0.01) * 2.0), + 3.98 - ((0.9 * 0.01 + 0.01) * 2.0) + ]), self.evaluate(var1)) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/placeholder_test.py b/tensorflow/compiler/tests/placeholder_test.py index 77bb839409f..9671ae0ae97 100644 --- a/tensorflow/compiler/tests/placeholder_test.py +++ b/tensorflow/compiler/tests/placeholder_test.py @@ -33,7 +33,7 @@ class PlaceholderTest(xla_test.XLATestCase): ph = array_ops.placeholder_with_default(v, shape=[]) out = ph * 2 sess.run(variables.variables_initializer([v])) - self.assertEqual(8.0, sess.run(out)) + self.assertEqual(8.0, self.evaluate(out)) def test_placeholder_with_default_fed(self): with self.cached_session() as sess, self.test_scope(): diff --git a/tensorflow/compiler/tests/powersign_test.py b/tensorflow/compiler/tests/powersign_test.py index 86536da7fed..5b35c200277 100644 --- a/tensorflow/compiler/tests/powersign_test.py +++ b/tensorflow/compiler/tests/powersign_test.py @@ -91,8 +91,8 @@ class PowerSignTest(xla_test.XLATestCase): variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Run 7 steps of powersign # first 4 steps with positive gradient @@ -125,8 +125,8 @@ class PowerSignTest(xla_test.XLATestCase): ) # Validate updated params - self.assertAllCloseAccordingToType(var0_np, var0.eval()) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) def testDense(self): decay_steps = 10 diff --git a/tensorflow/compiler/tests/proximal_adagrad_test.py b/tensorflow/compiler/tests/proximal_adagrad_test.py index c41b4171e26..63cc51a4701 100644 --- a/tensorflow/compiler/tests/proximal_adagrad_test.py +++ b/tensorflow/compiler/tests/proximal_adagrad_test.py @@ -45,15 +45,17 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase): update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() - self.assertAllClose([0.0, 0.0], var0.eval()) - self.assertAllClose([0.0, 0.0], var1.eval()) + self.assertAllClose([0.0, 0.0], self.evaluate(var0)) + self.assertAllClose([0.0, 0.0], self.evaluate(var1)) # Run 3 steps Proximal Adagrad. for _ in range(3): update.run() - self.assertAllClose(np.array([-2.60260963, -4.29698515]), var0.eval()) - self.assertAllClose(np.array([-0.28432083, -0.56694895]), var1.eval()) + self.assertAllClose( + np.array([-2.60260963, -4.29698515]), self.evaluate(var0)) + self.assertAllClose( + np.array([-0.28432083, -0.56694895]), self.evaluate(var1)) opt_vars = opt.variables() self.assertStartsWith(opt_vars[0].name, var0._shared_name) self.assertStartsWith(opt_vars[1].name, var1._shared_name) @@ -74,14 +76,14 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase): update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([4.0, 3.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([4.0, 3.0], self.evaluate(var1)) # Run 3 steps Proximal Adagrad. for _ in range(3): update.run() - self.assertAllClose(np.array([-1.60261, -2.296985]), var0.eval()) - self.assertAllClose(np.array([3.715679, 2.433051]), var1.eval()) + self.assertAllClose(np.array([-1.60261, -2.296985]), self.evaluate(var0)) + self.assertAllClose(np.array([3.715679, 2.433051]), self.evaluate(var1)) def testProximalAdagradWithL1(self): with self.cached_session(), self.test_scope(): @@ -98,14 +100,14 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase): update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([4.0, 3.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([4.0, 3.0], self.evaluate(var1)) # Run 10 steps Proximal Adagrad for _ in range(10): update.run() - self.assertAllClose(np.array([-6.663634, -9.190331]), var0.eval()) - self.assertAllClose(np.array([2.959304, 1.029232]), var1.eval()) + self.assertAllClose(np.array([-6.663634, -9.190331]), self.evaluate(var0)) + self.assertAllClose(np.array([2.959304, 1.029232]), self.evaluate(var1)) def testProximalAdagradWithL1_L2(self): with self.cached_session(), self.test_scope(): @@ -122,15 +124,15 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase): update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([4.0, 3.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([4.0, 3.0], self.evaluate(var1)) # Run 10 steps Proximal Adagrad. for _ in range(10): update.run() - self.assertAllClose(np.array([-0.0495, -0.0995]), var0.eval()) - self.assertAllClose(np.array([-0.0045, -0.0095]), var1.eval()) + self.assertAllClose(np.array([-0.0495, -0.0995]), self.evaluate(var0)) + self.assertAllClose(np.array([-0.0045, -0.0095]), self.evaluate(var1)) def applyOptimizer(self, opt, steps=5): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) @@ -141,14 +143,14 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase): update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Run ProximalAdagrad for a few steps for _ in range(steps): update.run() - return var0.eval(), var1.eval() + return self.evaluate(var0), self.evaluate(var1) def testEquivAdagradwithoutRegularization(self): with self.cached_session(), self.test_scope(): diff --git a/tensorflow/compiler/tests/proximal_gradient_descent_test.py b/tensorflow/compiler/tests/proximal_gradient_descent_test.py index 3d808e6b8a7..5aec433be76 100644 --- a/tensorflow/compiler/tests/proximal_gradient_descent_test.py +++ b/tensorflow/compiler/tests/proximal_gradient_descent_test.py @@ -42,15 +42,15 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() - self.assertAllClose([0.0, 0.0], var0.eval()) - self.assertAllClose([0.0, 0.0], var1.eval()) + self.assertAllClose([0.0, 0.0], self.evaluate(var0)) + self.assertAllClose([0.0, 0.0], self.evaluate(var1)) # Run 3 steps Proximal Gradient Descent. for _ in range(3): update.run() - self.assertAllClose(np.array([-0.9, -1.8]), var0.eval()) - self.assertAllClose(np.array([-0.09, -0.18]), var1.eval()) + self.assertAllClose(np.array([-0.9, -1.8]), self.evaluate(var0)) + self.assertAllClose(np.array([-0.09, -0.18]), self.evaluate(var1)) def testProximalGradientDescentwithoutRegularization2(self): with self.cached_session(), self.test_scope(): @@ -64,15 +64,15 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([4.0, 3.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([4.0, 3.0], self.evaluate(var1)) # Run 3 steps Proximal Gradient Descent for _ in range(3): update.run() - self.assertAllClose(np.array([0.1, 0.2]), var0.eval()) - self.assertAllClose(np.array([3.91, 2.82]), var1.eval()) + self.assertAllClose(np.array([0.1, 0.2]), self.evaluate(var0)) + self.assertAllClose(np.array([3.91, 2.82]), self.evaluate(var1)) def testProximalGradientDescentWithL1(self): with self.cached_session(), self.test_scope(): @@ -86,15 +86,15 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([4.0, 3.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([4.0, 3.0], self.evaluate(var1)) # Run 10 steps proximal gradient descent. for _ in range(10): update.run() - self.assertAllClose(np.array([-1.988, -3.988001]), var0.eval()) - self.assertAllClose(np.array([3.67, 2.37]), var1.eval()) + self.assertAllClose(np.array([-1.988, -3.988001]), self.evaluate(var0)) + self.assertAllClose(np.array([3.67, 2.37]), self.evaluate(var1)) def testProximalGradientDescentWithL1_L2(self): with self.cached_session(), self.test_scope(): @@ -108,15 +108,15 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([4.0, 3.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([4.0, 3.0], self.evaluate(var1)) # Run 10 steps Proximal Gradient Descent for _ in range(10): update.run() - self.assertAllClose(np.array([-0.0495, -0.0995]), var0.eval()) - self.assertAllClose(np.array([-0.0045, -0.0095]), var1.eval()) + self.assertAllClose(np.array([-0.0495, -0.0995]), self.evaluate(var0)) + self.assertAllClose(np.array([-0.0045, -0.0095]), self.evaluate(var1)) def applyOptimizer(self, opt, steps=5): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) @@ -127,14 +127,14 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Run ProximalAdagrad for a few steps for _ in range(steps): update.run() - return var0.eval(), var1.eval() + return self.evaluate(var0), self.evaluate(var1) def testEquivGradientDescentwithoutRegularization(self): with self.cached_session(), self.test_scope(): diff --git a/tensorflow/compiler/tests/qr_op_test.py b/tensorflow/compiler/tests/qr_op_test.py index 236b1b881dc..b4d4193e35f 100644 --- a/tensorflow/compiler/tests/qr_op_test.py +++ b/tensorflow/compiler/tests/qr_op_test.py @@ -63,7 +63,7 @@ class QrOpTest(xla_test.XLATestCase, parameterized.TestCase): # Tests that x[...,:,:]^H * x[...,:,:] is close to the identity. xx = math_ops.matmul(x, x, adjoint_a=True) identity = array_ops.matrix_band_part(array_ops.ones_like(xx), 0, 0) - precision = self.AdjustedNorm(xx.eval() - identity.eval()) + precision = self.AdjustedNorm(xx.eval() - self.evaluate(identity)) self.assertTrue(np.all(precision < 5.0)) def _test(self, dtype, shape, full_matrices): diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py index 36ef6ed5fee..97ffad34c00 100644 --- a/tensorflow/compiler/tests/random_ops_test.py +++ b/tensorflow/compiler/tests/random_ops_test.py @@ -46,9 +46,9 @@ class RandomOpsTest(xla_test.XLATestCase): # The random-number generator, if working correctly, should produce the # same output multiple times with low probability. - y = sess.run(x) - z = sess.run(x) - w = sess.run(x) + y = self.evaluate(x) + z = self.evaluate(x) + w = self.evaluate(x) # We use exact equality here. If the random-number generator is producing # deterministic output, all three outputs will be bitwise identical. @@ -83,7 +83,7 @@ class RandomOpsTest(xla_test.XLATestCase): with self.test_scope(): x = random_ops.random_uniform( shape=[1000], dtype=dtype, minval=-2, maxval=33) - y = sess.run(x) + y = self.evaluate(x) self.assertTrue((y >= -2).sum() == 1000) self.assertTrue((y < 33).sum() == 1000) @@ -102,7 +102,7 @@ class RandomOpsTest(xla_test.XLATestCase): with self.cached_session() as sess: with self.test_scope(): x = random_ops.truncated_normal(shape=[count], dtype=dtype) - y = sess.run(x) + y = self.evaluate(x) def normal_cdf(x): return .5 * math.erfc(-x / math.sqrt(2)) @@ -111,7 +111,7 @@ class RandomOpsTest(xla_test.XLATestCase): return math.exp(-(x**2) / 2.) / math.sqrt(2 * math.pi) def probit(x, sess=sess): - return sess.run(special_math.ndtri(x)) + return self.evaluate(special_math.ndtri(x)) a = -2. b = 2. @@ -148,7 +148,7 @@ class RandomOpsTest(xla_test.XLATestCase): with self.test_scope(): x = math_ops.range(1 << 16) shuffle = random_ops.random_shuffle(x) - result = sess.run(shuffle) + result = self.evaluate(shuffle) expected = range(1 << 16) # Compare sets to avoid randomness behavior changes but make sure still # have all the values. @@ -159,7 +159,7 @@ class RandomOpsTest(xla_test.XLATestCase): with self.test_scope(): x = array_ops.diag(math_ops.range(20)) shuffle = random_ops.random_shuffle(x) - result = sess.run(shuffle) + result = self.evaluate(shuffle) expected = np.diag(range(20)).flatten() # Compare sets to avoid randomness behavior changes but make sure still # have all the values. diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index a6b58020126..d23fd125163 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -3382,10 +3382,10 @@ int main(int argc, char** argv) { } // XLA devices register kernels at construction time; create all known devices // to make sure the kernels are registered. - std::vector devices; + std::vector> devices; TF_CHECK_OK(tensorflow::DeviceFactory::AddDevices( tensorflow::SessionOptions(), "", &devices)); - tensorflow::DeviceMgr device_mgr(devices); + tensorflow::DeviceMgr device_mgr(std::move(devices)); tensorflow::Device* ignored; TF_QCHECK_OK( diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py index 132c59c32c9..e8fc81bbb54 100644 --- a/tensorflow/compiler/tests/reduce_ops_test.py +++ b/tensorflow/compiler/tests/reduce_ops_test.py @@ -91,6 +91,7 @@ class ReduceOpsTest(xla_test.XLATestCase, parameterized.TestCase): np.array([], dtype=np.bool).reshape(0, 3), np.array([[False, True, False], [True, True, False]]), ] + ONES = [np.ones([34000, 2])] def testReduceSumF32(self, index_dtype): self._testReduction(math_ops.reduce_sum, np.sum, np.float32, self.REAL_DATA, @@ -149,6 +150,11 @@ class ReduceOpsTest(xla_test.XLATestCase, parameterized.TestCase): self._testReduction(math_ops.reduce_mean, np.mean, np.float32, self.NONEMPTY_REAL_DATA, index_dtype) + def testReduceMeanF16(self, index_dtype): + if np.float16 in self.all_types: + self._testReduction(math_ops.reduce_mean, np.mean, np.float16, self.ONES, + index_dtype) + def testReduceMeanC64(self, index_dtype): self._testReduction(math_ops.reduce_mean, np.mean, np.complex64, self.NONEMPTY_COMPLEX_DATA, index_dtype) diff --git a/tensorflow/compiler/tests/rmsprop_test.py b/tensorflow/compiler/tests/rmsprop_test.py index 8840a1329a9..dc3e90b4afa 100644 --- a/tensorflow/compiler/tests/rmsprop_test.py +++ b/tensorflow/compiler/tests/rmsprop_test.py @@ -76,7 +76,7 @@ class RmspropTest(xla_test.XLATestCase): rms_opt = rmsprop.RMSPropOptimizer(learning_rate, centered=centered) rms_update = rms_opt.apply_gradients( zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() + self.evaluate(variables.global_variables_initializer()) mg0 = rms_opt.get_slot(var0, "mg") self.assertEqual(mg0 is not None, centered) @@ -92,12 +92,12 @@ class RmspropTest(xla_test.XLATestCase): self.assertTrue(mom1 is not None) # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Run 3 steps of RMSProp for _ in range(3): - rms_update.run() + self.evaluate(rms_update) var0_np, mg0_np, rms0_np, mom0_np = self._rmsprop_update_numpy( var0_np, @@ -118,14 +118,14 @@ class RmspropTest(xla_test.XLATestCase): # Validate updated params if centered: - self.assertAllCloseAccordingToType(mg0_np, mg0.eval()) - self.assertAllCloseAccordingToType(mg1_np, mg1.eval()) - self.assertAllCloseAccordingToType(rms0_np, rms0.eval()) - self.assertAllCloseAccordingToType(rms1_np, rms1.eval()) - self.assertAllCloseAccordingToType(mom0_np, mom0.eval()) - self.assertAllCloseAccordingToType(mom1_np, mom1.eval()) - self.assertAllCloseAccordingToType(var0_np, var0.eval()) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) + self.assertAllCloseAccordingToType(mg0_np, self.evaluate(mg0)) + self.assertAllCloseAccordingToType(mg1_np, self.evaluate(mg1)) + self.assertAllCloseAccordingToType(rms0_np, self.evaluate(rms0)) + self.assertAllCloseAccordingToType(rms1_np, self.evaluate(rms1)) + self.assertAllCloseAccordingToType(mom0_np, self.evaluate(mom0)) + self.assertAllCloseAccordingToType(mom1_np, self.evaluate(mom1)) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/scan_ops_test.py b/tensorflow/compiler/tests/scan_ops_test.py index 897db384b7e..17639bd8a75 100644 --- a/tensorflow/compiler/tests/scan_ops_test.py +++ b/tensorflow/compiler/tests/scan_ops_test.py @@ -71,7 +71,7 @@ def handle_options(func, x, axis, exclusive, reverse): class CumsumTest(xla_test.XLATestCase): - valid_dtypes = [np.float32] + valid_dtypes = [np.float32, np.int32] def axis_dtypes(self): return set(self.int_types).intersection([np.int32, np.int64]) @@ -149,7 +149,7 @@ class CumsumTest(xla_test.XLATestCase): class CumprodTest(xla_test.XLATestCase): - valid_dtypes = [np.float32] + valid_dtypes = [np.float32, np.int32] def axis_dtypes(self): return set(self.int_types).intersection([np.int32, np.int64]) diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py index 21708aa1587..ee7ca7e6f19 100644 --- a/tensorflow/compiler/tests/stateless_random_ops_test.py +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -156,7 +156,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): return math.exp(-(x**2) / 2.) / math.sqrt(2 * math.pi) def probit(x, sess=sess): - return sess.run(special_math.ndtri(x)) + return self.evaluate(special_math.ndtri(x)) a = -2. b = 2. diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py index 46ca371c8ab..d7e26d79c4c 100644 --- a/tensorflow/compiler/tests/tensor_array_ops_test.py +++ b/tensorflow/compiler/tests/tensor_array_ops_test.py @@ -79,7 +79,8 @@ class TensorArrayTest(xla_test.XLATestCase): c0 = w2.stack() self.assertAllEqual( - convert([[[4.0, 5.0]], [[6.0, 7.0]], [[8.0, 9.0]]]), c0.eval()) + convert([[[4.0, 5.0]], [[6.0, 7.0]], [[8.0, 9.0]]]), + self.evaluate(c0)) def testTensorArrayWritePack(self): for dtype in self.numeric_tf_types: @@ -97,7 +98,7 @@ class TensorArrayTest(xla_test.XLATestCase): c0 = w2.stack() - self.assertAllEqual([3, 0, 1], c0.eval().shape) + self.assertAllEqual([3, 0, 1], self.evaluate(c0).shape) def _testTensorArrayWriteConcat(self, tf_dtype): with self.cached_session(), self.test_scope(): @@ -113,8 +114,8 @@ class TensorArrayTest(xla_test.XLATestCase): c0 = w2.concat() self.assertAllEqual( - convert([[4.0, 5.0], [104.0, 105.0], [6.0, 7.0], - [106.0, 107.0], [8.0, 9.0], [204.0, 205.0]]), c0.eval()) + convert([[4.0, 5.0], [104.0, 105.0], [6.0, 7.0], [106.0, 107.0], + [8.0, 9.0], [204.0, 205.0]]), self.evaluate(c0)) def testTensorArrayWriteConcat(self): for dtype in self.numeric_tf_types: @@ -341,7 +342,7 @@ class TensorArrayTest(xla_test.XLATestCase): r0_bad = gen_data_flow_ops.tensor_array_read_v3( handle=w0.handle, index=0, dtype=dtype2, flow_in=w0.flow) with self.assertRaisesOpError("TensorArray dtype is "): - r0_bad.eval() + self.evaluate(r0_bad) # Test reading from a different index than the one we wrote to w0.read(1) @@ -422,7 +423,7 @@ class TensorArrayTest(xla_test.XLATestCase): w2 = h2.write(0, 5.0) r2 = w2.read(0) r = r1 + r2 - self.assertAllClose(9.0, r.eval()) + self.assertAllClose(9.0, self.evaluate(r)) def _testTensorArrayGradientWriteReadType(self, dtype): with self.cached_session() as session, self.test_scope(): @@ -504,7 +505,7 @@ class TensorArrayTest(xla_test.XLATestCase): [-0.5, 1.5], # read(0) gradient [20.0, 30.0, 40.0, 50.0], # concat gradient ]) - grad_vals = sess.run(grad_r) # 2 + 2 entries + grad_vals = self.evaluate(grad_r) # 2 + 2 entries self.assertAllClose([2.0 - 0.5 + 20.0, 3.0 + 1.5 + 30.0], grad_vals[0]) self.assertAllEqual([4.0 + 40.0, 5.0 + 50.0], grad_vals[1]) @@ -526,7 +527,7 @@ class TensorArrayTest(xla_test.XLATestCase): with ops.control_dependencies([r0_readtwice]): r1_readtwice = w_readtwice.read(0) - self.assertAllEqual([1.0, -1.0], r1_readtwice.eval()) + self.assertAllEqual([1.0, -1.0], self.evaluate(r1_readtwice)) def _testTensorArrayGradientUnpackRead(self): with self.cached_session() as session, self.test_scope(): @@ -592,7 +593,7 @@ class TensorArrayTest(xla_test.XLATestCase): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", size=3) s = ta.size() - self.assertAllEqual(3, s.eval()) + self.assertAllEqual(3, self.evaluate(s)) def testWriteCloseTensorArray(self): with self.cached_session(), self.test_scope(): @@ -722,7 +723,7 @@ class TensorArrayTest(xla_test.XLATestCase): # r = acc2.stack() # grad = gradients_impl.gradients(r, [x])[0] - # self.assertAllClose(31.0, grad.eval()) + # self.assertAllClose(31.0, self.evaluate(grad)) def testSumOfTwoReadVariablesWithoutRepeatGrad(self): with self.cached_session() as session, self.test_scope(): @@ -912,7 +913,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertEqual(0, ta.size().eval()) ta = ta.unstack(array_ops.zeros([0, 3, 5])) packed = ta.stack() - self.assertAllEqual([0, 3, 5], packed.eval().shape) + self.assertAllEqual([0, 3, 5], self.evaluate(packed).shape) # Concatenating zero tensors along their first dimension gives a # first dimension of zero self.assertAllEqual([0, 5], ta.concat().eval().shape) @@ -1041,8 +1042,8 @@ class TensorArrayTest(xla_test.XLATestCase): (read0, read1, size0, size1)) # Tests that the control dependencies was added and executed. - self.assertEqual(1, v0.eval()) - self.assertEqual(1, v1.eval()) + self.assertEqual(1, self.evaluate(v0)) + self.assertEqual(1, self.evaluate(v1)) # Tests correct TensorArray. self.assertEqual(read0_v, 0) diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index d612d3b32dd..95c9e7ffd46 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -481,6 +481,72 @@ class UnaryOpsTest(xla_test.XLATestCase): np.array([-1, -0.5, 0, 0.3], dtype=dtype), expected=np.array([-1., -0.5, 0., 0.296875], dtype=dtype)) + def quantize_and_dequantize_v2_round_half_up(x): + return array_ops.quantize_and_dequantize_v2( + x, + -1, + 1.0, + signed_input=True, + num_bits=8, + range_given=True, + round_mode="HALF_UP") + + self._assertOpOutputMatchesExpected( + quantize_and_dequantize_v2_round_half_up, + np.array([-0.8, -0.5, 0, 0.3, 0.8, -2, 33], dtype=dtype), + expected=np.array([ + -102.0 / 127, + -63.0 / 127, + 0, + 38.0 / 127, + 102.0 / 127, + -128.0 / 127, + 1, + ], + dtype=dtype)) + + def quantize_and_dequantize_v2_round_half_to_even(x): + return array_ops.quantize_and_dequantize_v2( + x, + -1.0, + 1.0, + signed_input=True, + num_bits=8, + range_given=True, + round_mode="HALF_TO_EVEN") + + self._assertOpOutputMatchesExpected( + quantize_and_dequantize_v2_round_half_to_even, + np.array( + [ + -0.8, + # The -0.5 should become -63.5 after scaling and with + # rounding this should become -64. But with the test + # unary_ops_test_cpu_ondemand, this fails as the result + # before scaling becomes -63.499996 and gets rounded to -63. + # TODO(sreenik): Some one more familiar with this test needs + # to take a look and resolve this. This works on all other + # variations of the platform like cpu, and gpu. + # -0.5, + 0, + 0.3, + 0.8, + -2, + 33 + ], + dtype=dtype), + expected=np.array( + [ + -102.0 / 127, + # -64.0 / 127, + 0, + 38.0 / 127, + 102.0 / 127, + -128.0 / 127, + 1, + ], + dtype=dtype)) + def quantize_and_dequantize_v3(x): return array_ops.quantize_and_dequantize_v3( x, -127, 127, num_bits=8, signed_input=True, range_given=False) diff --git a/tensorflow/compiler/tests/variable_ops_test.py b/tensorflow/compiler/tests/variable_ops_test.py index 77cdeac8168..fcd7ac5ba1c 100644 --- a/tensorflow/compiler/tests/variable_ops_test.py +++ b/tensorflow/compiler/tests/variable_ops_test.py @@ -77,7 +77,7 @@ class VariableOpsTest(xla_test.XLATestCase): sess.run(variables.variables_initializer([v])) x = v.sparse_read(2) self.assertAllClose( - np.array([8j, 9, 10, 11]).astype(dtype), sess.run(x)) + np.array([8j, 9, 10, 11]).astype(dtype), self.evaluate(x)) def testSparseRead1DIndices(self): for dtype in self.numeric_types: @@ -89,7 +89,7 @@ class VariableOpsTest(xla_test.XLATestCase): x = v.sparse_read([2, 1]) self.assertAllClose( np.array([[8, 9, 10, 11], [4, 5, 6j, 7]]).astype(dtype), - sess.run(x)) + self.evaluate(x)) def testSparseRead2DIndices(self): for dtype in self.numeric_types: @@ -102,7 +102,7 @@ class VariableOpsTest(xla_test.XLATestCase): self.assertAllClose( np.array([[[8, 9, 10, 11], [4, 5, 6, 7]], [[0, 1, 2j, 3], [8, 9, 10, 11]]]).astype(dtype), - sess.run(x)) + self.evaluate(x)) def testSparseRead2DIndices3DTensor(self): for dtype in self.numeric_types: @@ -115,9 +115,9 @@ class VariableOpsTest(xla_test.XLATestCase): x = v.sparse_read([[2, 1], [3, 0]]) self.assertAllClose( np.array( - [[[[20, 21, 22], [23, 24j, 25]], [[10, 11, 12], [13, 14, 15]] - ], [[[30, 31, 32], [33, 34, 35]], [[0, 1, 2], [3, 4, 5]]] - ],).astype(dtype), sess.run(x)) + [[[[20, 21, 22], [23, 24j, 25]], [[10, 11, 12], [13, 14, 15]]], + [[[30, 31, 32], [33, 34, 35]], [[0, 1, 2], [3, 4, 5]]] + ],).astype(dtype), self.evaluate(x)) def testShape(self): for dtype in self.numeric_types: @@ -229,7 +229,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_add( handle, [0], constant_op.constant([[2]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertAllEqual(sess.run(read), [[3], [7]]) + self.assertAllEqual(self.evaluate(read), [[3], [7]]) def testScatterSub(self): with self.test_session() as sess, self.test_scope(): @@ -242,7 +242,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_sub( handle, [1], constant_op.constant([[2]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertAllEqual(sess.run(read), [[4], [-1]]) + self.assertAllEqual(self.evaluate(read), [[4], [-1]]) def testScatterMul(self): with self.test_session() as sess, self.test_scope(): @@ -255,7 +255,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_mul( handle, [0], constant_op.constant([[5]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertEqual(sess.run(read), [[5]]) + self.assertEqual(self.evaluate(read), [[5]]) def testScatterDiv(self): with self.test_session() as sess, self.test_scope(): @@ -268,7 +268,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_div( handle, [0], constant_op.constant([[3]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertAllEqual(sess.run(read), [[2]]) + self.assertAllEqual(self.evaluate(read), [[2]]) def testScatterMin(self): with self.test_session() as sess, self.test_scope(): @@ -281,7 +281,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_min( handle, [0], constant_op.constant([[3]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertEqual(sess.run(read), [[3]]) + self.assertEqual(self.evaluate(read), [[3]]) def testScatterMax(self): with self.test_session() as sess, self.test_scope(): @@ -294,7 +294,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_max( handle, [0], constant_op.constant([[3]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertEqual(sess.run(read), [[6]]) + self.assertEqual(self.evaluate(read), [[6]]) def testScatterUpdate(self): with self.test_session() as sess, self.test_scope(): @@ -307,7 +307,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_update( handle, [0], constant_op.constant([[3]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertEqual(sess.run(read), [[3]]) + self.assertEqual(self.evaluate(read), [[3]]) def testScatterAddScalar(self): with self.test_session() as sess, self.test_scope(): @@ -320,7 +320,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_add( handle, [0], constant_op.constant(2, dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertEqual(sess.run(read), [[3]]) + self.assertEqual(self.evaluate(read), [[3]]) def testScatterSubScalar(self): with self.test_session() as sess, self.test_scope(): @@ -333,7 +333,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_sub( handle, [0], constant_op.constant(2, dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertEqual(sess.run(read), [[-1]]) + self.assertEqual(self.evaluate(read), [[-1]]) def testScatterMulScalar(self): with self.test_session() as sess, self.test_scope(): @@ -346,7 +346,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_mul( handle, [0], constant_op.constant(5, dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertEqual(sess.run(read), [[5]]) + self.assertEqual(self.evaluate(read), [[5]]) def testScatterDivScalar(self): with self.test_session() as sess, self.test_scope(): @@ -359,7 +359,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_div( handle, [0], constant_op.constant(3, dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertEqual(sess.run(read), [[2]]) + self.assertEqual(self.evaluate(read), [[2]]) def testScatterMinScalar(self): with self.test_session() as sess, self.test_scope(): @@ -372,7 +372,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_min( handle, [0], constant_op.constant(3, dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertEqual(sess.run(read), [[3]]) + self.assertEqual(self.evaluate(read), [[3]]) def testScatterMaxScalar(self): with self.test_session() as sess, self.test_scope(): @@ -385,7 +385,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_max( handle, [0], constant_op.constant(3, dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertEqual(sess.run(read), [[6]]) + self.assertEqual(self.evaluate(read), [[6]]) def testScatterNdAddOps(self): with self.test_session() as sess, self.test_scope(): @@ -400,7 +400,7 @@ class VariableOpsTest(xla_test.XLATestCase): sess.run(gen_state_ops.resource_scatter_nd_add(handle, indices, updates)) read = resource_variable_ops.read_variable_op( handle, dtype=dtypes.float32) - self.assertAllClose(expected, sess.run(read)) + self.assertAllClose(expected, self.evaluate(read)) def testScatterNdUpdateAddOps(self): with self.test_session() as sess, self.test_scope(): @@ -416,7 +416,7 @@ class VariableOpsTest(xla_test.XLATestCase): gen_state_ops.resource_scatter_nd_update(handle, indices, updates)) read = resource_variable_ops.read_variable_op( handle, dtype=dtypes.float32) - self.assertAllClose(expected, sess.run(read)) + self.assertAllClose(expected, self.evaluate(read)) class StridedSliceAssignChecker(object): diff --git a/tensorflow/compiler/tests/xla_device_test.py b/tensorflow/compiler/tests/xla_device_test.py index 28d61fb07dc..ef55292b1be 100644 --- a/tensorflow/compiler/tests/xla_device_test.py +++ b/tensorflow/compiler/tests/xla_device_test.py @@ -81,7 +81,7 @@ class XlaDeviceTest(xla_test.XLATestCase): with self.cached_session() as sess: with self.test_scope(): x = gen_control_flow_ops.control_trigger() - sess.run(x) + self.evaluate(x) if __name__ == "__main__": diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index e0171415492..5a0d9b9af9d 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -9,6 +9,7 @@ package_group( "//tensorflow/compiler/jit/...", "//tensorflow/compiler/tests/...", "//tensorflow/compiler/tf2xla/...", + "//tensorflow/contrib/compiler/...", ], ) @@ -195,8 +196,8 @@ cc_library( ":sharding_util", ":side_effect_util", ":tf2xla_util", + "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/jit:xla_cluster_util", - "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", "//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -204,13 +205,13 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/lib:constants", - "//tensorflow/compiler/xla/client/lib:numeric", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -221,6 +222,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], alwayslink = 1, @@ -437,21 +439,15 @@ cc_library( name = "dump_graph", srcs = [ "dump_graph.cc", - "dump_graph_flags.cc", - "dump_graph_flags.h", ], hdrs = [ "dump_graph.h", ], deps = [ - "//tensorflow/compiler/xla:parse_flags_from_env", - "//tensorflow/core:core_cpu", - "//tensorflow/core:core_cpu_internal", + "//tensorflow/compiler/jit:flags", "//tensorflow/core:framework", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", + "//tensorflow/core:graph", "//tensorflow/core:protos_all_cc", - "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/tf2xla/dump_graph.cc b/tensorflow/compiler/tf2xla/dump_graph.cc index 380c6a7e23d..64fdbbebc65 100644 --- a/tensorflow/compiler/tf2xla/dump_graph.cc +++ b/tensorflow/compiler/tf2xla/dump_graph.cc @@ -18,87 +18,26 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/dump_graph.h" -#include "absl/strings/str_cat.h" -#include "tensorflow/compiler/tf2xla/dump_graph_flags.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/compiler/jit/flags.h" +#include "tensorflow/core/util/dump_graph.h" namespace tensorflow { namespace dump_graph { -namespace { - -struct NameCounts { - mutex counts_mutex; - std::unordered_map counts; -}; - -string MakeUniqueFilename(string name) { - static NameCounts& instance = *new NameCounts; - - // Remove illegal characters from `name`. - for (int i = 0; i < name.size(); ++i) { - char ch = name[i]; - if (ch == '/' || ch == '[' || ch == ']' || ch == '*' || ch == '?') { - name[i] = '_'; - } - } - - int count; - { - mutex_lock lock(instance.counts_mutex); - count = instance.counts[name]++; - } - - string filename = name; - if (count > 0) { - absl::StrAppend(&filename, "_", count); - } - absl::StrAppend(&filename, ".pbtxt"); - return filename; -} - -string WriteTextProtoToUniqueFile( - Env* env, const string& name, const char* proto_type, - const ::tensorflow::protobuf::Message& proto) { - const string& dirname = - legacy_flags::GetDumpGraphFlags()->tf_dump_graph_prefix; - Status status = env->RecursivelyCreateDir(dirname); - if (!status.ok()) { - LOG(WARNING) << "Failed to create " << dirname << " for dumping " - << proto_type << ": " << status; - return "(unavailable)"; - } - string filepath = absl::StrCat(dirname, "/", MakeUniqueFilename(name)); - status = WriteTextProto(Env::Default(), filepath, proto); - if (!status.ok()) { - LOG(WARNING) << "Failed to dump " << proto_type << " to file: " << filepath - << " : " << status; - return "(unavailable)"; - } - LOG(INFO) << "Dumped " << proto_type << " to " << filepath; - return filepath; -} - -} // anonymous namespace - string DumpGraphDefToFile(const string& name, GraphDef const& graph_def) { - return WriteTextProtoToUniqueFile(Env::Default(), name, "GraphDef", - graph_def); + return tensorflow::DumpGraphDefToFile( + name, graph_def, GetDumpGraphFlags()->tf_dump_graph_prefix); } string DumpGraphToFile(const string& name, Graph const& graph, const FunctionLibraryDefinition* flib_def) { - GraphDef graph_def; - graph.ToGraphDef(&graph_def); - if (flib_def) { - *graph_def.mutable_library() = flib_def->ToProto(); - } - return DumpGraphDefToFile(name, graph_def); + return tensorflow::DumpGraphToFile(name, graph, flib_def, + GetDumpGraphFlags()->tf_dump_graph_prefix); } string DumpFunctionDefToFile(const string& name, FunctionDef const& fdef) { - return WriteTextProtoToUniqueFile(Env::Default(), name, "FunctionDef", fdef); + return tensorflow::DumpFunctionDefToFile( + name, fdef, GetDumpGraphFlags()->tf_dump_graph_prefix); } } // namespace dump_graph diff --git a/tensorflow/compiler/tf2xla/dump_graph_flags.cc b/tensorflow/compiler/tf2xla/dump_graph_flags.cc deleted file mode 100644 index 2eb1f8cd849..00000000000 --- a/tensorflow/compiler/tf2xla/dump_graph_flags.cc +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for the XLA bridge's dump_graph module. - -#include -#include - -#include "tensorflow/compiler/tf2xla/dump_graph_flags.h" -#include "tensorflow/compiler/xla/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static DumpGraphFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new DumpGraphFlags; - flags->tf_dump_graph_prefix = "/tmp/"; - flag_list = new std::vector({ - Flag("tf_dump_graph_prefix", &flags->tf_dump_graph_prefix, - "Path prefix to which graphs dumped during debugging should be " - "written."), - }); - xla::ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with the XLA bridge's -// dump_graph module. -void AppendDumpGraphFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the DumpGraphFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -DumpGraphFlags* GetDumpGraphFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/dump_graph_flags.h b/tensorflow/compiler/tf2xla/dump_graph_flags.h deleted file mode 100644 index 80a3307d920..00000000000 --- a/tensorflow/compiler/tf2xla/dump_graph_flags.h +++ /dev/null @@ -1,48 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_TF2XLA_DUMP_GRAPH_FLAGS_H_ -#define TENSORFLOW_COMPILER_TF2XLA_DUMP_GRAPH_FLAGS_H_ - -// Legacy flags for the XLA bridge's dump_graph module. - -#include - -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with the XLA bridge's -// dump_graph module. -void AppendDumpGraphFlags(std::vector* flag_list); - -// The values of flags associated with the XLA bridge's -// dump_graph module. -typedef struct { - string tf_dump_graph_prefix; // Path prefix to which graphs dumped during - // debugging should be written. -} DumpGraphFlags; - -// Return a pointer to the DumpGraphFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -DumpGraphFlags* GetDumpGraphFlags(); - -} // namespace legacy_flags -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_TF2XLA_DUMP_GRAPH_FLAGS_H_ diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 9ef9f49f422..3dfd3f854c8 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -75,6 +75,25 @@ Status FunctionalizeControlFlow(Graph* graph, return FunctionalizeControlFlow(/*lookup_library=*/nullptr, graph, library); } +Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def, + FunctionLibraryDefinition* library) { + return FunctionalizeControlFlowForGraphDef(/*lookup_library=*/nullptr, + graph_def, library); +} + +Status FunctionalizeControlFlowForGraphDef( + const FunctionLibraryDefinition* lookup_library, GraphDef* graph_def, + FunctionLibraryDefinition* library) { + FunctionDefLibrary function_lib = graph_def->library(); + Graph graph(OpRegistry::Global()); + + TF_RETURN_IF_ERROR(ConvertGraphDefToGraph({}, *graph_def, &graph)); + TF_RETURN_IF_ERROR(FunctionalizeControlFlow(lookup_library, &graph, library)); + graph.ToGraphDef(graph_def); + std::swap(*graph_def->mutable_library(), function_lib); + return Status::OK(); +} + Status FunctionalizeControlFlowForFunction( const string& func_name, const string& new_func_name, const protobuf::Map& attrs, diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/tensorflow/compiler/tf2xla/functionalize_control_flow.h index ba99205640c..91d33fa4058 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.h +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.h @@ -33,6 +33,12 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, Graph* graph, FunctionLibraryDefinition* library); +Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def, + FunctionLibraryDefinition* library); +Status FunctionalizeControlFlowForGraphDef( + const FunctionLibraryDefinition* lookup_library, GraphDef* graph_def, + FunctionLibraryDefinition* library); + // This pass looks at the graph and all associated FunctionDefs, and turns // traditional control flow structure (Switch/Merge/etc.) into functional // control flow structure (If/While). diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index c3841f996f8..9784985af83 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -95,77 +95,87 @@ TEST(FunctionalizeControlFlow, Conditional) { } FunctionLibraryDefinition library(OpRegistry::Global(), {}); + GraphDef optimized_graph_def; + graph.ToGraphDef(&optimized_graph_def); + TF_ASSERT_OK( + FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library)); TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); + GraphDef converted_graph_def; + graph.ToGraphDef(&converted_graph_def); - GraphDef graph_def; - graph.ToGraphDef(&graph_def); - string op_name; - NameAttrList then_fn; - NameAttrList else_fn; - TF_EXPECT_OK(FindIfThenAndElse(graph_def, &op_name, &then_fn, &else_fn)); - InstantiationResultForTest else_result; - TF_EXPECT_OK( - InstantiateFunctionForTest(else_fn.name(), library, &else_result)); + for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) { + string op_name; + NameAttrList then_fn; + NameAttrList else_fn; + TF_EXPECT_OK(FindIfThenAndElse(graph_def, &op_name, &then_fn, &else_fn)); + InstantiationResultForTest else_result; + TF_EXPECT_OK( + InstantiateFunctionForTest(else_fn.name(), library, &else_result)); - // Outer graph - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32); - auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); - auto less = ops::Less(scope.WithOpName("cond/Less"), y, x); - auto if_op = ops::If(scope.WithOpName(op_name), less, - std::initializer_list{less, y, x}, {DT_INT32}, - then_fn, else_fn); - auto id = ops::Identity(scope.WithOpName("cond/Merge"), if_op.output[0]); - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - TF_EXPECT_GRAPH_EQ(expected, graph_def); - } + // Outer graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32); + auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); + auto less = ops::Less(scope.WithOpName("cond/Less"), y, x); + auto if_op = ops::If(scope.WithOpName(op_name), less, + std::initializer_list{less, y, x}, {DT_INT32}, + then_fn, else_fn); + auto id = ops::Identity(scope.WithOpName("cond/Merge"), if_op.output[0]); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } - // then body. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg_0 = ops::_Arg(scope.WithOpName("_arg0"), DT_BOOL, 0); - auto arg_1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); - auto arg_2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); - auto identity = ops::Identity(scope.WithOpName("cond/Identity"), arg_0); - auto cond = ops::Const( - scope.WithOpName("cond").WithControlDependencies(identity), 17); - auto mul = ops::Mul(scope.WithOpName("cond/Mul"), arg_1, cond); - auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), mul, 0); + // then body. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg_0 = ops::_Arg(scope.WithOpName("_arg0"), DT_BOOL, 0); + auto arg_1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto arg_2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); + auto identity = ops::Identity(scope.WithOpName("cond/Identity"), arg_0); + auto cond = ops::Const( + scope.WithOpName("cond").WithControlDependencies(identity), 17); + auto mul = ops::Mul(scope.WithOpName("cond/Mul"), arg_1, cond); + auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), mul, 0); - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); - InstantiationResultForTest result; - TF_EXPECT_OK(InstantiateFunctionForTest(then_fn.name(), library, &result)); + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(then_fn.name(), library, &result)); - EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); - EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), result.arg_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); - } + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), + result.arg_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } - // else body. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg_0 = ops::_Arg(scope.WithOpName("_arg0"), DT_BOOL, 0); - auto arg_1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); - auto arg_2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); - auto identity = ops::Identity(scope.WithOpName("cond/Identity_1"), arg_0); - auto cond_1 = ops::Const( - scope.WithOpName("cond_1").WithControlDependencies(identity), 23); - auto add = ops::Add(scope.WithOpName("cond/false/add"), arg_2, cond_1); - auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); + // else body. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg_0 = ops::_Arg(scope.WithOpName("_arg0"), DT_BOOL, 0); + auto arg_1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto arg_2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); + auto identity = ops::Identity(scope.WithOpName("cond/Identity_1"), arg_0); + auto cond_1 = ops::Const( + scope.WithOpName("cond_1").WithControlDependencies(identity), 23); + auto add = ops::Add(scope.WithOpName("cond/false/add"), arg_2, cond_1); + auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); - InstantiationResultForTest result; - TF_EXPECT_OK(InstantiateFunctionForTest(else_fn.name(), library, &result)); + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(else_fn.name(), library, &result)); - EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); - EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), result.arg_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), + result.arg_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } } } @@ -239,75 +249,77 @@ TEST(FunctionalizeControlFlow, OneLoopVar) { } FunctionLibraryDefinition library(OpRegistry::Global(), {}); + GraphDef optimized_graph_def; + graph.ToGraphDef(&optimized_graph_def); + TF_ASSERT_OK( + FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library)); TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); + GraphDef converted_graph_def; + graph.ToGraphDef(&converted_graph_def); - GraphDef graph_def; - graph.ToGraphDef(&graph_def); + for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) { + NameAttrList cond_fn, body_fn; + TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); - NameAttrList cond_fn, body_fn; - TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); + // Outer graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); + auto while_op = + ops::While(scope.WithOpName("while/LoopCond"), + std::initializer_list{source}, cond_fn, body_fn); + auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } - // Outer graph - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); - auto while_op = - ops::While(scope.WithOpName("while/LoopCond"), - std::initializer_list{source}, cond_fn, body_fn); - auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]); - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - TF_EXPECT_GRAPH_EQ(expected, graph_def); - } + // Condition graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto ten = ops::Const( + scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10); + auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); - // Condition graph - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); - auto ten = ops::Const( - scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10); - auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten); - auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(cond_fn.name(), library, &result)); - InstantiationResultForTest result; - TF_EXPECT_OK(InstantiateFunctionForTest(cond_fn.name(), library, &result)); + EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); + EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } - EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); - EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); - } + // Body graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); + auto one = ops::Const( + scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); + auto add = ops::Add(scope.WithOpName("while/add"), identity, one); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); - // Body graph. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); - auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); - auto one = ops::Const( - scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); - auto add = ops::Add(scope.WithOpName("while/add"), identity, one); - auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(body_fn.name(), library, &result)); - InstantiationResultForTest result; - TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result)); - - EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); - EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); + EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } } } -// @function.Defun(noinline=True) -// def increment_fn(x): -// return [x + 1] -// Define the above function, and add it to the given graph. It's used as the -// while loop body in NoinlineLoopBody test. -Status AddNoinlineFunctionToGraph(const string& node_name, Graph* graph) { +FunctionDef GetNoinlineFunctionDef() { FunctionDef fdef = FunctionDefHelper::Create( "increment_fn", {"x:int32"}, {"add:int32"}, {}, { @@ -316,8 +328,17 @@ Status AddNoinlineFunctionToGraph(const string& node_name, Graph* graph) { }, {{"add", "add_0:z:0"}}); (*fdef.mutable_attr())["_noinline"].set_b(true); + return fdef; +} + +// @function.Defun(noinline=True) +// def increment_fn(x): +// return [x + 1] +// Define the above function, and add it to the given graph. It's used as the +// while loop body in NoinlineLoopBody test. +Status AddNoinlineFunctionToGraph(const string& node_name, Graph* graph) { FunctionDefLibrary fdef_lib; - *(fdef_lib.add_function()) = fdef; + *(fdef_lib.add_function()) = GetNoinlineFunctionDef(); TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(fdef_lib)); NodeDef increment_fn; increment_fn.set_name(node_name); @@ -376,55 +397,88 @@ TEST(FunctionalizeControlFlow, NoinlineLoopBody) { FunctionLibraryDefinition lookup_lib(graph.flib_def()); FunctionLibraryDefinition library(OpRegistry::Global(), {}); // Function increment_fn will be copied from lookup_lib to library. + GraphDef optimized_graph_def; + graph.ToGraphDef(&optimized_graph_def); + + *(optimized_graph_def.mutable_library()->add_function()) = + GetNoinlineFunctionDef(); + + TF_ASSERT_OK(FunctionalizeControlFlowForGraphDef( + &lookup_lib, &optimized_graph_def, &library)); TF_ASSERT_OK(FunctionalizeControlFlow(&lookup_lib, &graph, &library)); + GraphDef converted_graph_def; + graph.ToGraphDef(&converted_graph_def); - GraphDef graph_def; - graph.ToGraphDef(&graph_def); + for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) { + NameAttrList cond_fn, body_fn; + TF_ASSERT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); - NameAttrList cond_fn, body_fn; - TF_ASSERT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); + // Outer graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); + auto while_op = + ops::While(scope.WithOpName("while/LoopCond"), + std::initializer_list{source}, cond_fn, body_fn); + GraphDef expected; + TF_ASSERT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } - // Outer graph + // Body graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + TF_ASSERT_OK( + AddNoinlineFunctionToGraph(noinline_node_name, scope.graph())); + auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); + NodeDef retval; + retval.set_name("_retval0_RetVal"); + retval.set_op(FunctionLibraryDefinition::kRetOp); + *retval.add_input() = noinline_node_name; + (*retval.mutable_attr())["T"].set_type(DT_INT32); + (*retval.mutable_attr())["index"].set_i(0); + Status status; + scope.graph()->AddNode(retval, &status); + TF_ASSERT_OK(status); + + GraphDef expected; + TF_ASSERT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + // Verify that increment_fn has been copied to library. + TF_EXPECT_OK( + InstantiateFunctionForTest(body_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + // Ignore the function library when comparing the graphs. + expected.clear_library(); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } + } +} + +TEST(FunctionalizeControlFlow, MissingFunctionDefInLibrary) { + const string& noinline_node_name = "while/increment_fn"; + Graph graph(OpRegistry::Global()); { Scope scope = Scope::NewRootScope().ExitOnError(); auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); - auto while_op = - ops::While(scope.WithOpName("while/LoopCond"), - std::initializer_list{source}, cond_fn, body_fn); - GraphDef expected; - TF_ASSERT_OK(scope.ToGraphDef(&expected)); - TF_EXPECT_GRAPH_EQ(expected, graph_def); - } - - // Body graph. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto identity = ops::Identity(scope.WithOpName("while/Identity"), source); TF_ASSERT_OK(AddNoinlineFunctionToGraph(noinline_node_name, scope.graph())); - auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); - NodeDef retval; - retval.set_name("_retval0_RetVal"); - retval.set_op(FunctionLibraryDefinition::kRetOp); - *retval.add_input() = noinline_node_name; - (*retval.mutable_attr())["T"].set_type(DT_INT32); - (*retval.mutable_attr())["index"].set_i(0); - Status status; - scope.graph()->AddNode(retval, &status); - TF_ASSERT_OK(status); - - GraphDef expected; - TF_ASSERT_OK(scope.ToGraphDef(&expected)); - - InstantiationResultForTest result; - // Verify that increment_fn has been copied to library. - TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result)); - - EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); - EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); - // Ignore the function library when comparing the graphs. - expected.clear_library(); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); + TF_ASSERT_OK(scope.ToGraph(&graph)); } + + FunctionLibraryDefinition lookup_lib(graph.flib_def()); + FunctionLibraryDefinition library(OpRegistry::Global(), {}); + GraphDef graph_def; + graph.ToGraphDef(&graph_def); + graph_def.clear_library(); + + Status status = + FunctionalizeControlFlowForGraphDef(&lookup_lib, &graph_def, &library); + EXPECT_EQ(tensorflow::error::NOT_FOUND, status.code()); } // Tests functionalizing OneLoopVar where the loop value is not used post the @@ -467,65 +521,72 @@ TEST(FunctionalizeControlFlow, OneLoopVarWithoutExit) { } FunctionLibraryDefinition library(OpRegistry::Global(), {}); + GraphDef optimized_graph_def; + graph.ToGraphDef(&optimized_graph_def); + TF_ASSERT_OK( + FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library)); TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); + GraphDef converted_graph_def; + graph.ToGraphDef(&converted_graph_def); - GraphDef graph_def; - graph.ToGraphDef(&graph_def); + for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) { + NameAttrList cond_fn, body_fn; + TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); - NameAttrList cond_fn, body_fn; - TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); + // Outer graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); + auto while_op = + ops::While(scope.WithOpName("while/LoopCond"), + std::initializer_list{source}, cond_fn, body_fn); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } - // Outer graph - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); - auto while_op = - ops::While(scope.WithOpName("while/LoopCond"), - std::initializer_list{source}, cond_fn, body_fn); - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - TF_EXPECT_GRAPH_EQ(expected, graph_def); - } + // Condition graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto ten = ops::Const( + scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10); + auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); - // Condition graph - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); - auto ten = ops::Const( - scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10); - auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten); - auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(cond_fn.name(), library, &result)); - InstantiationResultForTest result; - TF_EXPECT_OK(InstantiateFunctionForTest(cond_fn.name(), library, &result)); + EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); + EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } - EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); - EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); - } + // Body graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); + auto one = ops::Const( + scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); + auto add = ops::Add(scope.WithOpName("while/add"), identity, one); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); - // Body graph. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); - auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); - auto one = ops::Const( - scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); - auto add = ops::Add(scope.WithOpName("while/add"), identity, one); - auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(body_fn.name(), library, &result)); - InstantiationResultForTest result; - TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result)); - - EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); - EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); + EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } } } @@ -608,86 +669,95 @@ TEST(FunctionalizeControlFlow, TwoLoopVars) { } FunctionLibraryDefinition library(OpRegistry::Global(), {}); + GraphDef optimized_graph_def; + graph.ToGraphDef(&optimized_graph_def); + TF_ASSERT_OK( + FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library)); TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); + GraphDef converted_graph_def; + graph.ToGraphDef(&converted_graph_def); - GraphDef graph_def; - graph.ToGraphDef(&graph_def); + for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) { + NameAttrList cond_fn, body_fn; + TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); - NameAttrList cond_fn, body_fn; - TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); + // Outer graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32); + auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32); + auto while_op = + ops::While(scope.WithOpName("while/LoopCond"), + std::initializer_list{x, y}, cond_fn, body_fn); + auto sink_x = ops::Identity(scope.WithOpName("sink_x"), while_op[0]); + auto sink_y = ops::Identity(scope.WithOpName("sink_y"), while_op[1]); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } - // Outer graph. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32); - auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32); - auto while_op = - ops::While(scope.WithOpName("while/LoopCond"), - std::initializer_list{x, y}, cond_fn, body_fn); - auto sink_x = ops::Identity(scope.WithOpName("sink_x"), while_op[0]); - auto sink_y = ops::Identity(scope.WithOpName("sink_y"), while_op[1]); - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - TF_EXPECT_GRAPH_EQ(expected, graph_def); - } - - // Condition graph. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); - auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); - auto three = ops::Const(scope.WithOpName("while/cond/three") + // Condition graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto three = ops::Const(scope.WithOpName("while/cond/three") + .WithControlDependencies(arg0.output), + 3); + auto cond_add = + ops::Add(scope.WithOpName("while/cond/Add"), arg0.output, three); + auto ten = ops::Const(scope.WithOpName("while/cond/ten") .WithControlDependencies(arg0.output), - 3); - auto cond_add = - ops::Add(scope.WithOpName("while/cond/Add"), arg0.output, three); - auto ten = ops::Const( - scope.WithOpName("while/cond/ten").WithControlDependencies(arg0.output), - 10); - auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten); - auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); + 10); + auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); - InstantiationResultForTest result; - TF_EXPECT_OK(InstantiateFunctionForTest(cond_fn.name(), library, &result)); + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(cond_fn.name(), library, &result)); - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types); - EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); - } + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types); + EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } - // Body graph. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); - auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + // Body graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); - auto identity_x = ops::Identity(scope.WithOpName("while/Identity/x"), arg0); - auto identity_y = ops::Identity(scope.WithOpName("while/Identity/y"), arg1); + auto identity_x = + ops::Identity(scope.WithOpName("while/Identity/x"), arg0); + auto identity_y = + ops::Identity(scope.WithOpName("while/Identity/y"), arg1); - auto one = ops::Const( - scope.WithOpName("while/add/one").WithControlDependencies(identity_x), - 1); - auto two = ops::Const( - scope.WithOpName("while/mul/two").WithControlDependencies(identity_x), - 2); + auto one = ops::Const( + scope.WithOpName("while/add/one").WithControlDependencies(identity_x), + 1); + auto two = ops::Const( + scope.WithOpName("while/mul/two").WithControlDependencies(identity_x), + 2); - auto add = ops::Add(scope.WithOpName("while/add"), identity_x, one); - auto mul = ops::Add(scope.WithOpName("while/mul"), identity_y, two); - auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); - auto retval1 = ops::_Retval(scope.WithOpName("_retval1_RetVal"), mul, 1); + auto add = ops::Add(scope.WithOpName("while/add"), identity_x, one); + auto mul = ops::Add(scope.WithOpName("while/mul"), identity_y, two); + auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); + auto retval1 = ops::_Retval(scope.WithOpName("_retval1_RetVal"), mul, 1); - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); - InstantiationResultForTest result; - TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result)); + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(body_fn.name(), library, &result)); - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types); - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types); + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } } } @@ -841,177 +911,192 @@ TEST(FunctionalizeControlFlow, Complex) { } FunctionLibraryDefinition library(OpRegistry::Global(), {}); + GraphDef optimized_graph_def; + graph.ToGraphDef(&optimized_graph_def); + TF_ASSERT_OK( + FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library)); TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); + GraphDef converted_graph_def; + graph.ToGraphDef(&converted_graph_def); - GraphDef graph_def; - graph.ToGraphDef(&graph_def); - - NameAttrList outer_cond_fn, outer_body_fn; - TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &outer_cond_fn, &outer_body_fn)); - - // Outer graph. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); - auto three = ops::Const(scope.WithOpName("three"), 3); - auto y = ops::Add(scope.WithOpName("y"), x, three); - - auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32, - TensorShape({})); - - auto zero = ops::Const(scope.WithOpName("outer/Const"), 0); - - auto while_op = ops::While(scope.WithOpName("outer/LoopCond"), - std::initializer_list{zero, y, x, var}, - outer_cond_fn, outer_body_fn); - auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]); - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - TF_EXPECT_GRAPH_EQ(expected, graph_def); - } - - // Outer condition graph. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); - auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); - auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); - auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); - - auto ten = ops::Const( - scope.WithOpName("outer/Less/y").WithControlDependencies(arg0.output), - 10); - auto less = ops::Less(scope.WithOpName("outer/Less_i"), arg0, ten); - auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); - - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - - InstantiationResultForTest result; + for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) { + NameAttrList outer_cond_fn, outer_body_fn; TF_EXPECT_OK( - InstantiateFunctionForTest(outer_cond_fn.name(), library, &result)); + FindWhileCondAndBody(graph_def, &outer_cond_fn, &outer_body_fn)); - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), - result.arg_types); - EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); - } + // Outer graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); + auto three = ops::Const(scope.WithOpName("three"), 3); + auto y = ops::Add(scope.WithOpName("y"), x, three); - // Outer body graph. - NameAttrList inner_cond_fn, inner_body_fn; - { - InstantiationResultForTest result; - TF_EXPECT_OK( - InstantiateFunctionForTest(outer_body_fn.name(), library, &result)); + auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32, + TensorShape({})); - // Find the inner condition and body names. - TF_EXPECT_OK( - FindWhileCondAndBody(result.gdef, &inner_cond_fn, &inner_body_fn)); + auto zero = ops::Const(scope.WithOpName("outer/Const"), 0); - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); - auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); - auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); - auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); + auto while_op = ops::While(scope.WithOpName("outer/LoopCond"), + std::initializer_list{zero, y, x, var}, + outer_cond_fn, outer_body_fn); + auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } - auto identity_i = ops::Identity(scope.WithOpName("outer/Identity"), arg0); - auto one_j = ops::Const( - scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1); - auto while_op = - ops::While(scope.WithOpName("outer/LoopCond_1"), - std::initializer_list{one_j, arg1, arg2, arg3}, - inner_cond_fn, inner_body_fn); + // Outer condition graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); + auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); - auto one_outer = ops::Const( - scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1); - auto add_i = - ops::Add(scope.WithOpName("outer/add") - .WithControlDependencies(absl::Span{ - while_op[0].op(), while_op[1].op()}), - identity_i, one_outer); + auto ten = ops::Const( + scope.WithOpName("outer/Less/y").WithControlDependencies(arg0.output), + 10); + auto less = ops::Less(scope.WithOpName("outer/Less_i"), arg0, ten); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); - auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add_i, 0); - auto retval1 = ops::_Retval(scope.WithOpName("_retval1_RetVal"), arg1, 1); - auto retval2 = ops::_Retval(scope.WithOpName("_retval2_RetVal"), arg2, 2); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(outer_cond_fn.name(), library, &result)); - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), - result.arg_types); - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32}), result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); - } + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), + result.arg_types); + EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } - // Inner condition graph. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); - auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); - auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); - auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); + // Outer body graph. + NameAttrList inner_cond_fn, inner_body_fn; + { + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(outer_body_fn.name(), library, &result)); - auto five = ops::Const( - scope.WithOpName("outer/inner/Five").WithControlDependencies(arg0), 5); - auto less_j = ops::Less(scope.WithOpName("outer/inner/Less_j"), arg0, five); - auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less_j, 0); + // Find the inner condition and body names. + TF_EXPECT_OK( + FindWhileCondAndBody(result.gdef, &inner_cond_fn, &inner_body_fn)); - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); + auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); - InstantiationResultForTest result; - TF_EXPECT_OK( - InstantiateFunctionForTest(inner_cond_fn.name(), library, &result)); + auto identity_i = ops::Identity(scope.WithOpName("outer/Identity"), arg0); + auto one_j = ops::Const( + scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1); + auto while_op = + ops::While(scope.WithOpName("outer/LoopCond_1"), + std::initializer_list{one_j, arg1, arg2, arg3}, + inner_cond_fn, inner_body_fn); - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), - result.arg_types); - EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); - } + auto one_outer = ops::Const( + scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), + 1); + auto add_i = + ops::Add(scope.WithOpName("outer/add") + .WithControlDependencies(absl::Span{ + while_op[0].op(), while_op[1].op()}), + identity_i, one_outer); - // Inner body graph. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); - auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); - auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); - auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); + auto retval0 = + ops::_Retval(scope.WithOpName("_retval0_RetVal"), add_i, 0); + auto retval1 = ops::_Retval(scope.WithOpName("_retval1_RetVal"), arg1, 1); + auto retval2 = ops::_Retval(scope.WithOpName("_retval2_RetVal"), arg2, 2); - auto identity_j = - ops::Identity(scope.WithOpName("outer/inner/Identity_j"), arg0); - auto identity_k = - ops::Identity(scope.WithOpName("outer/inner/Identity_k"), arg1); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); - auto mul_jk = - ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k); - auto add_jkx = ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, arg2); - auto assign = ops::AssignAddVariableOp( - scope.WithOpName("outer/inner/assign_add"), arg3, add_jkx); + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), + result.arg_types); + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32}), + result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } - auto one = ops::Const( - scope.WithOpName("outer/inner/One") - .WithControlDependencies( - absl::Span{assign.operation}), - 1); - auto add_j = - ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one); + // Inner condition graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); + auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); - auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add_j, 0); - auto retval1 = - ops::_Retval(scope.WithOpName("_retval1_RetVal"), identity_k, 1); - auto retval2 = ops::_Retval(scope.WithOpName("_retval2_RetVal"), arg2, 2); + auto five = ops::Const( + scope.WithOpName("outer/inner/Five").WithControlDependencies(arg0), + 5); + auto less_j = + ops::Less(scope.WithOpName("outer/inner/Less_j"), arg0, five); + auto retval = + ops::_Retval(scope.WithOpName("_retval0_RetVal"), less_j, 0); - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); - InstantiationResultForTest result; - TF_EXPECT_OK( - InstantiateFunctionForTest(inner_body_fn.name(), library, &result)); + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(inner_cond_fn.name(), library, &result)); - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), - result.arg_types); - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32}), result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), + result.arg_types); + EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } + + // Inner body graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); + auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); + + auto identity_j = + ops::Identity(scope.WithOpName("outer/inner/Identity_j"), arg0); + auto identity_k = + ops::Identity(scope.WithOpName("outer/inner/Identity_k"), arg1); + + auto mul_jk = + ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k); + auto add_jkx = + ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, arg2); + auto assign = ops::AssignAddVariableOp( + scope.WithOpName("outer/inner/assign_add"), arg3, add_jkx); + + auto one = ops::Const( + scope.WithOpName("outer/inner/One") + .WithControlDependencies( + absl::Span{assign.operation}), + 1); + auto add_j = + ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one); + + auto retval0 = + ops::_Retval(scope.WithOpName("_retval0_RetVal"), add_j, 0); + auto retval1 = + ops::_Retval(scope.WithOpName("_retval1_RetVal"), identity_k, 1); + auto retval2 = ops::_Retval(scope.WithOpName("_retval2_RetVal"), arg2, 2); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(inner_body_fn.name(), library, &result)); + + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), + result.arg_types); + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32}), + result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } } } diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index d85b4f5ae0c..fa51a72aea4 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -121,7 +121,6 @@ tf_kernel_library( ":while_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/tf2xla/lib:batch_dot", "//tensorflow/compiler/tf2xla/lib:broadcast", "//tensorflow/compiler/tf2xla/lib:cholesky", "//tensorflow/compiler/tf2xla/lib:qr", @@ -144,7 +143,7 @@ tf_kernel_library( "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/compiler/xla/client/lib:math", - "//tensorflow/compiler/xla/client/lib:numeric", + "//tensorflow/compiler/xla/client/lib:matrix", "//tensorflow/compiler/xla/client/lib:pooling", "//tensorflow/compiler/xla/client/lib:prng", "//tensorflow/compiler/xla/client/lib:sorting", @@ -196,7 +195,6 @@ cc_library( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/lib:constants", - "//tensorflow/compiler/xla/client/lib:numeric", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -216,7 +214,6 @@ cc_library( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/lib:constants", - "//tensorflow/compiler/xla/client/lib:numeric", "//tensorflow/core:framework", "//tensorflow/core/kernels:bounds_check", "//tensorflow/core/kernels:conv_ops", diff --git a/tensorflow/compiler/tf2xla/kernels/arg_op.cc b/tensorflow/compiler/tf2xla/kernels/arg_op.cc index 2db2514397d..795ea09831e 100644 --- a/tensorflow/compiler/tf2xla/kernels/arg_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/arg_op.cc @@ -50,7 +50,7 @@ class XlaArgOp : public XlaOpKernel { return; } - const XlaExpression& arg = XlaContext::Get(ctx).args()[index_]; + const XlaExpression& arg = ctx->xla_context()->args()[index_]; OP_REQUIRES(ctx, arg.kind() != XlaExpression::Kind::kInvalid, errors::InvalidArgument("Invalid/missing argument expression")); ctx->SetOutputExpression(0, arg); diff --git a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc index 4cfe946b2e6..1b254e328a8 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc @@ -13,9 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" +#include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" namespace tensorflow { namespace { @@ -28,9 +30,11 @@ class BatchMatMulOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - auto result = BatchDot(ctx->Input(0), ctx->Input(1), - /*transpose_x=*/adj_x_, /*transpose_y=*/adj_y_, - /*conjugate_x=*/adj_x_, /*conjugate_y=*/adj_y_); + auto result = + xla::BatchDot(MaybeTransposeInMinorDims( + MaybeConjugate(ctx->Input(0), adj_x_), adj_x_), + MaybeTransposeInMinorDims( + MaybeConjugate(ctx->Input(1), adj_y_), adj_y_)); ctx->SetOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc index a267c0c72fc..0e2f335f335 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc @@ -115,9 +115,9 @@ class FusedBatchNormGradOp : public XlaOpKernel { // operators. For now, cast everything to the statistics type (which // may be more precise than the input type). auto grad_backprop = - XlaHelpers::ConvertElementType(b, ctx->Input(0), scale_dtype); + XlaHelpers::ConvertElementType(ctx->Input(0), scale_dtype); auto activations = - XlaHelpers::ConvertElementType(b, ctx->Input(1), scale_dtype); + XlaHelpers::ConvertElementType(ctx->Input(1), scale_dtype); auto scale = ctx->Input(2); auto mean = ctx->Input(3); auto var = ctx->Input(4); @@ -151,11 +151,11 @@ class FusedBatchNormGradOp : public XlaOpKernel { const DataType accumulation_type = XlaHelpers::SumAccumulationType(scale_dtype); auto converted = - XlaHelpers::ConvertElementType(b, grad_backprop, accumulation_type); + XlaHelpers::ConvertElementType(grad_backprop, accumulation_type); auto reduce = xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), reduction_dims); - offset_backprop = XlaHelpers::ConvertElementType(b, reduce, scale_dtype); + offset_backprop = XlaHelpers::ConvertElementType(reduce, scale_dtype); // scratch1 = rsqrt(pop_var + epsilon) auto neg_half = XlaHelpers::FloatLiteral(b, scale_dtype, -0.5); @@ -165,19 +165,18 @@ class FusedBatchNormGradOp : public XlaOpKernel { // scratch2 = sum(y_backprop * (x - mean)) auto mul = xla::Mul(grad_backprop, xla::Sub(activations, mean, {feature_index})); - converted = XlaHelpers::ConvertElementType(b, mul, accumulation_type); + converted = XlaHelpers::ConvertElementType(mul, accumulation_type); reduce = xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), reduction_dims); - auto scratch2 = XlaHelpers::ConvertElementType(b, reduce, scale_dtype); + auto scratch2 = XlaHelpers::ConvertElementType(reduce, scale_dtype); x_backprop = xla::Mul(grad_backprop, xla::Mul(scratch1, scale), {feature_index}); scale_backprop = xla::Mul(scratch1, scratch2); } - ctx->SetOutput(0, - XlaHelpers::ConvertElementType(b, x_backprop, input_dtype)); + ctx->SetOutput(0, XlaHelpers::ConvertElementType(x_backprop, input_dtype)); ctx->SetOutput(1, scale_backprop); ctx->SetOutput(2, offset_backprop); ctx->SetConstantOutput(3, Tensor()); diff --git a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc index 41f540506ba..e7f369b761f 100644 --- a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc @@ -107,11 +107,11 @@ class BiasAddGradOp : public XlaOpKernel { const DataType accumulation_type = XlaHelpers::SumAccumulationType(input_type(0)); auto converted = - XlaHelpers::ConvertElementType(b, ctx->Input(0), accumulation_type); + XlaHelpers::ConvertElementType(ctx->Input(0), accumulation_type); auto reduce = xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), reduce_dims); - ctx->SetOutput(0, XlaHelpers::ConvertElementType(b, reduce, input_type(0))); + ctx->SetOutput(0, XlaHelpers::ConvertElementType(reduce, input_type(0))); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index 47e517a6576..5e9280c1fe6 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -43,6 +43,9 @@ namespace { const std::vector& extend_dimensions) override { \ xla::XlaBuilder* b = ctx->builder(); \ (void)b; \ + (void)lhs_shape; \ + (void)rhs_shape; \ + (void)extend_dimensions; \ return HLO; \ } \ }; \ @@ -103,23 +106,23 @@ static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, XLA_MAKE_BINARY(FloorDiv, FloorDivImpl(b, input_type(0), lhs, rhs, broadcast_helper)); -static xla::XlaOp XlogyImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, - xla::XlaOp y, const BCast& broadcast_helper) { +xla::XlaOp XlogyImpl(xla::XlaOp x, xla::XlaOp y, + const BCast& broadcast_helper) { std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); - auto zero = XlaHelpers::Zero(b, dtype); + auto zero = xla::ZerosLike(x); auto is_zero = xla::Eq(x, zero); return xla::Select(is_zero, zero, xla::Mul(x, xla::Log(y))); } -XLA_MAKE_BINARY(Xlogy, XlogyImpl(b, input_type(0), lhs, rhs, broadcast_helper)); +XLA_MAKE_BINARY(Xlogy, XlogyImpl(lhs, rhs, broadcast_helper)); -static xla::XlaOp XdivyImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, - xla::XlaOp y, const BCast& broadcast_helper) { +xla::XlaOp XdivyImpl(xla::XlaOp x, xla::XlaOp y, + const BCast& broadcast_helper) { std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); - auto zero = XlaHelpers::Zero(b, dtype); + auto zero = xla::ZerosLike(x); auto is_zero = xla::Eq(x, zero); return xla::Select(is_zero, zero, xla::Div(x, y)); } -XLA_MAKE_BINARY(Xdivy, XdivyImpl(b, input_type(0), lhs, rhs, broadcast_helper)); +XLA_MAKE_BINARY(Xdivy, XdivyImpl(lhs, rhs, broadcast_helper)); // Implementation of FloorMod. Pseudo-code: // T trunc_mod = std::fmod(x, y); diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc index ad85940920e..7199b9b6feb 100644 --- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc @@ -21,10 +21,13 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/prng.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" namespace tensorflow { namespace { @@ -57,11 +60,9 @@ class CategoricalOp : public XlaOpKernel { const int64 batch_size = logits_shape.dim_size(0); const int64 num_classes = logits_shape.dim_size(1); - xla::XlaBuilder* builder = ctx->builder(); - xla::Shape uniform_shape; int class_dimension; - if (num_samples > 1) { + if (num_samples != 1) { std::array uniform_shape_array = { {batch_size, num_samples, num_classes}}; xla::PrimitiveType uniform_xla_type; @@ -83,16 +84,16 @@ class CategoricalOp : public XlaOpKernel { xla::ShapeUtil::MakeShape(uniform_xla_type, uniform_shape_array); class_dimension = 1; } - xla::XlaOp uniforms = - xla::RngUniform(XlaHelpers::Zero(builder, input_type(0)), - XlaHelpers::One(builder, input_type(0)), uniform_shape); + xla::PrimitiveType type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(0), &type)); + xla::XlaOp log_uniforms = GetLogUniforms(uniform_shape, type, ctx); // Use Gumbel softmax trick to generate categorical samples. // See: // https://hips.seas.harvard.edu/blog/2013/04/06/the-gumbel-max-trick-for-discrete-distributions/ // TODO(b/68769470): Switch to using a cumulative sum approach. auto softmax_entries = - xla::Sub(logits, xla::Log(-xla::Log(uniforms)), + xla::Sub(logits, log_uniforms, /*broadcast_dimensions=*/{0, class_dimension}); xla::PrimitiveType xla_output_type; @@ -107,6 +108,16 @@ class CategoricalOp : public XlaOpKernel { ctx->SetOutput(0, argmax); } + virtual xla::XlaOp GetLogUniforms(xla::Shape uniform_shape, + xla::PrimitiveType type, + XlaOpKernelContext* ctx) { + xla::XlaBuilder* builder = ctx->builder(); + auto uniforms = + xla::RngUniform(XlaHelpers::Zero(builder, input_type(0)), + XlaHelpers::One(builder, input_type(0)), uniform_shape); + return xla::Log(-xla::Log(uniforms)); + } + private: TF_DISALLOW_COPY_AND_ASSIGN(CategoricalOp); }; @@ -115,5 +126,48 @@ class CategoricalOp : public XlaOpKernel { REGISTER_XLA_OP(Name("Multinomial").CompileTimeConstantInput("num_samples"), CategoricalOp); +class StatelessCategoricalOp : public CategoricalOp { + public: + explicit StatelessCategoricalOp(OpKernelConstruction* ctx) + : CategoricalOp(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } + + xla::XlaOp GetLogUniforms(xla::Shape uniform_shape, xla::PrimitiveType type, + XlaOpKernelContext* ctx) override { + xla::XlaOp seed = ctx->Input(2); + auto seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); + auto seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {}); + + xla::XlaBuilder* builder = ctx->builder(); + if (uniform_shape.element_type() == xla::BF16) { + uniform_shape.set_element_type(xla::F32); + } + auto uniforms = xla::StatelessRngUniform( + {seed0, seed1}, uniform_shape, XlaHelpers::Zero(builder, DT_FLOAT), + XlaHelpers::One(builder, DT_FLOAT)); + return xla::ConvertElementType(xla::Log(-xla::Log(uniforms)), type); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape seed_shape = ctx->InputShape(2); + OP_REQUIRES(ctx, seed_shape.dims() == 1 && seed_shape.dim_size(0) == 2, + errors::InvalidArgument("seed must have shape [2], not ", + seed_shape.DebugString())); + CategoricalOp::Compile(ctx); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(StatelessCategoricalOp); +}; + +REGISTER_XLA_OP(Name("StatelessMultinomial") + .CompileTimeConstantInput("num_samples") + .TypeConstraint("T", {DT_FLOAT, DT_BFLOAT16}) + .TypeConstraint("Tseed", DT_INT32), + StatelessCategoricalOp); + } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc index c9a1be49406..641fefafb35 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/node_def_util.h" @@ -65,60 +64,63 @@ xla::Shape ExpandedFilterShapeForDepthwiseConvolution(const xla::Shape& shape) { // 0 0 1 1 0 0 0 0 1 1 0 0 // 0 0 0 0 1 1 0 0 0 0 1 1 // -// The first step is to create a one tensor, A, that is [3] -// 0 1 2 +// The first step is to create a iota A with iota_dimension = 2 +// 0 0 0 0 0 0 0 0 0 0 0 0 +// 1 1 1 1 1 1 1 1 1 1 1 1 +// 2 2 2 2 2 2 2 2 2 2 2 2 // -// and another tensor, B, that is [3 * 2] -// 0 1 2 3 4 5 +// 0 0 0 0 0 0 0 0 0 0 0 0 +// 1 1 1 1 1 1 1 1 1 1 1 1 +// 2 2 2 2 2 2 2 2 2 2 2 2 // -// and divide B it by 2 to get -// 0 0 1 1 2 2 +// and another iota B with iota_dimension = 3 +// 0 1 2 3 4 5 0 1 2 3 4 5 +// 0 1 2 3 4 5 0 1 2 3 4 5 +// 0 1 2 3 4 5 0 1 2 3 4 5 // -// then we broadcast the B to [2, 2, 3, 3 * 2] -// 0 0 1 1 2 2 0 0 1 1 2 2 -// 0 0 1 1 2 2 0 0 1 1 2 2 -// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 1 2 3 4 5 0 1 2 3 4 5 +// 0 1 2 3 4 5 0 1 2 3 4 5 +// 0 1 2 3 4 5 0 1 2 3 4 5 // -// 0 0 1 1 2 2 0 0 1 1 2 2 -// 0 0 1 1 2 2 0 0 1 1 2 2 -// 0 0 1 1 2 2 0 0 1 1 2 2 +// and divide B by 2 to get +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 // -// Finally compare A and broadcasted B in dimension 2 amd return the result at -// the beginning of the comment. +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// +// Finally compare A and B and return the result at the beginning of the +// comment. xla::XlaOp CreateExpandedFilterMask(const xla::Shape& filter_shape, xla::XlaBuilder* builder) { xla::Shape expanded_filter_shape = ExpandedFilterShapeForDepthwiseConvolution(filter_shape); int64 depthwise_multiplier = filter_shape.dimensions(filter_shape.dimensions_size() - 1); - int64 input_feature = - filter_shape.dimensions(filter_shape.dimensions_size() - 2); - // Create a M sized linspace and an M*N sized linspace that will be - // broadcasted into perpendicular dimensions and compared. - xla::XlaOp input_feature_iota = xla::Iota(builder, xla::S32, input_feature); - xla::XlaOp expanded_feature_iota = - xla::Iota(builder, xla::S32, input_feature * depthwise_multiplier); + // Create two iotas with the shape of the expanded filter, one of them with + // the iota dimension chosen as the feature dimension, and the other a iota + // with the iota dimension chosen as the expanded output feature dimension. + std::vector iota_dimensions(expanded_filter_shape.dimensions().begin(), + expanded_filter_shape.dimensions().end()); + xla::Shape iota_shape = xla::ShapeUtil::MakeShape(xla::S32, iota_dimensions); + xla::XlaOp input_feature_iota = xla::Iota( + builder, iota_shape, /*iota_dimension=*/iota_dimensions.size() - 2); + xla::XlaOp expanded_feature_iota = xla::Iota( + builder, iota_shape, /*iota_dimension=*/iota_dimensions.size() - 1); - // Divide the M*N sized linspace by the depthwise_multiplier to create - // [0 0 1 1 2 2] in the example in the function comment. + // Divide 'expanded_feature_iota' by the depthwise_multiplier to create + // [0 0 1 1 2 2] ... in the example in the function comment. expanded_feature_iota = xla::Div(expanded_feature_iota, XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32, depthwise_multiplier)); - // Broadcast the N*M linspace to [H, W, ..., M, M*N]. - std::vector expanded_feature_broadcast_dims( - expanded_filter_shape.dimensions().begin(), - expanded_filter_shape.dimensions().end()); - expanded_feature_broadcast_dims.pop_back(); - auto broadcasted_expanded_feature_iota = - xla::Broadcast(expanded_feature_iota, expanded_feature_broadcast_dims); - - // Compare the broadcasted linspace to the input feature linspace in the - // input feature dimension to create a diagonal predicate. - return xla::Eq(broadcasted_expanded_feature_iota, input_feature_iota, - {expanded_filter_shape.dimensions_size() - 2}); + // Compare 'input_feature_iota' with 'expanded_feature_iota' to create a + // diagonal predicate. + return xla::Eq(expanded_feature_iota, input_feature_iota); } // Reshapes a filter of shape [H, W, ..., M, N] to [H, W, ..., 1, M*N]. Used to diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index d820528a430..eafdba876ae 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/node_def_util.h" diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc index 49c12fc2320..ee79cbc70da 100644 --- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc index b2f6ef43fa9..6e6ba21daf5 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc @@ -113,8 +113,20 @@ class DynamicStitchOp : public XlaOpKernel { } } int number_of_indices = max_index + 1; - OP_REQUIRES(ctx, number_of_indices > 0, - errors::InvalidArgument("no indices supplied")); + int64 result_rank = 1 + data0_shape.dims() - indices0_shape.dims(); + if (number_of_indices == 0) { + std::vector result_shape(result_rank); + for (int d = indices0_shape.dims(); d < data0_shape.dims(); d++) { + result_shape[d - indices0_shape.dims() + 1] = data0_shape.dim_size(d); + } + xla::PrimitiveType element_type = + ctx->input_xla_type(ctx->num_inputs() - 1); + xla::Literal empty_literal = xla::Literal::CreateFromShape( + xla::ShapeUtil::MakeShape(element_type, result_shape)); + ctx->SetOutput(0, xla::ConstantLiteral(ctx->builder(), empty_literal)); + return; + } + // Construct the reverse mapping, for each index, of which slice of which // input it comes from. std::vector src_input_vector(number_of_indices); @@ -157,12 +169,9 @@ class DynamicStitchOp : public XlaOpKernel { // Set up the vectors for slicing: the first dimension will vary // slice by slice, and the rest take the full common extra shape. - std::vector slice_start(1 + data0_shape.dims() - - indices0_shape.dims()); - std::vector slice_limit(1 + data0_shape.dims() - - indices0_shape.dims()); - std::vector stride(1 + data0_shape.dims() - indices0_shape.dims(), - 1); + std::vector slice_start(result_rank); + std::vector slice_limit(result_rank); + std::vector stride(result_rank, 1); for (int d = indices0_shape.dims(); d < data0_shape.dims(); d++) { slice_limit[1 + d - indices0_shape.dims()] = data0_shape.dim_size(d); } diff --git a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc index c68b0bfd796..29687c7b82f 100644 --- a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc @@ -17,7 +17,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/util/tensor_format.h" diff --git a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc index cdba6680dee..142be030f73 100644 --- a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc @@ -260,19 +260,19 @@ class FakeQuantWithMinMaxVarsGradOp : public XlaOpKernel { xla::XlaOp below_min = xla::Lt(input, nudged_input_min); xla::XlaOp select1 = xla::Select(below_min, gradient, zeroes); xla::XlaOp reduce1 = xla::ReduceAll( - XlaHelpers::ConvertElementType(b, select1, accumulation_type), + XlaHelpers::ConvertElementType(select1, accumulation_type), XlaHelpers::Zero(b, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type)); - xla::XlaOp output1 = XlaHelpers::ConvertElementType(b, reduce1, data_type); + xla::XlaOp output1 = XlaHelpers::ConvertElementType(reduce1, data_type); ctx->SetOutput(1, output1); xla::XlaOp above_max = xla::Gt(input, nudged_input_max); xla::XlaOp select2 = xla::Select(above_max, gradient, zeroes); xla::XlaOp reduce2 = xla::ReduceAll( - XlaHelpers::ConvertElementType(b, select2, accumulation_type), + XlaHelpers::ConvertElementType(select2, accumulation_type), XlaHelpers::Zero(b, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type)); - xla::XlaOp output2 = XlaHelpers::ConvertElementType(b, reduce2, data_type); + xla::XlaOp output2 = XlaHelpers::ConvertElementType(reduce2, data_type); ctx->SetOutput(2, output2); } diff --git a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc index 9b06357d9b7..6df8b5367d2 100644 --- a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" @@ -50,11 +51,36 @@ class GenericFftOp : public XlaOpKernel { errors::InvalidArgument("input must be at least 1 dimensional")); std::vector fft_length; + xla::XlaOp input = ctx->Input(0); if (fft_type_ == FftType::RFFT || fft_type_ == FftType::IRFFT) { OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &fft_length)); OP_REQUIRES(ctx, fft_length.size() == fft_rank_, errors::InvalidArgument("fft_length must be length ", fft_rank_, " vector")); + + // Zero pad or truncate the axes we're doing FFT on. + absl::InlinedVector slice_sizes = input_shape.dim_sizes(); + std::vector> padding_sizes(slice_sizes.size()); + std::vector expected_sizes = fft_length; + // IRFFT wants the innermost axis to be n / 2 + 1. + if (fft_type_ == FftType::IRFFT) { + expected_sizes[fft_rank_ - 1] = fft_length[fft_rank_ - 1] / 2 + 1; + } + for (int i = 0; i < fft_rank_; i++) { + int index = input_shape.dims() - fft_rank_ + i; + if (input_shape.dim_size(index) > expected_sizes[i]) { + slice_sizes[index] = expected_sizes[i]; + } else { + padding_sizes[index].second = + expected_sizes[i] - input_shape.dim_size(index); + } + } + + std::vector start_indices(input_shape.dims(), 0); + std::vector strides(input_shape.dims(), 1); + input = xla::Pad(xla::Slice(input, start_indices, slice_sizes, strides), + XlaHelpers::Zero(ctx->builder(), ctx->input_type(0)), + xla::MakeEdgePaddingConfig(padding_sizes)); } else { // Innermost axis provides the FFT length. for (int i = 0; i < fft_rank_; i++) { @@ -63,7 +89,7 @@ class GenericFftOp : public XlaOpKernel { } } - xla::XlaOp fft = xla::Fft(ctx->Input(0), fft_type_, fft_length); + xla::XlaOp fft = xla::Fft(input, fft_type_, fft_length); ctx->SetOutput(0, fft); } diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index 56da50f1408..b5e08391255 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -72,7 +72,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { arg.shape = resource->shape(); OP_REQUIRES(ctx, arg.initialized, errors::Unimplemented("Uninitialized arguments: ", arg.name)); - arg.tensor_array_size = resource->tensor_array_size(); + arg.max_array_size = resource->max_array_size(); for (const auto& gradient : resource->tensor_array_gradients()) { arg.tensor_array_gradients.insert(gradient.first); } diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc index b49b2516d8b..e9bb0a77e99 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc @@ -191,12 +191,11 @@ class AdjustContrastOpV2 : public XlaOpKernel { DataType type = context->input_type(0); const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); - auto converted = - XlaHelpers::ConvertElementType(b, input, accumulation_type); + auto converted = XlaHelpers::ConvertElementType(input, accumulation_type); auto reduce = xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), *context->GetOrCreateAdd(accumulation_type), {height_dim, width_dim}); - auto output = XlaHelpers::ConvertElementType(b, reduce, type); + auto output = XlaHelpers::ConvertElementType(reduce, type); output = xla::Div(output, XlaHelpers::FloatLiteral(b, type, height * width)); diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index 0c7ca602bfa..5a10c52ba8b 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc index e310db2162d..e2c05b648bb 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc @@ -30,7 +30,9 @@ limitations under the License. namespace tensorflow { namespace { -// The logic below uses a custom-call to implement argmax. +// The logic below uses a custom-call to implement argmax when possible. When +// custom-call is not allowed or input shapes are not supported, this kernel +// falls back to using XLA HLO native ArgMax. // // Also see b/29507024 for first-class XLA support for indexing ops. class ArgMaxCustomCallOp : public XlaOpKernel { @@ -50,27 +52,40 @@ class ArgMaxCustomCallOp : public XlaOpKernel { // overhead, when compiling ahead-of-time. int64 dim; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &dim)); - OP_REQUIRES(ctx, dim >= 0, errors::InvalidArgument("dim must be >= 0")); - OP_REQUIRES( - ctx, dim < input_shape.dims(), - errors::InvalidArgument("dim must be < input rank (", - input_shape.dims(), "), but got: ", dim)); - const int64 dim_size = input_shape.dim_size(dim); - OP_REQUIRES(ctx, dim_size > 0, + + const int input_dims = input_shape.dims(); + const int axis = dim < 0 ? dim + input_dims : dim; + OP_REQUIRES(ctx, axis >= 0 && axis < input_dims, + errors::InvalidArgument("Expected dimension in the range [", + -input_dims, ", ", input_dims, + "), but got ", dim)); + + const int64 axis_size = input_shape.dim_size(axis); + OP_REQUIRES(ctx, axis_size > 0, errors::InvalidArgument( "Reduction axis ", dim, " is empty in shape: ", input_shape.DebugString())); - // The output shape is the input shape contracted along dim. - TensorShape output_shape; - for (int d = 0; d < input_shape.dims() - 1; ++d) { - output_shape.AddDim(input_shape.dim_size((d < dim) ? d : d + 1)); + const DataType dtype = output_type(0); + xla::PrimitiveType output_type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype, &output_type)); + + // Fall back to XLA ArgMax HLO when CustomCall is not allowed or when input + // shape isn't supported. + if (!ctx->compiler()->options().allow_cpu_custom_calls || + (input_dims != 1 && input_dims != 2)) { + xla::XlaOp output = XlaHelpers::ArgMax(ctx->Input(0), output_type, axis); + ctx->SetOutput(0, output); + return; + } + + xla::XlaOp output; + // The output shape is the input shape contracted along axis. + TensorShape output_shape; + for (int d = 0; d < input_shape.dims() - 1; ++d) { + output_shape.AddDim(input_shape.dim_size((d < axis) ? d : d + 1)); } - // For now we use a custom-call, only for the 1d and 2d cases. - OP_REQUIRES(ctx, XlaContext::Get(ctx).allow_cpu_custom_calls(), - errors::InvalidArgument( - "ArgMax implementation requires a CustomCall on CPU")); xla::XlaBuilder& b = *ctx->builder(); // XLA passes to the function, so it is not included here. @@ -84,7 +99,7 @@ class ArgMaxCustomCallOp : public XlaOpKernel { args.push_back(xla::ConstantLiteral( &b, xla::LiteralUtil::CreateR1(output_shape.dim_sizes()))); args.push_back( - xla::ConstantLiteral(&b, xla::LiteralUtil::CreateR0(dim))); + xla::ConstantLiteral(&b, xla::LiteralUtil::CreateR0(axis))); } // The argmax function expects row-major layout. @@ -101,24 +116,15 @@ class ArgMaxCustomCallOp : public XlaOpKernel { } // Tell XLA to call the custom code, defined in - // index_ops_kernel_argmax_float_1d.cc. - xla::XlaOp output; - switch (input_shape.dims()) { - case 1: - output = xla::CustomCallWithLayout(&b, "argmax_float_1d_xla_impl", args, - xla_shape, arg_shapes); - break; - case 2: - output = xla::CustomCallWithLayout(&b, "argmax_float_2d_xla_impl", args, - xla_shape, arg_shapes); - break; - default: - OP_REQUIRES(ctx, false, - errors::Unimplemented( - "Argmax is only implemented for 1d and 2d tensors" - ", but got shape: ", - input_shape.DebugString())); + // index_ops_kernel_argmax_float_{1, 2}d.cc. + if (input_dims == 1) { + output = xla::CustomCallWithLayout(&b, "argmax_float_1d_xla_impl", args, + xla_shape, arg_shapes); + } else { + output = xla::CustomCallWithLayout(&b, "argmax_float_2d_xla_impl", args, + xla_shape, arg_shapes); } + output = xla::ConvertElementType(output, output_type); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc index f028e361bcc..93f029731c3 100644 --- a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc @@ -37,12 +37,11 @@ class L2LossOp : public XlaOpKernel { // output = sum(t ** 2) / 2 const DataType accumulation_type = XlaHelpers::SumAccumulationType(dtype); - auto t = - XlaHelpers::ConvertElementType(b, ctx->Input(0), accumulation_type); + auto t = XlaHelpers::ConvertElementType(ctx->Input(0), accumulation_type); auto square = xla::Mul(t, t); auto reduce = xla::Reduce(square, XlaHelpers::Zero(b, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), dims); - auto deconverted = XlaHelpers::ConvertElementType(b, reduce, dtype); + auto deconverted = XlaHelpers::ConvertElementType(reduce, dtype); auto two = XlaHelpers::IntegerLiteral(b, dtype, 2); ctx->SetOutput(0, xla::Div(deconverted, two)); } diff --git a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc index 87ee2d3aede..987901d82b3 100644 --- a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc @@ -49,16 +49,14 @@ class LRNOp : public XlaOpKernel { // We use a window of depth_radius_ * 2 + 1, to account for the current // element and a depth_radius_ on either side. auto accumulation_type = XlaHelpers::SumAccumulationType(input_type(0)); - auto converted = - XlaHelpers::ConvertElementType(builder, input, accumulation_type); + auto converted = XlaHelpers::ConvertElementType(input, accumulation_type); auto squared = xla::Mul(converted, converted); auto reduce = xla::ReduceWindow( squared, XlaHelpers::Zero(builder, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1}, /* window_strides = */ {1, 1, 1, 1}, xla::Padding::kSame); - auto sqr_sum = - XlaHelpers::ConvertElementType(builder, reduce, input_type(0)); + auto sqr_sum = XlaHelpers::ConvertElementType(reduce, input_type(0)); auto scale = xla::Pow( xla::Add(xla::ConstantR0(builder, bias_), @@ -138,15 +136,14 @@ class LRNGradOp : public XlaOpKernel { auto accumulation_type = XlaHelpers::SumAccumulationType(input_type(0)); auto converted = - XlaHelpers::ConvertElementType(builder, in_image, accumulation_type); + XlaHelpers::ConvertElementType(in_image, accumulation_type); auto squared = xla::Mul(converted, converted); auto reduce = xla::ReduceWindow( squared, XlaHelpers::Zero(builder, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1}, /* window_strides = */ {1, 1, 1, 1}, xla::Padding::kSame); - auto sqr_sum = - XlaHelpers::ConvertElementType(builder, reduce, input_type(0)); + auto sqr_sum = XlaHelpers::ConvertElementType(reduce, input_type(0)); auto norm = xla::Add(xla::ConstantR0(builder, bias_), @@ -157,15 +154,13 @@ class LRNGradOp : public XlaOpKernel { xla::Div(out_image, norm)), in_grads); - auto converted_dy = - XlaHelpers::ConvertElementType(builder, dy, accumulation_type); + auto converted_dy = XlaHelpers::ConvertElementType(dy, accumulation_type); auto dy_reduce = xla::ReduceWindow( converted_dy, XlaHelpers::Zero(builder, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1}, /* window_strides = */ {1, 1, 1, 1}, xla::Padding::kSame); - auto dy_reduced = - XlaHelpers::ConvertElementType(builder, dy_reduce, input_type(0)); + auto dy_reduced = XlaHelpers::ConvertElementType(dy_reduce, input_type(0)); xla::XlaOp gradients = xla::Add( xla::Mul(in_image, dy_reduced), diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc index 8dfd7de591c..2dd0a710e47 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { @@ -61,11 +61,11 @@ class MatrixBandPartOp : public XlaOpKernel { // Compute 'offset', which is how many diagonals we are above/below the // diagonal. - xla::XlaOp iota_m = xla::Iota(builder, index_xla_type, m); - xla::XlaOp iota_n = xla::Iota(builder, index_xla_type, n); + xla::Shape iota_shape = xla::ShapeUtil::MakeShape(index_xla_type, {m, n}); + xla::XlaOp iota_m = xla::Iota(builder, iota_shape, /*iota_dimension=*/0); + xla::XlaOp iota_n = xla::Iota(builder, iota_shape, /*iota_dimension=*/1); - auto offset = xla::Sub(xla::Broadcast(iota_n, {m}), iota_m, - /*broadcast_dimensions=*/{0}); + auto offset = xla::Sub(iota_n, iota_m); // If num_lower or num_upper are negative, include all lower/upper // diagonals. diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc index c0ca881ff82..4f980b6d14e 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/permute_op.cc b/tensorflow/compiler/tf2xla/kernels/permute_op.cc index 94b51e1a586..71920bf5c1e 100644 --- a/tensorflow/compiler/tf2xla/kernels/permute_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/permute_op.cc @@ -75,8 +75,7 @@ class DataFormatVecPermuteOp : public XlaOpKernel { } auto keys = xla::ConstantR1(builder, absl::Span(dst_indices)); if (input_rank == 2) { - keys = xla::BroadcastInDim( - keys, xla::ShapeUtil::MakeShape(xla::S32, {4, 2}), {0}); + keys = xla::BroadcastInDim(keys, {4, 2}, {0}); } auto sorted = xla::Sort(keys, {ctx->Input(0)}, 0); auto output = xla::GetTupleElement(sorted, 1); diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index a259da6383d..06c6cc37ec9 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -152,7 +152,12 @@ class MaxPoolOp : public PoolingOp { public: MaxPoolOp(OpKernelConstruction* ctx, int num_spatial_dims) : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims, - /*reduction_type=*/ctx->input_type(0)) {} + /*reduction_type=*/ctx->input_type(0)) { + string data_format_str; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); + OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), + errors::InvalidArgument("Invalid data format")); + } void Compile(XlaOpKernelContext* ctx) override { auto ksize_or_error = GetKernelSize(ctx); @@ -180,10 +185,6 @@ class MaxPool2DOp : public MaxPoolOp { public: explicit MaxPool2DOp(OpKernelConstruction* ctx) : MaxPoolOp(ctx, /*num_spatial_dims=*/2) { - string data_format_str; - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); - OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), - errors::InvalidArgument("Invalid data format")); } }; REGISTER_XLA_OP(Name("MaxPool"), MaxPool2DOp); @@ -204,7 +205,12 @@ class AvgPoolOp : public PoolingOp { AvgPoolOp(OpKernelConstruction* ctx, int num_spatial_dims) : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims, /*reduction_type=*/ - XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} + XlaHelpers::SumAccumulationType(ctx->input_type(0))) { + string data_format_str; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); + OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), + errors::InvalidArgument("Invalid data format")); + } void Compile(XlaOpKernelContext* ctx) override { auto ksize_or_error = GetKernelSize(ctx); @@ -241,10 +247,6 @@ class AvgPool2DOp : public AvgPoolOp { public: explicit AvgPool2DOp(OpKernelConstruction* ctx) : AvgPoolOp(ctx, /*num_spatial_dims=*/2) { - string data_format_str; - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); - OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), - errors::InvalidArgument("Invalid data format")); } }; REGISTER_XLA_OP(Name("AvgPool"), AvgPool2DOp); @@ -390,6 +392,11 @@ class AvgPoolGradOp : public XlaOpKernel { OP_REQUIRES(ctx, ksize_[0] == 1 && stride_[0] == 1, errors::Unimplemented( "Pooling is not yet supported on the batch dimension.")); + + string data_format; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); + OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); } int num_dims() const { return num_spatial_dims_ + 2; } @@ -449,10 +456,6 @@ class AvgPool2DGradOp : public AvgPoolGradOp { public: explicit AvgPool2DGradOp(OpKernelConstruction* ctx) : AvgPoolGradOp(ctx, /*num_spatial_dims=*/2) { - string data_format; - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); - OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), - errors::InvalidArgument("Invalid data format")); } }; REGISTER_XLA_OP( diff --git a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc index 6f4ed496a17..7fe102428db 100644 --- a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/core/platform/macros.h" @@ -26,12 +27,26 @@ limitations under the License. namespace tensorflow { namespace { +enum QuantizerRoundMode { + // Round half up: if the fraction of y is exactly 0.5, then + // round(y) = y + 0.5 + // E.g., -5.5 gets rounded to -5, -5.4 goes to -5, + // 5.4 goes to 5, and 5.5 goes to 6. + ROUND_HALF_UP, + // Round half to even: if the fraction of y is exactly 0.5, then round(y) is + // the nearest even integer to y. + // E.g., 23.5 gets rounded to 24, 24.5 gets rounded to 24, while -23.5 becomes + // -24, and -24.5 gets rounded to 24. + ROUND_HALF_TO_EVEN, +}; + class QuantizeAndDequantizeOp : public XlaOpKernel { public: explicit QuantizeAndDequantizeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("signed_input", &signed_input_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("range_given", &range_given_)); + round_mode_ = ROUND_HALF_TO_EVEN; } void Compile(XlaOpKernelContext* ctx) override { @@ -117,8 +132,17 @@ class QuantizeAndDequantizeOp : public XlaOpKernel { // in that case they were measured from the tensor. input = Clamp(min_range, input, max_range); } - xla::XlaOp result = - Floor((input - min_range) * scale + half) * inverse_scale + min_range; + xla::XlaOp result; + switch (round_mode_) { + case ROUND_HALF_TO_EVEN: { + result = xla::RoundToEven(input * scale) * inverse_scale; + break; + } + case ROUND_HALF_UP: { + result = Floor(input * scale + half) * inverse_scale; + break; + } + } ctx->SetOutput(0, result); } @@ -126,6 +150,7 @@ class QuantizeAndDequantizeOp : public XlaOpKernel { int64 num_bits_ = -1; bool signed_input_; bool range_given_; + QuantizerRoundMode round_mode_; }; class QuantizeAndDequantizeV2Op : public QuantizeAndDequantizeOp { @@ -136,6 +161,20 @@ class QuantizeAndDequantizeV2Op : public QuantizeAndDequantizeOp { OP_REQUIRES(ctx, num_bits_ > 0 && num_bits_ < (signed_input_ ? 62 : 63), errors::InvalidArgument("num_bits is out of range: ", num_bits_, " with signed_input_ ", signed_input_)); + string round_mode_string; + OP_REQUIRES_OK(ctx, ctx->GetAttr("round_mode", &round_mode_string)); + OP_REQUIRES( + ctx, + (round_mode_string == "HALF_UP" || round_mode_string == "HALF_TO_EVEN"), + errors::InvalidArgument("Round mode string must be " + "'HALF_UP' or " + "'HALF_TO_EVEN', is '" + + round_mode_string + "'")); + if (round_mode_string == "HALF_UP") { + round_mode_ = ROUND_HALF_UP; + } else if (round_mode_string == "HALF_TO_EVEN") { + round_mode_ = ROUND_HALF_TO_EVEN; + } } }; diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index 415ce9b77ff..8822e29f7e7 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -26,7 +26,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc index 107fa62967a..65e158d64fd 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -113,12 +113,21 @@ class MeanOp : public XlaReductionOp { xla::Add(scalar_lhs, scalar_rhs); } - xla::XlaOp BuildFinalizer(xla::XlaBuilder* builder, - const xla::XlaOp& reduce_output, - int64 num_elements_reduced) override { - auto divisor = XlaHelpers::IntegerLiteral(builder, input_type(0), - num_elements_reduced); - return reduce_output / divisor; + xla::XlaOp BuildFinalizer( + xla::XlaBuilder* /*builder*/, const xla::XlaOp& input, + const xla::XlaOp& reduce_output, + const std::vector& dimensions_to_reduce) override { + if (dimensions_to_reduce.empty()) { + return reduce_output; + } + auto divisor = xla::GetDimensionSize(input, dimensions_to_reduce[0]); + for (int i = 1; i < dimensions_to_reduce.size(); i++) { + auto size = xla::GetDimensionSize(input, dimensions_to_reduce[i]); + divisor = xla::Mul(divisor, size); + } + divisor = xla::ConvertElementType(divisor, xla_reduction_type_); + return XlaHelpers::ConvertElementType(reduce_output / divisor, + input_type(0)); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h index 466e79828d1..af716eab798 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h @@ -48,13 +48,14 @@ class XlaReductionOp : public XlaOpKernel { const xla::XlaOp& scalar_rhs) = 0; // Applies a transformation to the output of the reduction. The desired - // computation should be added to 'builder'. Argument 'reduce_output' is the - // output of the reduction. 'num_elements_reduced' is the number of elements - // that contributed to the reduction. Returns the transformed reduction - // output, Defaults to returning 'reduce_output' unchanged. - virtual xla::XlaOp BuildFinalizer(xla::XlaBuilder* builder, - const xla::XlaOp& reduce_output, - int64 num_elements_reduced); + // computation should be added to 'builder'. Argument 'input' is the original + // input of the reduction; 'reduce_output' is the output of the reduction. + // Returns the transformed reduction output. Defaults to returning + // 'reduce_output' converted to the input type. + virtual xla::XlaOp BuildFinalizer( + xla::XlaBuilder* builder, const xla::XlaOp& input, + const xla::XlaOp& reduce_output, + const std::vector& dimensions_to_reduce); void Compile(XlaOpKernelContext* ctx) override; diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index 118f2798d55..2ca2a85244b 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -35,12 +35,13 @@ XlaReductionOp::XlaReductionOp(OpKernelConstruction* ctx, ctx, DataTypeToPrimitiveType(reduction_type_, &xla_reduction_type_)); } -// Unless BuildFinalizer is overridden the reduction has no -// finalizer. -xla::XlaOp XlaReductionOp::BuildFinalizer(xla::XlaBuilder* builder, - const xla::XlaOp& reduce_output, - int64 num_elements_reduced) { - return reduce_output; +// The default finalizer converts the results back into the input type. This can +// be overridden. +xla::XlaOp XlaReductionOp::BuildFinalizer( + xla::XlaBuilder* /*builder*/, const xla::XlaOp& /*input*/, + const xla::XlaOp& reduce_output, + const std::vector& /*dimensions_to_reduce*/) { + return XlaHelpers::ConvertElementType(reduce_output, input_type(0)); } void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { @@ -71,7 +72,6 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { absl::InlinedVector bitmap(data_shape.dims(), false); std::vector xla_axes; - int64 num_elements_reduced = 1LL; for (int64 i = 0; i < axes_tensor_shape.num_elements(); ++i) { int64 index = axes[i]; OP_REQUIRES(ctx, @@ -82,7 +82,6 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { index = (index + data_shape.dims()) % data_shape.dims(); bitmap[index] = true; xla_axes.push_back(index); - num_elements_reduced *= data_shape.dim_size(index); } std::vector final_shape; @@ -118,8 +117,7 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { xla::XlaComputation reduction_computation = r.Build().ConsumeValueOrDie(); auto reduce = xla::Reduce(data, initial, reduction_computation, xla_axes); - auto deconverted = XlaHelpers::ConvertElementType(b, reduce, input_type(0)); - auto finalized = BuildFinalizer(b, deconverted, num_elements_reduced); + auto finalized = BuildFinalizer(b, data, reduce, xla_axes); auto result = keep_dims_ ? xla::Reshape(finalized, final_shape) : finalized; ctx->SetOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc b/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc index 847704608fb..54d34a38abc 100644 --- a/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -44,9 +43,6 @@ namespace { using xla::XlaOp; -// TODO(b/112295522): note that sampling from image boundary is not currently -// being handled properly. - // Calculates the bilinear weight tensor, given basis ratio (px, py) of the // sampling position: // W = [(1-px)*(1-py), px*(1-py), (1-px)*py, px*py] @@ -70,11 +66,8 @@ XlaOp BilinearWeights(XlaOpKernelContext* ctx, XlaOp ratio, std::vector last_two_dims_indices = {(broadcast_dims_size - 2), (broadcast_dims_size - 1)}; - xla::Shape broadcast_shape = - xla::ShapeUtil::MakeShape(xla_type, broadcast_dims); - auto broadcast_first_term = - xla::BroadcastInDim(first_term, broadcast_shape, last_two_dims_indices); + xla::BroadcastInDim(first_term, broadcast_dims, last_two_dims_indices); // Ratio is of the same dimension as warp, which is [batch, dim_0,... dim_n, // 2], we broadcast ratio tensor to 'broadcast_dim' by keeping the @@ -85,7 +78,7 @@ XlaOp BilinearWeights(XlaOpKernelContext* ctx, XlaOp ratio, ratio_broadcast_indices.erase(ratio_broadcast_indices.end() - 2); auto broadcast_ratio = - xla::BroadcastInDim(ratio, broadcast_shape, ratio_broadcast_indices); + xla::BroadcastInDim(ratio, broadcast_dims, ratio_broadcast_indices); auto first_term_subtract_weights = broadcast_first_term - broadcast_ratio; @@ -96,7 +89,7 @@ XlaOp BilinearWeights(XlaOpKernelContext* ctx, XlaOp ratio, sign_change = xla::ConvertElementType(sign_change, xla_type); auto broadcast_sign_change = - xla::BroadcastInDim(sign_change, broadcast_shape, last_two_dims_indices); + xla::BroadcastInDim(sign_change, broadcast_dims, last_two_dims_indices); auto flipped = first_term_subtract_weights * broadcast_sign_change; @@ -232,21 +225,19 @@ XlaOp CalculateGradData(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, std::vector weights_with_channels_dims = reshaped_weights_dims; weights_with_channels_dims.push_back(data_channels); - auto weights_with_channels_shape = - xla::ShapeUtil::MakeShape(warp_type, weights_with_channels_dims); std::vector reshaped_weights_indices(reshaped_weights_dims.size()); std::iota(reshaped_weights_indices.begin(), reshaped_weights_indices.end(), 0); // The dimension is [batch, dim_0, ..., dim_n, 2, 2, data_channel]. auto broadcast_reshaped_weights = xla::BroadcastInDim( - reshaped_weights, weights_with_channels_shape, reshaped_weights_indices); + reshaped_weights, weights_with_channels_dims, reshaped_weights_indices); std::vector grad_output_indices(warp_dims_without_last_dims.size()); std::iota(grad_output_indices.begin(), grad_output_indices.end(), 0); grad_output_indices.push_back(weights_with_channels_dims.size() - 1); XlaOp broadcast_grad_output = xla::BroadcastInDim( - grad_output, weights_with_channels_shape, grad_output_indices); + grad_output, weights_with_channels_dims, grad_output_indices); auto grad_output_multiply_weights = broadcast_grad_output * broadcast_reshaped_weights; @@ -294,13 +285,10 @@ XlaOp CalculateGradWarp(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, std::vector warp_dims_without_last_dims(warp_dims.begin(), warp_dims.end() - 1); + // With dimension [batch, dim_0, ...dim_n, 4] std::vector neighbor_broadcast_dims = warp_dims_without_last_dims; neighbor_broadcast_dims.push_back(4); - // With dimension [batch, dim_0, ...dim_n, 4] - auto neighbor_broadcast_shape = - xla::ShapeUtil::MakeShape(data_type, neighbor_broadcast_dims); - // The dimension is [batch, dim_0, ... dim_n, 4, data_channels] auto neighbors_data = Gather2by2Neighbors( ctx->builder(), data, gather_indices, data_channels, warp_shape.dims()); @@ -326,7 +314,7 @@ XlaOp CalculateGradWarp(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, xla::BroadcastInDim( xla::ConvertElementType( xla::ConstantR1(ctx->builder(), {0, 0, -1, 1}), data_type), - neighbor_broadcast_shape, {last_warp_dim}), + neighbor_broadcast_dims, {last_warp_dim}), neighbors_data, dot_dims, /*precision_config=*/nullptr); // img_cxfy - img_fxfy @@ -334,7 +322,7 @@ XlaOp CalculateGradWarp(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, xla::BroadcastInDim( xla::ConvertElementType( xla::ConstantR1(ctx->builder(), {-1, 1, 0, 0}), data_type), - neighbor_broadcast_shape, {last_warp_dim}), + neighbor_broadcast_dims, {last_warp_dim}), neighbors_data, dot_dims, /*precision_config=*/nullptr); // img_cxcy - img_cxfy @@ -342,7 +330,7 @@ XlaOp CalculateGradWarp(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, xla::BroadcastInDim( xla::ConvertElementType( xla::ConstantR1(ctx->builder(), {0, -1, 0, 1}), data_type), - neighbor_broadcast_shape, {last_warp_dim}), + neighbor_broadcast_dims, {last_warp_dim}), neighbors_data, dot_dims, /*precision_config=*/nullptr); // img_fxcy - img_fxfy @@ -350,7 +338,7 @@ XlaOp CalculateGradWarp(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, xla::BroadcastInDim( xla::ConvertElementType( xla::ConstantR1(ctx->builder(), {-1, 0, 1, 0}), data_type), - neighbor_broadcast_shape, {last_warp_dim}), + neighbor_broadcast_dims, {last_warp_dim}), neighbors_data, dot_dims, /*precision_config=*/nullptr); // Slice out x and y. @@ -421,12 +409,13 @@ class ResamplerOp : public XlaOpKernel { OP_REQUIRES(ctx, warp_shape.dim_size(last_warp_dim) == 2, errors::InvalidArgument( "the last dimension of warp must be exactly size 2.")); + xla::PrimitiveType warp_type = ctx->input_xla_type(1); XlaOp data = ctx->Input("data"); XlaOp warp = ctx->Input("warp"); // Find the coordinates of the top left corner for the 2x2 region to be - // sampled from. The dimensions are (batch, dim_0, ... dim_n, 2) where the + // sampled from. The dimensions are [batch, dim_0, ... dim_n, 2] where the // last dimension of size 2 in turn is [x, y]. XlaOp top_left = xla::ConvertElementType(warp, xla::U32); @@ -457,10 +446,54 @@ class ResamplerOp : public XlaOpKernel { dot_dims.add_lhs_contracting_dimensions(warp_shape.dims() - 1); dot_dims.add_rhs_contracting_dimensions(warp_shape.dims() - 1); + // The dimension is [batch, dim_0, ...dim_n, data_channels]. auto blended_pixels = xla::DotGeneral(weights, neighbors_data, dot_dims, /*precision_config=*/nullptr); - ctx->SetOutput(0, blended_pixels); + // Handle out of boundary cases by constructing a predicate mask array based + // on the in-bound condition, and output 0 for the blended pixel value if + // out-bound. The dimension is the same as top_left: [batch, dim_0, + // ...dim_n, 2] where the last dimension of size 2 is the [x, y] coordinate. + + auto is_ge_zero = xla::Ge(warp, xla::ZerosLike(warp)); + + auto is_lt_image_size = xla::Lt( + warp, + xla::ConvertElementType( + xla::ConstantR1( + ctx->builder(), + {/*width=*/static_cast(data_shape.dim_size(2) - 1), + /*height=*/static_cast(data_shape.dim_size(1) - 1)}), + warp_type), + /*broadcast_dimensions=*/{warp_shape.dims() - 1}); + + auto is_in_bound_x_y = xla::And(is_ge_zero, is_lt_image_size); + // Reduce along last dimension. The resulting dimension is: + // [batch, dim_0, ...dim_n]. + auto is_in_bound = xla::Reduce( + is_in_bound_x_y, xla::ConstantR0(ctx->builder(), true), + xla::CreateScalarAndComputation(xla::PrimitiveType::PRED, + ctx->builder()), + {last_warp_dim}); + + // Broadcast 'is_in_bound' to the same dimension as 'blended_pixels', which + // is the dimension of the result: + // [batch, dim_0, ...dim_n, data_channels]. + auto warp_dims = warp_shape.dim_sizes(); + std::vector result_dims(warp_dims.begin(), warp_dims.end() - 1); + result_dims.push_back(data_channels); + + std::vector broadcasted_dims(warp_dims.size() - 1); + std::iota(broadcasted_dims.begin(), broadcasted_dims.end(), 0); + auto broadcasted_is_in_bound = + xla::BroadcastInDim(is_in_bound, result_dims, broadcasted_dims); + + // Set out of bound samples to zero. + auto zeros = + xla::Broadcast(xla::Zero(ctx->builder(), data_type), result_dims); + auto result = xla::Select(broadcasted_is_in_bound, blended_pixels, zeros); + + ctx->SetOutput(0, result); } }; @@ -473,6 +506,8 @@ class ResamplerGradOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &output_dtype)); } + // TODO(b/112295522): note that sampling from image boundary is not currently + // being handled properly. void Compile(XlaOpKernelContext* ctx) override { TensorShape data_shape_tf = ctx->InputShape("data"); OP_REQUIRES(ctx, data_shape_tf.dims() == 4, diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc index 6970dd0a006..e4046c79557 100644 --- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc @@ -47,8 +47,7 @@ class RetvalOp : public XlaOpKernel { // compilation. OP_REQUIRES_OK(ctx, frame->SetRetval(index_, input)); } else { - XlaContext& xla_context = XlaContext::Get(ctx); - xla_context.SetRetval(index_, ctx->InputExpression(0)); + ctx->xla_context()->SetRetval(index_, ctx->InputExpression(0)); } } diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc index 7ff3e916381..d7b38e86cc9 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc index b5fd7850bfc..4b9e1a578be 100644 --- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc @@ -39,8 +39,8 @@ namespace { // TODO(phawkins): implement double-sized windowed reductions in XLA and remove // the type constraint. -constexpr std::array kScanOpTypes = { - {DT_HALF, DT_BFLOAT16, DT_FLOAT}}; +constexpr std::array kScanOpTypes = { + {DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_INT32}}; class ScanOp : public XlaOpKernel { public: @@ -103,11 +103,10 @@ class ScanOp : public XlaOpKernel { reducer = ctx->GetOrCreateMul(dtype); } auto output = xla::ReduceWindowWithGeneralPadding( - XlaHelpers::ConvertElementType(builder, ctx->Input(0), dtype), init, - *reducer, window_dims, window_strides, + XlaHelpers::ConvertElementType(ctx->Input(0), dtype), init, *reducer, + window_dims, window_strides, /*base_dilations=*/{}, /*window_dilations=*/{}, padding); - output = - XlaHelpers::ConvertElementType(builder, output, ctx->input_type(0)); + output = XlaHelpers::ConvertElementType(output, ctx->input_type(0)); // In exclusive mode, we have computed an extra element containing the sum // of all the input elements. Slice off this extra "last" element. diff --git a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc index a7f5a8f1698..84470b230d4 100644 --- a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc @@ -42,7 +42,7 @@ SendOp::SendOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { } void SendOp::Compile(XlaOpKernelContext* ctx) { - XlaCompiler* compiler = XlaContext::Get(ctx).compiler(); + XlaCompiler* compiler = ctx->compiler(); xla::ChannelHandle channel; OP_REQUIRES_OK(ctx, compiler->GetChannelHandle(tensor_name_, &channel)); xla::Send(ctx->Input(0), channel); @@ -73,7 +73,7 @@ RecvOp::RecvOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { } void RecvOp::Compile(XlaOpKernelContext* ctx) { - XlaCompiler* compiler = XlaContext::Get(ctx).compiler(); + XlaCompiler* compiler = ctx->compiler(); xla::ChannelHandle channel; OP_REQUIRES_OK(ctx, compiler->GetChannelHandle(tensor_name_, &channel)); ctx->SetOutput(0, xla::Recv(ctx->builder(), shape_, channel)); diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc index 60b011ba6d9..b1fa2915d59 100644 --- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc index d6bd927135c..20da8033536 100644 --- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc @@ -71,7 +71,7 @@ class SoftmaxOp : public XlaOpKernel { auto reduce = xla::Reduce(converted, xla::Zero(b, xla_accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); - auto sum = XlaHelpers::ConvertElementType(b, reduce, type); + auto sum = XlaHelpers::ConvertElementType(reduce, type); auto softmax = log_ // softmax = shifted_logits - log(sum(exp(shifted_logits))) @@ -111,11 +111,11 @@ std::pair CrossEntropyWithLogits( // sum_{class} (exp(logits - max_logits)) const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); auto converted = - XlaHelpers::ConvertElementType(b, exp_shifted_logits, accumulation_type); + XlaHelpers::ConvertElementType(exp_shifted_logits, accumulation_type); auto reduce = xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); - auto sum_exp = XlaHelpers::ConvertElementType(b, reduce, type); + auto sum_exp = XlaHelpers::ConvertElementType(reduce, type); // log(sum(exp(logits - max_logits))) auto log_sum_exp = xla::Log(sum_exp); @@ -126,11 +126,10 @@ std::pair CrossEntropyWithLogits( // (The subtraction broadcasts along the batch dimension.) auto sub = xla::Sub(shifted_logits, log_sum_exp, {kBatchDim}); auto mul = xla::Mul(xla::Neg(labels), sub); - auto sum = - xla::Reduce(XlaHelpers::ConvertElementType(b, mul, accumulation_type), - XlaHelpers::Zero(b, accumulation_type), - *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); - auto loss = XlaHelpers::ConvertElementType(b, sum, type); + auto sum = xla::Reduce(XlaHelpers::ConvertElementType(mul, accumulation_type), + XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); + auto loss = XlaHelpers::ConvertElementType(sum, type); // backprop: prob - labels, where // prob = exp(logits - max_logits) / sum(exp(logits - max_logits)) diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc index 7b96b43ad83..8e9e4daf99d 100644 --- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -69,7 +69,7 @@ Status MaybeInitializeStack(xla::XlaBuilder* builder, XlaResource* resource, } TensorShape stack_shape; - stack_shape.AddDim(resource->tensor_array_size()); + stack_shape.AddDim(resource->max_array_size()); stack_shape.AppendShape(elem_shape); if (!resource->initialized()) { @@ -97,10 +97,10 @@ class StackOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - int64 size; - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(0, &size)); + int64 max_size; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(0, &max_size)); OP_REQUIRES( - ctx, size >= 0, + ctx, max_size >= 0, errors::InvalidArgument( "XLA compilation requires a fixed stack size upper bound. If " "you are using tf.while_loop, set the maximum_iterations parameter " @@ -108,14 +108,9 @@ class StackOp : public XlaOpKernel { // We defer initializing the Stack resource until we see the first push. // Otherwise we do not know the shape of the stack elements. - xla::XlaOp value; - XlaContext& xc = XlaContext::Get(ctx); - XlaResource* resource; - string name = absl::StrCat("Stack: ", stack_name_); - OP_REQUIRES_OK( - ctx, xc.CreateResource(XlaResource::kStack, -1, std::move(name), dtype_, - TensorShape(), value, /*tensor_array_size=*/size, - /*tensor_array_gradients=*/{}, &resource)); + XlaResource* resource = + ctx->xla_context()->AddResource(XlaResource::CreateStack( + /*name=*/absl::StrCat("Stack: ", stack_name_), dtype_, max_size)); ctx->SetResourceOutput(0, resource); } diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index 5db52781be4..50653d7b397 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/math.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/lib/prng.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 252967a7464..939d7e19515 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -61,8 +61,8 @@ Status MaybeInitializeTensorArray(xla::XlaBuilder* builder, " but op has dtype ", DataTypeString(dtype), "."); } - TF_RET_CHECK(resource->tensor_array_size() >= 0) - << resource->name() << " size " << resource->tensor_array_size(); + TF_RET_CHECK(resource->max_array_size() >= 0) + << resource->name() << " size " << resource->max_array_size(); if (!resource->initialized()) { TF_RETURN_IF_ERROR(resource->SetTypeAndShape(dtype, elem_shape)); @@ -78,7 +78,7 @@ Status MaybeInitializeTensorArray(xla::XlaBuilder* builder, XLAShapeToTensorShape(shape_or_status.ValueOrDie(), &shape)); TensorShape ta_shape; - ta_shape.AddDim(resource->tensor_array_size()); + ta_shape.AddDim(resource->max_array_size()); ta_shape.AppendShape(elem_shape); if (ta_shape != shape) { return errors::InvalidArgument( @@ -114,7 +114,7 @@ Status CheckTensorArrayIsInitialized(const string& op_name, Status GetTensorArrayShape(const XlaResource* resource, xla::XlaBuilder* builder, TensorShape* shape) { *shape = resource->shape(); - shape->InsertDim(0, resource->tensor_array_size()); + shape->InsertDim(0, resource->max_array_size()); return Status::OK(); } @@ -166,13 +166,10 @@ class TensorArrayOp : public XlaOpKernel { value = xla::Broadcast(zero, ta_shape.dim_sizes()); } - XlaContext& xc = XlaContext::Get(ctx); - XlaResource* var; - string name = absl::StrCat("TensorArray: ", tensor_array_name_); - OP_REQUIRES_OK( - ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name), - dtype_, shape, value, /*tensor_array_size=*/size, - /*tensor_array_gradients=*/{}, &var)); + XlaResource* var = + ctx->xla_context()->AddResource(XlaResource::CreateTensorArray( + /*name=*/absl::StrCat("TensorArray: ", tensor_array_name_), dtype_, + shape, /*initial_value=*/value, /*max_array_size=*/size)); ctx->SetResourceOutput(0, var); Tensor flow(DT_FLOAT, TensorShape({})); @@ -517,14 +514,13 @@ class TensorArraySplitOp : public XlaOpKernel { xla::XlaOp ta = resource->value(); TensorShape ta_shape; - ta_shape.AddDim(resource->tensor_array_size()); + ta_shape.AddDim(resource->max_array_size()); ta_shape.AppendShape(elem_shape); - OP_REQUIRES( - ctx, lengths.size() == resource->tensor_array_size(), - errors::InvalidArgument( - "TensorArray's size is not equal to the size of lengths (", - lengths.size(), " vs. ", resource->tensor_array_size(), ")")); + OP_REQUIRES(ctx, lengths.size() == resource->max_array_size(), + errors::InvalidArgument( + "TensorArray's size is not equal to the size of lengths (", + lengths.size(), " vs. ", resource->max_array_size(), ")")); const xla::XlaOp value = ctx->Input(1); const xla::XlaOp flow = ctx->Input(3); @@ -562,8 +558,7 @@ class TensorArraySizeOp : public XlaOpKernel { XlaResource* var; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &var)); Tensor size_tensor(DT_INT32, {}); - size_tensor.scalar()() = - static_cast(var->tensor_array_size()); + size_tensor.scalar()() = static_cast(var->max_array_size()); ctx->SetConstantOutput(0, size_tensor); } diff --git a/tensorflow/compiler/tf2xla/kernels/topk_op.cc b/tensorflow/compiler/tf2xla/kernels/topk_op.cc index 8a0c94cfae1..ee3bdf3394e 100644 --- a/tensorflow/compiler/tf2xla/kernels/topk_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/topk_op.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/lib/sorting.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index 7077c2e3a54..960c1462ceb 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -320,9 +320,8 @@ class ResourceApplyAdagradDA : public XlaOpKernel { xla::XlaOp lr = ctx->Input(4); xla::XlaOp l1 = ctx->Input(5); xla::XlaOp l2 = ctx->Input(6); - xla::XlaBuilder* const b = ctx->builder(); xla::XlaOp global_step = - XlaHelpers::ConvertElementType(b, ctx->Input(7), dtype_); + XlaHelpers::ConvertElementType(ctx->Input(7), dtype_); accum = accum + grad; squared_accum = squared_accum + xla::Square(grad); diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index 559414eeaa5..ce007fc04a8 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -64,7 +64,7 @@ Status MakeXlaCompilerArgumentsFromInputs( if (!arg.initialized) { *has_uninitialized_vars = true; } - arg.tensor_array_size = resource->tensor_array_size(); + arg.max_array_size = resource->max_array_size(); for (const auto& gradient : resource->tensor_array_gradients()) { arg.tensor_array_gradients.insert(gradient.first); } diff --git a/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc index a9f88a6df25..ad8e707e111 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc @@ -89,13 +89,10 @@ class XlaBroadcastHelperOp : public XlaOpKernel { lhs_shape.DebugString(), " and ", rhs_shape.DebugString())); broadcast_shape[dim] = min_rank_shape->dim_size(i); } - xla::PrimitiveType type = context->input_xla_type(0); - xla::Shape broadcast_xla_shape = - xla::ShapeUtil::MakeShape(type, broadcast_shape); if (broadcast_lhs) { - lhs = xla::BroadcastInDim(lhs, broadcast_xla_shape, broadcast_dims); + lhs = xla::BroadcastInDim(lhs, broadcast_shape, broadcast_dims); } else { - rhs = xla::BroadcastInDim(rhs, broadcast_xla_shape, broadcast_dims); + rhs = xla::BroadcastInDim(rhs, broadcast_shape, broadcast_dims); } context->SetOutput(0, lhs); context->SetOutput(1, rhs); diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index 1ce3930fd1c..422781d536a 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -17,20 +17,6 @@ filegroup( load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") -cc_library( - name = "batch_dot", - srcs = ["batch_dot.cc"], - hdrs = ["batch_dot.h"], - deps = [ - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/core:lib", - ], -) - cc_library( name = "broadcast", srcs = ["broadcast.cc"], @@ -52,7 +38,6 @@ cc_library( srcs = ["cholesky.cc"], hdrs = ["cholesky.h"], deps = [ - ":batch_dot", ":triangular_solve", ":util", ":while_loop", @@ -63,6 +48,8 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:matrix", + "//tensorflow/compiler/xla/client/lib:slicing", "//tensorflow/core:lib", ], ) @@ -87,7 +74,6 @@ cc_library( srcs = ["qr.cc"], hdrs = ["qr.h"], deps = [ - ":batch_dot", ":util", ":while_loop", "//tensorflow/compiler/xla:literal_util", @@ -99,7 +85,8 @@ cc_library( "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/compiler/xla/client/lib:math", - "//tensorflow/compiler/xla/client/lib:numeric", + "//tensorflow/compiler/xla/client/lib:matrix", + "//tensorflow/compiler/xla/client/lib:slicing", "//tensorflow/core:lib", ], ) @@ -129,7 +116,6 @@ cc_library( srcs = ["triangular_solve.cc"], hdrs = ["triangular_solve.h"], deps = [ - ":batch_dot", ":util", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -140,7 +126,9 @@ cc_library( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:constants", - "//tensorflow/compiler/xla/client/lib:numeric", + "//tensorflow/compiler/xla/client/lib:math", + "//tensorflow/compiler/xla/client/lib:matrix", + "//tensorflow/compiler/xla/client/lib:slicing", "//tensorflow/core:lib", ], ) @@ -187,29 +175,6 @@ cc_library( ], ) -xla_test( - name = "util_test", - srcs = ["util_test.cc"], - deps = [ - ":batch_dot", - ":util", - "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:global_data", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", - "//tensorflow/core:test", - ], -) - cc_library( name = "while_loop", srcs = ["while_loop.cc"], diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc deleted file mode 100644 index 5400e8834cb..00000000000 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc +++ /dev/null @@ -1,115 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" - -#include -#include - -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/lib/core/errors.h" - -namespace tensorflow { - -xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x, - bool transpose_y, bool conjugate_x, bool conjugate_y, - xla::PrecisionConfig::Precision precision) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x)); - TF_ASSIGN_OR_RETURN(xla::Shape y_shape, builder->GetShape(y)); - - // Check that both tensors have the same number of dimensions. There must be - // at least two (the batch dimensions can be empty). - if (xla::ShapeUtil::Rank(x_shape) != xla::ShapeUtil::Rank(y_shape)) { - return errors::InvalidArgument( - "Arguments to BatchedDot have different ranks: ", - xla::ShapeUtil::HumanString(x_shape), " vs. ", - xla::ShapeUtil::HumanString(y_shape)); - } - const int ndims = xla::ShapeUtil::Rank(x_shape); - if (ndims < 2) { - return errors::InvalidArgument( - "Arguments to BatchedDot must have rank >= 2: ", ndims); - } - - // The batch dimensions must be equal and the matrix dimensions must be - // valid. - std::vector batch_dimension_numbers; - for (int i = 0; i < ndims - 2; ++i) { - if (x_shape.dimensions(i) != y_shape.dimensions(i)) { - return errors::InvalidArgument( - "Dimension ", i, " of inputs to BatchedDot must be equal: ", - xla::ShapeUtil::HumanString(x_shape), " vs ", - xla::ShapeUtil::HumanString(y_shape)); - } - batch_dimension_numbers.push_back(i); - } - - int x_inner_dim = transpose_x ? (ndims - 2) : (ndims - 1); - int y_inner_dim = transpose_y ? (ndims - 1) : (ndims - 2); - if (x_shape.dimensions(x_inner_dim) != y_shape.dimensions(y_inner_dim)) { - return errors::InvalidArgument( - "Dimensions ", x_inner_dim, " and ", y_inner_dim, - " of arguments to BatchedDot must be equal: ", - xla::ShapeUtil::HumanString(x_shape), " transpose: ", transpose_x, - " vs. ", xla::ShapeUtil::HumanString(y_shape), - " transpose: ", transpose_y); - } - - // Check for zero lhs/rhs dim size. - if (xla::ShapeUtil::IsZeroElementArray(x_shape) || - xla::ShapeUtil::IsZeroElementArray(y_shape)) { - std::vector dimensions(batch_dimension_numbers.size()); - for (int i = 0; i < batch_dimension_numbers.size(); ++i) { - dimensions[i] = x_shape.dimensions(batch_dimension_numbers[i]); - } - int x_outer_dim = transpose_x ? (ndims - 1) : (ndims - 2); - int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1); - dimensions.push_back(x_shape.dimensions(x_outer_dim)); - dimensions.push_back(y_shape.dimensions(y_outer_dim)); - return xla::Broadcast( - xla::ConstantLiteral(builder, - xla::LiteralUtil::Zero(x_shape.element_type())), - dimensions); - } - - if (x_shape.element_type() == xla::C64 && conjugate_x) { - x = xla::Conj(x); - } - if (y_shape.element_type() == xla::C64 && conjugate_y) { - y = xla::Conj(y); - } - - xla::PrecisionConfig precision_proto; - precision_proto.add_operand_precision(precision); - precision_proto.add_operand_precision(precision); - - xla::DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(x_inner_dim); - dot_dnums.add_rhs_contracting_dimensions(y_inner_dim); - for (auto batch_dimension_number : batch_dimension_numbers) { - dot_dnums.add_lhs_batch_dimensions(batch_dimension_number); - dot_dnums.add_rhs_batch_dimensions(batch_dimension_number); - } - - return xla::DotGeneral(x, y, dot_dnums, &precision_proto); - }); -} - -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.h b/tensorflow/compiler/tf2xla/lib/batch_dot.h deleted file mode 100644 index 6edd63a4d3b..00000000000 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.h +++ /dev/null @@ -1,54 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ -#define TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ - -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" - -namespace tensorflow { - -// Multiplies slices of two tensors in batches. - -// Multiplies all slices of `Tensor` `x` and `y` (each slice can be -// viewed as an element of a batch), and arranges the individual results -// in a single output tensor of the same batch size. Each of the -// individual slices can optionally be transposed before multiplication by -// setting the `transpose_x` or `transpose_y` flag to `true`. Similarly, each -// can be elementwise-complex-conjugated by setting the `conjugate_x` or -// `conjugate_y` flag to `true`. To apply a Hermitian adjoint to `x`, set both -// `transpose_x` and `conjugate_x` to `true`, and analogously for `y`. -// -// The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]` -// and `[..., r_y, c_y]`. -// -// The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where: -// -// r_o = c_x if transpose_x else r_x -// c_o = r_y if transpose_y else c_y -// -// It is computed as: -// -// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) -xla::XlaOp BatchDot( - xla::XlaOp x, xla::XlaOp y, bool transpose_x = false, - bool transpose_y = false, bool conjugate_x = false, - bool conjugate_y = false, - xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); - -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ diff --git a/tensorflow/compiler/tf2xla/lib/broadcast.cc b/tensorflow/compiler/tf2xla/lib/broadcast.cc index 3e402ef855c..be31f116686 100644 --- a/tensorflow/compiler/tf2xla/lib/broadcast.cc +++ b/tensorflow/compiler/tf2xla/lib/broadcast.cc @@ -80,10 +80,8 @@ xla::StatusOr BroadcastTo(xla::XlaOp input, broadcast_dim = broadcast_shape_size - broadcast_dim - 1; } absl::c_reverse(broadcast_shape); - xla::XlaOp output = xla::BroadcastInDim( - input, - xla::ShapeUtil::MakeShape(input_shape.element_type(), broadcast_shape), - broadcast_dims); + xla::XlaOp output = + xla::BroadcastInDim(input, broadcast_shape, broadcast_dims); if (broadcast_shape != output_dims) { output = xla::Reshape(output, output_dims); } diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc index ab3d0a56683..7ef8659992f 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.cc +++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc @@ -18,11 +18,12 @@ limitations under the License. #include #include -#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" #include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -101,10 +102,7 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a, // a[..., i, i] auto a_ii = DynamicSliceInMinorDims(body_a, {i, i}, {1, 1}); // np.dot(row, np.swapaxes(row, -1, -2)) - auto diag_dot = BatchDot(row, row, - /*transpose_x=*/false, - /*transpose_y=*/true, /*conjugate_x=*/false, - /*conjugate_y=*/false, precision); + auto diag_dot = BatchDot(row, TransposeInMinorDims(row), precision); // l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row, // np.swapaxes(row, -1, -2))) auto l_ii = @@ -122,10 +120,7 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a, // The columns in [i, n] are zeroed out in `row`, so we just have to // zero out rows above i+1 after the BatchDot. np.dot(l[..., :, :i], // r.T) - auto dot = BatchDot(body_l, row, - /*transpose_x=*/false, - /*transpose_y=*/true, /*conjugate_x=*/false, - /*conjugate_y=*/false, precision); + auto dot = BatchDot(body_l, TransposeInMinorDims(row), precision); // np.dot(l[..., i+1:, :i], r.T) auto dot_ip1 = xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, dot); @@ -185,9 +180,7 @@ xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size, // a[i:, i:i+k] -= np.dot(l[i:, :i], np.transpose(l[i:i+k, :i])) auto lhs = SliceInMinorDims(l, {i, 0}, {n, i}); auto rhs = SliceInMinorDims(l, {i, 0}, {i + k, i}); - auto delta = BatchDot(lhs, rhs, /*transpose_x=*/false, - /*transpose_y=*/true, /*conjugate_x=*/false, - /*conjugate_y=*/false, precision); + auto delta = BatchDot(lhs, TransposeInMinorDims(rhs), precision); auto before = SliceInMinorDims(a, {i, i}, {n, i + k}); a = UpdateSliceInMinorDims(a, before - delta, {i, i}); } diff --git a/tensorflow/compiler/tf2xla/lib/qr.cc b/tensorflow/compiler/tf2xla/lib/qr.cc index 6b3f2b6e065..d6007748609 100644 --- a/tensorflow/compiler/tf2xla/lib/qr.cc +++ b/tensorflow/compiler/tf2xla/lib/qr.cc @@ -18,13 +18,13 @@ limitations under the License. #include #include -#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/math.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -191,12 +191,8 @@ xla::StatusOr QRBlock( auto v_broadcast = xla::Reshape(v, shape); // a[:, :] -= tau * np.dot(v[:, np.newaxis], // np.dot(v[np.newaxis, :], a[:, :])) - auto vva = - BatchDot(v_broadcast, a, /*transpose_x=*/false, /*transpose_y=*/false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); - vva = - BatchDot(v_broadcast, vva, /*transpose_x=*/true, /*transpose_y=*/false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + auto vva = BatchDot(v_broadcast, a, precision); + vva = BatchDot(TransposeInMinorDims(v_broadcast), vva, precision); a = a - xla::Mul(tau, vva, /*broadcast_dimensions=*/batch_dim_indices); @@ -278,12 +274,9 @@ xla::StatusOr ComputeWYRepresentation( auto beta = DynamicSliceInMinorDims(taus, {j}, {1}); // yv has shape [..., n, 1] - auto yv = BatchDot(y, v, /*transpose_x=*/true, /*transpose_y=*/false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + auto yv = BatchDot(TransposeInMinorDims(y), v, precision); // wyv has shape [..., m, 1] - auto wyv = - BatchDot(w, yv, /*transpose_x=*/false, /*transpose_y=*/false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + auto wyv = BatchDot(w, yv, precision); auto z = xla::Mul( -beta, v + wyv, @@ -375,23 +368,15 @@ xla::StatusOr QRDecomposition( // a[i:, i+k:] += np.dot(Y, np.dot(W.T, a[i:, i+k:])) auto a_panel = SliceInMinorDims(a, {i, i + k}, {m, n}); - auto a_update = - BatchDot(w, a_panel, /*transpose_x=*/true, /*transpose_y=*/false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); - a_update = - BatchDot(y, a_update, /*transpose_x=*/false, /*transpose_y=*/false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + auto a_update = BatchDot(TransposeInMinorDims(w), a_panel, precision); + a_update = BatchDot(y, a_update, precision); a_panel = a_panel + a_update; a = UpdateSliceInMinorDims(a, a_panel, {i, i + k}); // q[:, i:] += np.dot(np.dot(q[:, i:], W), Y.T)) auto q_panel = SliceInMinorDims(q, {0, i}, {m, m}); - auto q_update = - BatchDot(q_panel, w, /*transpose_x=*/false, /*transpose_y=*/false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); - q_update = BatchDot(q_update, y, /*transpose_x=*/false, - /*transpose_y=*/true, /*conjugate_x=*/false, - /*conjugate_y=*/false, precision); + auto q_update = BatchDot(q_panel, w, precision); + q_update = BatchDot(q_update, TransposeInMinorDims(y), precision); q_panel = q_panel + q_update; q = UpdateSliceInMinorDims(q, q_panel, {0, i}); } diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc index 6524c2a9b1a..192a61dca26 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc @@ -18,10 +18,11 @@ limitations under the License. #include #include -#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" @@ -311,13 +312,13 @@ xla::XlaOp SolveWithInvertedDiagonalBlocks( auto a_row = MaybeConjugate(SliceInMinorDims(a, start, end), conjugate_a); if (left_side) { - remainder = b_row - BatchDot(a_row, x, transpose_a, false, - /*conjugate_x=*/false, - /*conjugate_y=*/false, precision); + remainder = + b_row - BatchDot(MaybeTransposeInMinorDims(a_row, transpose_a), x, + precision); } else { - remainder = b_row - BatchDot(x, a_row, false, transpose_a, - /*conjugate_x=*/false, - /*conjugate_y=*/false, precision); + remainder = + b_row - BatchDot(x, MaybeTransposeInMinorDims(a_row, transpose_a), + precision); } } @@ -327,13 +328,12 @@ xla::XlaOp SolveWithInvertedDiagonalBlocks( xla::ConstantR0WithType(builder, xla::S32, j * block_size); std::vector update_starts = {start_index, zero}; if (left_side) { - x_update = - BatchDot(inv_block, remainder, transpose_a, false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + x_update = BatchDot(MaybeTransposeInMinorDims(inv_block, transpose_a), + remainder, precision); } else { - x_update = - BatchDot(remainder, inv_block, false, transpose_a, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + x_update = BatchDot(remainder, + MaybeTransposeInMinorDims(inv_block, transpose_a), + precision); std::swap(update_starts[0], update_starts[1]); } x = DynamicUpdateSliceInMinorDims(x, x_update, /*starts=*/update_starts); diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index 804671fbc75..c0bd172d17c 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -113,36 +113,6 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, return xla::ConstantLiteral(builder, literal); } -xla::XlaOp SliceInMinorDims(xla::XlaOp x, absl::Span start, - absl::Span end) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_RET_CHECK(start.size() == end.size()); - int64 n_minor_dims = start.size(); - - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - - const int64 n_dims = xla::ShapeUtil::Rank(shape); - TF_RET_CHECK(n_minor_dims <= n_dims); - auto major_dims = xla::AsInt64Slice(shape.dimensions()) - .subspan( - /*pos=*/0, - /*len=*/n_dims - n_minor_dims); - - // Prepends 0s in the major dim - std::vector padded_start(n_dims, 0); - std::copy(start.begin(), start.end(), - padded_start.begin() + major_dims.size()); - - // Prepends the shape of the major dims. - std::vector padded_end(n_dims); - std::copy(major_dims.begin(), major_dims.end(), padded_end.begin()); - std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size()); - - std::vector strides(n_dims, 1); - return xla::Slice(x, padded_start, padded_end, strides); - }); -} std::vector ConcatVectors(absl::Span xs, absl::Span ys) { @@ -152,100 +122,4 @@ std::vector ConcatVectors(absl::Span xs, return output; } -xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x, - absl::Span starts, - absl::Span sizes) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - int64 n_minor_dims = starts.size(); - TF_RET_CHECK(n_minor_dims == sizes.size()); - TF_RET_CHECK(n_minor_dims <= n_dims); - auto major_dims = xla::AsInt64Slice(shape.dimensions()) - .subspan( - /*pos=*/0, - /*len=*/n_dims - sizes.size()); - auto padded_starts = PrependZerosInMajorDims(x, starts); - auto padded_sizes = ConcatVectors(major_dims, sizes); - return xla::DynamicSlice(x, padded_starts, padded_sizes); - }); -} - -xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update, - absl::Span start) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - // TODO(phawkins): make int64 work on all backends, remove the int32 cast. - std::vector start_as_int32(start.begin(), start.end()); - auto start_constant = xla::ConstantR1(builder, start_as_int32); - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - TF_ASSIGN_OR_RETURN(xla::Shape start_constant_shape, - builder->GetShape(start_constant)); - const int64 start_length = - xla::ShapeUtil::GetDimension(start_constant_shape, -1); - TF_RET_CHECK(start_length == n_dims); - return xla::DynamicUpdateSlice(x, update, start_constant); - }); -} - -xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, - absl::Span start) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - const int64 n_minor_dims = start.size(); - TF_RET_CHECK(n_minor_dims <= n_dims); - std::vector padded_start(n_dims, 0); - std::copy(start.begin(), start.end(), - padded_start.begin() + (n_dims - n_minor_dims)); - return UpdateSlice(x, update, padded_start); - }); -} - -xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, - absl::Span starts) { - auto padded_starts = PrependZerosInMajorDims(x, starts); - return xla::DynamicUpdateSlice(x, update, padded_starts); -} - -xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, - absl::Span starts) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - auto zero = xla::Reshape(xla::ConstantR0(builder, 0), {1}); - std::vector padded_starts(n_dims, zero); - for (int i = 0; i < starts.size(); ++i) { - padded_starts[n_dims - starts.size() + i] = xla::Reshape(starts[i], {1}); - } - return xla::ConcatInDim(builder, padded_starts, 0); - }); -} - -xla::XlaOp TransposeInMinorDims(xla::XlaOp x) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - TF_RET_CHECK(n_dims >= 2); - std::vector permutation(n_dims); - std::iota(permutation.begin(), permutation.end(), 0); - std::swap(permutation[n_dims - 1], permutation[n_dims - 2]); - return xla::Transpose(x, permutation); - }); -} - -xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - auto perform_conj = shape.element_type() == xla::C64 && conjugate; - return perform_conj ? xla::Conj(x) : x; - }); -} - } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h index 80e9e5b002d..aec8061cb43 100644 --- a/tensorflow/compiler/tf2xla/lib/util.h +++ b/tensorflow/compiler/tf2xla/lib/util.h @@ -38,44 +38,10 @@ xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, int64 value); -// Builds a vector of zeros of length rank(x) with the last values being -// those in `starts`. -xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, - absl::Span starts); - -// Performs a slice in the minor dimensions of a Tensor. -xla::XlaOp SliceInMinorDims(xla::XlaOp x, absl::Span start, - absl::Span end); - // Returns the concatenation of `xs` and `ys`. std::vector ConcatVectors(absl::Span xs, absl::Span ys); -// Performs a dynamic slice in the minor dimensions of a Tensor. -xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x, - absl::Span starts, - absl::Span sizes); - -// Updates a slice of 'x', i.e., -// x[start[0], ..., start[n]] = update -xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update, - absl::Span start); - -// Updates a slice of 'x', where 'start' contains a list of minor dimensions: -// x[..., start[0], ..., start[n]] = update -xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, - absl::Span start); - -xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, - absl::Span starts); - -// Transposes a stack of matrices `x` by swapping the last two dimensions. -xla::XlaOp TransposeInMinorDims(xla::XlaOp x); - -// Applies a complex conjugation operation if `a` is complex and `conjugate_a` -// is true, otherwise returns its argument. -xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate); - } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/python/BUILD b/tensorflow/compiler/tf2xla/python/BUILD index c9f486edc8d..fef97b98c37 100644 --- a/tensorflow/compiler/tf2xla/python/BUILD +++ b/tensorflow/compiler/tf2xla/python/BUILD @@ -1,11 +1,13 @@ licenses(["notice"]) # Apache 2.0 +package_group( + name = "friends", + includes = ["//tensorflow:internal"], +) + package( default_visibility = [ - "//learning/deepmind/public/wavenet/python:__subpackages__", - "//learning/deepmind/research/alphastar:__subpackages__", - "//learning/tfx:__subpackages__", - "//tensorflow:internal", + ":friends", ], ) diff --git a/tensorflow/compiler/tf2xla/shape_util.h b/tensorflow/compiler/tf2xla/shape_util.h index f7e34a5b40c..0b231ea8e7a 100644 --- a/tensorflow/compiler/tf2xla/shape_util.h +++ b/tensorflow/compiler/tf2xla/shape_util.h @@ -18,6 +18,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_SHAPE_UTIL_H_ #define TENSORFLOW_COMPILER_TF2XLA_SHAPE_UTIL_H_ +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h index 425e769346f..c7341cf8b9e 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h @@ -26,7 +26,7 @@ limitations under the License. // Forward-declare, rather than include, to reduce code size for users that // never use this functionality. namespace xla { -class ProgramShape; +class ProgramShapeProto; class HloProfilePrinterData; } @@ -84,7 +84,7 @@ class XlaCompiledCpuFunction { void set_result_names(const char** result_names) { result_names_ = result_names; } - void set_program_shape(const xla::ProgramShape* program_shape) { + void set_program_shape(const xla::ProgramShapeProto* program_shape) { program_shape_ = program_shape; } const xla::HloProfilePrinterData* hlo_profile_printer_data() const { @@ -122,7 +122,7 @@ class XlaCompiledCpuFunction { const char** result_names_ = nullptr; // [Optional] Arg and result shapes. - const xla::ProgramShape* program_shape_ = nullptr; + const xla::ProgramShapeProto* program_shape_ = nullptr; // [Optional] Profile printer data. Null if profiling is disabled. const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr; @@ -206,8 +206,14 @@ class XlaCompiledCpuFunction { // // Aliasing of argument and result buffers is not allowed, and results in // undefined behavior. - void set_arg_data(size_t index, void* data) { - buffer_table_[arg_index_table_[index]] = data; + void set_arg_data(size_t index, const void* data) { + // The const_cast is safe because the generated code does not write to arg + // buffers. + // + // buffer_table_ contains pointers to buffers that _will_ be written to by + // generated code so it would be misleading to make buffer_table_ a `const + // void**`. + buffer_table_[arg_index_table_[index]] = const_cast(data); } // ------------------------------ @@ -264,7 +270,7 @@ class XlaCompiledCpuFunction { // Returns the shape of the args and results. May return nullptr if the // program shape isn't available. - const xla::ProgramShape* ProgramShape() const { return program_shape_; } + const xla::ProgramShapeProto* ProgramShape() const { return program_shape_; } bool hlo_profiling_enabled() const { return hlo_profile_printer_data_ != nullptr; @@ -287,11 +293,6 @@ class XlaCompiledCpuFunction { // Argument i needs to be placed in buffer_table_[arg_index_to_temp_index_[i]] // for XLA generated code to be able to find it. - // - // For now we need to keep around the args_ array because there is code that - // depends on args() returning a void**. However, in the future we may remove - // args_ in favor of using buffer_table_ as the sole storage for the - // arguments. const int32* const arg_index_table_; // The number of incoming arguments. @@ -310,7 +311,7 @@ class XlaCompiledCpuFunction { // Optional metadata. const char** arg_names_ = nullptr; const char** result_names_ = nullptr; - const xla::ProgramShape* program_shape_ = nullptr; + const xla::ProgramShapeProto* program_shape_ = nullptr; const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr; }; diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index a08d030ce71..ee461a3c07d 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -158,7 +158,8 @@ Status BuildComputation( xla::XlaBuilder* builder, xla::XlaComputation* computation, int* num_computation_outputs, int* num_nonconst_outputs, std::vector* outputs, - std::vector* resource_updates) { + std::vector* resource_updates, + xla::Shape* output_shape) { // Attach a common operator name as metadata. This has no semantic effect ā€” it // merely makes the HLO graph more readable when visualized via TensorBoard, // since TensorBoard forms groups out of operators with similar names. @@ -176,6 +177,10 @@ Status BuildComputation( std::vector elems; elems.reserve(retvals.size()); + + // Keeps track of which retvals have layout to update. The first element is + // the output index, second element is the new layout. + std::vector> retval_to_update_layout; for (int i = 0; i < retvals.size(); ++i) { XlaCompiler::OutputDescription& output = (*outputs)[i]; const XlaExpression& retval = retvals[i]; @@ -202,10 +207,12 @@ Status BuildComputation( TF_ASSIGN_OR_RETURN(xla::Shape shape, shape_representation_fn( output.shape, output.type)); value = xla::Reshape(value, xla::AsInt64Slice(shape.dimensions())); + retval_to_update_layout.emplace_back(elems.size(), shape.layout()); } else if (it != retval_cores.end()) { // Apply the sharding to the output, if there is a core assignment. value = identity_op(value); } + elems.push_back(value); break; } @@ -297,6 +304,21 @@ Status BuildComputation( return computation_status.status(); } *computation = computation_status.ConsumeValueOrDie(); + + TF_ASSIGN_OR_RETURN(const auto& program_shape, + computation->GetProgramShape()); + *output_shape = program_shape.result(); + // Update the output layout to the layout of retval. + for (auto& update : retval_to_update_layout) { + if (!always_return_tuple && elems.size() == 1) { + *output_shape->mutable_layout() = update.second; + continue; + } + + xla::Shape* output_sub_shape = + xla::ShapeUtil::GetMutableSubshape(output_shape, {update.first}); + *output_sub_shape->mutable_layout() = update.second; + } return Status::OK(); } @@ -304,10 +326,10 @@ Status BuildComputation( bool XlaCompiler::Argument::operator==( const XlaCompiler::Argument& other) const { - if (std::tie(kind, resource_kind, type, name, initialized, tensor_array_size, + if (std::tie(kind, resource_kind, type, name, initialized, max_array_size, tensor_array_gradients) != std::tie(other.kind, other.resource_kind, other.type, other.name, - other.initialized, other.tensor_array_size, + other.initialized, other.max_array_size, other.tensor_array_gradients)) { return false; } @@ -337,8 +359,8 @@ string XlaCompiler::Argument::HumanString() const { string output = absl::StrCat("kind=resource", common, " resource_kind=", XlaResource::KindToString(resource_kind), " initialized=", initialized); - if (tensor_array_size >= 0) { - absl::StrAppend(&output, " tensor_array_size=", tensor_array_size); + if (max_array_size >= 0) { + absl::StrAppend(&output, " max_array_size=", max_array_size); } if (!tensor_array_gradients.empty()) { absl::StrAppend(&output, " tensor_array_gradients=", @@ -358,7 +380,7 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options) initialization_status_(Status::OK()), next_step_id_(1), device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)), - device_mgr_({device_}) { + device_mgr_(absl::WrapUnique(device_)) { CHECK(!options_.device_type.type_string().empty()); if (options_.populate_resource_manager) { initialization_status_ = @@ -545,12 +567,12 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, return Status::OK(); } case XlaResource::kTensorArray: { - if (arg.tensor_array_size < 0) { + if (arg.max_array_size < 0) { return errors::InvalidArgument( - "Negative tensor_array_size in XLAShapeForArgument"); + "Negative max_array_size in XLAShapeForArgument"); } TensorShape shape; - shape.AddDim(arg.tensor_array_size); + shape.AddDim(arg.max_array_size); shape.AppendShape(arg.shape); TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, xla_shape)); @@ -562,12 +584,12 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, return Status::OK(); } case XlaResource::kStack: { - if (arg.tensor_array_size < 0) { + if (arg.max_array_size < 0) { return errors::InvalidArgument( - "Negative tensor_array_size in XLAShapeForArgument"); + "Negative max_array_size in XLAShapeForArgument"); } TensorShape shape; - shape.AddDim(arg.tensor_array_size); + shape.AddDim(arg.max_array_size); shape.AppendShape(arg.shape); xla::Shape buffer_shape; TF_RETURN_IF_ERROR( @@ -613,21 +635,23 @@ Status XlaCompiler::BuildArguments( const XlaCompiler::Argument& arg = args[i]; XlaExpression& arg_expression = (*arg_expressions)[i]; switch (arg.kind) { - case XlaCompiler::Argument::kResource: + case XlaCompiler::Argument::kResource: { TF_RET_CHECK(arg.resource_kind != XlaResource::kInvalid); // TODO(phawkins): this code assumes that resource arguments do not // alias. - XlaResource* resource; - TF_RETURN_IF_ERROR(context->CreateResource( - arg.resource_kind, i, arg.name, arg.type, arg.shape, xla::XlaOp(), - /*tensor_array_size=*/arg.tensor_array_size, - /*tensor_array_gradients=*/arg.tensor_array_gradients, &resource)); + XlaResource* resource = + context->AddResource(absl::make_unique( + arg.resource_kind, i, arg.name, arg.type, arg.shape, + xla::XlaOp(), + /*max_array_size=*/arg.max_array_size, + /*tensor_array_gradients=*/arg.tensor_array_gradients, + /*tensor_array_multiple_writes_aggregate=*/true)); arg_expression = XlaExpression::Resource(resource); if (arg.initialized) { input_mapping->push_back(i); } - break; + } case XlaCompiler::Argument::kParameter: case XlaCompiler::Argument::kToken: { input_mapping->push_back(i); @@ -901,9 +925,7 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, options_.device_type, name)); xla::XlaBuilder builder(name); - XlaContext* context = - new XlaContext(this, &builder, options_.allow_cpu_custom_calls, - &options_.shape_representation_fn); + XlaContext* context = new XlaContext(this, &builder); core::ScopedUnref context_unref(context); std::vector real_args(args.begin(), args.end()); @@ -988,23 +1010,12 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, options.return_updated_values_for_all_resources, options.always_return_tuple, &builder, result->computation.get(), &num_computation_outputs, &num_nonconst_outputs, &result->outputs, - &result->resource_updates)); + &result->resource_updates, &result->xla_output_shape)); VLOG(2) << "Outputs: total: " << context->retvals().size() << " nonconstant: " << num_nonconst_outputs; - - // Compute the XLA output shape, if there is a computation with non-constant - // outputs. - TF_ASSIGN_OR_RETURN(std::unique_ptr computation_shape, - client()->GetComputationShape(*result->computation)); - - result->xla_output_shape.Swap(computation_shape->mutable_result()); VLOG(2) << "XLA output shape: " - << xla::ShapeUtil::HumanString(result->xla_output_shape); - - // Tensorflow expects a major-to-minor order of results. - xla::LayoutUtil::SetToDefaultLayout(&result->xla_output_shape); - + << xla::ShapeUtil::HumanStringWithLayout(result->xla_output_shape); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 63426124686..0d801b73a8c 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -150,7 +150,7 @@ class XlaCompiler { // For a TensorArray or Stack resource, what is the array's declared size? // (Used for lazy initialization.) - int64 tensor_array_size = -1; + int64 max_array_size = -1; // TensorArray resource parameters are passed as (array, gradient array 0, // ..., gradient array k), where the gradient arrays are in the same order diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index aaee208f634..fe2a5f5b0c9 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -649,7 +650,7 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { args[0].initialized = true; args[0].type = DT_INT32; args[0].shape = TensorShape({}); - args[0].tensor_array_size = 2; + args[0].max_array_size = 2; args[0].tensor_array_gradients = {"grad2"}; // Compiles the graph. @@ -708,7 +709,7 @@ TEST_F(XlaCompilerTest, UnwrittenTensorArrayGradientsAreNotComputationOutputs) { args[0].initialized = true; args[0].type = DT_INT32; args[0].shape = TensorShape({}); - args[0].tensor_array_size = 2; + args[0].max_array_size = 2; args[0].tensor_array_gradients = {"grad1"}; // Compiles the graph. @@ -740,7 +741,7 @@ TEST_F(XlaCompilerTest, NewTensorArrayGradientsAreComputationOutputs) { args[0].initialized = true; args[0].type = DT_INT32; args[0].shape = TensorShape({}); - args[0].tensor_array_size = 2; + args[0].max_array_size = 2; args[0].tensor_array_gradients = {"grad1"}; // Compiles the graph. @@ -910,6 +911,82 @@ TEST_F(XlaCompilerTest, Variables) { RunAndCheckVariablesComputation(client_, result); } +TEST_F(XlaCompilerTest, ResultLayoutSingle) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto b = ops::_Retval(scope.WithOpName("RET"), a, 0); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector args(1); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2, 3}); + + auto options = DefaultOptions(); + // Sets the representation function to return a non-default layout. + options.shape_representation_fn = + [](const TensorShape& shape, DataType type) -> xla::StatusOr { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape)); + *xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1}); + return xla_shape; + }; + + // Compiles the graph. + XlaCompiler compiler(options); + + XlaCompiler::CompilationResult result; + auto compile_options = XlaCompiler::CompileOptions(); + compile_options.always_return_tuple = false; + TF_ASSERT_OK(compiler.CompileGraph(compile_options, "id", std::move(graph), + args, &result)); + EXPECT_TRUE(xla::ShapeUtil::Equal( + result.xla_output_shape, + xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1}))); +} + +TEST_F(XlaCompilerTest, ResultLayoutMultiple) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto b = ops::_Retval(scope.WithOpName("RET1"), a, 0); + auto c = ops::_Retval(scope.WithOpName("RET2"), a, 1); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector args(1); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2, 3}); + + auto options = DefaultOptions(); + // Sets the representation function to return a non-default layout. + options.shape_representation_fn = + [](const TensorShape& shape, DataType type) -> xla::StatusOr { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape)); + *xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1}); + return xla_shape; + }; + + // Compiles the graph. + XlaCompiler compiler(options); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "id", + std::move(graph), args, &result)); + xla::Shape result_shape = + xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1}); + + EXPECT_TRUE(xla::ShapeUtil::Equal( + result.xla_output_shape, + xla::ShapeUtil::MakeTupleShape({result_shape, result_shape}))); +} + // Tests a simple graph that reads and writes a variable. TEST_F(XlaCompilerTest, ReturnResourceHandleOnly) { Scope scope = Scope::NewRootScope().ExitOnError(); diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 43095fbb473..a69af705033 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -54,25 +54,14 @@ const char XlaContext::kXlaContextResourceName[] = "_xla_context"; return *context; } -/* static */ XlaContext& XlaContext::Get(const XlaOpKernelContext* ctx) { - return Get(ctx->op_kernel_context()); -} - void XlaContext::set_args(std::vector args) { args_ = std::move(args); } -XlaContext::XlaContext( - XlaCompiler* compiler, xla::XlaBuilder* builder, - bool allow_cpu_custom_calls, - const std::function( - const TensorShape&, DataType)>* shape_representation_fn) - : compiler_(compiler), - builder_(builder), - allow_cpu_custom_calls_(allow_cpu_custom_calls), - shape_representation_fn_(shape_representation_fn) {} +XlaContext::XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder) + : compiler_(compiler), builder_(builder) {} -string XlaContext::DebugString() { return "TLA JIT context"; } +string XlaContext::DebugString() { return "XLA JIT context"; } void XlaContext::SetRetval(int index, const XlaExpression& expression) { if (retvals_.size() <= index) { @@ -81,21 +70,9 @@ void XlaContext::SetRetval(int index, const XlaExpression& expression) { retvals_[index] = expression; } -Status XlaContext::CreateResource( - XlaResource::Kind kind, int arg_num, string name, DataType type, - TensorShape shape, const xla::XlaOp& handle, int64 tensor_array_size, - const std::set& tensor_array_gradients, XlaResource** resource) { - resources_.emplace_back( - new XlaResource(kind, arg_num, std::move(name), type, std::move(shape), - handle, tensor_array_size, tensor_array_gradients, - /*tensor_array_multiple_writes_aggregate=*/false)); - *resource = resources_.back().get(); - return Status::OK(); -} - -xla::StatusOr XlaContext::RepresentationShape( - const TensorShape& shape, DataType type) const { - return (*shape_representation_fn_)(shape, type); +XlaResource* XlaContext::AddResource(std::unique_ptr resource) { + resources_.push_back(std::move(resource)); + return resources_.back().get(); } const xla::XlaComputation* XlaContext::GetOrCreateMax(const DataType type) { diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index dbfd344c9ba..0767d1faac1 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -41,14 +41,10 @@ class XlaContext : public ResourceBase { public: // Retrieves the XlaContext of the current compilation. static XlaContext& Get(const OpKernelContext* ctx); - static XlaContext& Get(const XlaOpKernelContext* ctx); // Creates a new XlaContext. See the documentation on the class data fields // for descriptions of the arguments. - XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder, - bool allow_cpu_custom_calls, - const std::function( - const TensorShape&, DataType)>* shape_representation_fn); + XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder); // Virtual method defined by ResourceBase. string DebugString() override; @@ -58,8 +54,6 @@ class XlaContext : public ResourceBase { // Returns the XlaBuilder that Ops use for compiling new expressions. xla::XlaBuilder* builder() { return builder_; } - bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; } - const std::vector& args() const { return args_; } void set_args(std::vector args); @@ -70,25 +64,13 @@ class XlaContext : public ResourceBase { // grows the return values vector to size index+1 if it is smaller. void SetRetval(int index, const XlaExpression& expression); - // Creates a resource with resource `kind` and initial value `handle`. `name` - // is a descriptive name for use in error messages. See the `XlaResource` - // constructor for a description of the remaining arguments. - // Fails if the resource already exists. - Status CreateResource(XlaResource::Kind kind, int arg_num, string name, - DataType type, TensorShape shape, - const xla::XlaOp& handle, int64 tensor_array_size, - const std::set& tensor_array_gradients, - XlaResource** resource); + // Adds 'resource' to the set of resources owned by the context. + XlaResource* AddResource(std::unique_ptr resource); const std::vector>& resources() { return resources_; } - // Returns the XLA shape to be used to represent a variable of TF `shape` - // and `type`, or of an argument or return value of a top-level computation. - xla::StatusOr RepresentationShape(const TensorShape& shape, - DataType type) const; - // Get an XLA lambda to compute Max. This is cached in the // XlaContext since it may be used by multiple Ops. There is a // separate specialization of the computation for each DataType. @@ -118,9 +100,6 @@ class XlaContext : public ResourceBase { // The XlaBuilder used to construct the subgraph's compiled representation. xla::XlaBuilder* builder_; - // Allow ops to emit CustomCall operations for CPU. - const bool allow_cpu_custom_calls_; - // Arguments to the Tensorflow graph, indexed by _Arg index. // Includes both compile-time constant arguments and runtime parameters. std::vector args_; @@ -131,11 +110,6 @@ class XlaContext : public ResourceBase { // Holds ownership of resources. The resources are not ordered. std::vector> resources_; - // Describes the on-host shapes of parameters and return values. Also see: - // XlaDevice::Options::shape_representation_fn. - const std::function(const TensorShape&, DataType)>* - shape_representation_fn_; - // Cache of prebuilt computations indexed by their type. using ComputationMap = std::map; diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index 9a34cd8c6ae..c2c07512111 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -26,7 +26,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/types.h" @@ -216,8 +215,7 @@ DataType XlaHelpers::SumAccumulationType(const DataType& dtype) { return dtype; } -xla::XlaOp XlaHelpers::ConvertElementType(xla::XlaBuilder* const builder, - const xla::XlaOp& operand, +xla::XlaOp XlaHelpers::ConvertElementType(const xla::XlaOp& operand, const DataType new_element_type) { xla::PrimitiveType convert_to; TF_CHECK_OK(DataTypeToPrimitiveType(new_element_type, &convert_to)); diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h index 39578144caa..4858dfee55a 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.h +++ b/tensorflow/compiler/tf2xla/xla_helpers.h @@ -80,8 +80,7 @@ class XlaHelpers { // A helper for creating a ConvertElementType xla op given a DataType rather // than the xla::PrimitiveType. - static xla::XlaOp ConvertElementType(xla::XlaBuilder* const builder, - const xla::XlaOp& operand, + static xla::XlaOp ConvertElementType(const xla::XlaOp& operand, const DataType new_element_type); }; diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc index 86a78ee429e..fabbcd04fed 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc @@ -133,7 +133,8 @@ XlaJitCompiledCpuFunction::Compile( jit->executable_ = std::move(executable); jit->buffer_infos_ = std::move(buffer_infos); jit->arg_index_table_ = std::move(arg_index_table); - jit->program_shape_ = std::move(program_shape); + jit->program_shape_ = + absl::make_unique(program_shape->ToProto()); jit->static_data_.set_raw_function(raw_function); jit->static_data_.set_buffer_infos(jit->buffer_infos_.data()); jit->static_data_.set_num_buffers(jit->buffer_infos_.size()); diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h index d3c8f22a807..a5392057177 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h @@ -80,8 +80,10 @@ class XlaJitCompiledCpuFunction { std::vector arg_names_; std::vector result_names_; - // The backing data for the program shape. - std::unique_ptr program_shape_; + // The backing data for the program shape. The proto form of program shape is + // used because the program shape is serialized and embedded in the object + // file. + std::unique_ptr program_shape_; }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc index 6d49298a6f3..8846088678b 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc @@ -116,13 +116,13 @@ TEST(XlaJitCompiledCpuFunction, Sum) { // Check program shape. using xla::ShapeUtil; const xla::Shape s32 = ShapeUtil::MakeShape(xla::S32, {}); - const xla::ProgramShape* program_shape = function.ProgramShape(); - ASSERT_TRUE(program_shape != nullptr); - ASSERT_EQ(program_shape->parameters_size(), 2); - EXPECT_TRUE(ShapeUtil::Compatible(program_shape->parameters(0), s32)); - EXPECT_TRUE(ShapeUtil::Compatible(program_shape->parameters(1), s32)); + ASSERT_TRUE(function.ProgramShape() != nullptr); + const xla::ProgramShape program_shape(*function.ProgramShape()); + ASSERT_EQ(program_shape.parameters_size(), 2); + EXPECT_TRUE(ShapeUtil::Compatible(program_shape.parameters(0), s32)); + EXPECT_TRUE(ShapeUtil::Compatible(program_shape.parameters(1), s32)); - const xla::Shape& result = program_shape->result(); + const xla::Shape& result = program_shape.result(); ASSERT_EQ(result.element_type(), xla::TUPLE); ASSERT_EQ(ShapeUtil::TupleElementCount(result), 1); const xla::Shape& result0 = ShapeUtil::GetTupleElementShape(result, 0); diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 8dd8def0549..58808c76de6 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -36,8 +36,16 @@ bool XlaOpKernelContext::ValidateInputsAreSameShape(OpKernel* op) { return context_->ValidateInputsAreSameShape(op); } +XlaContext* XlaOpKernelContext::xla_context() const { + return &XlaContext::Get(context_); +} + xla::XlaBuilder* XlaOpKernelContext::builder() const { - return XlaContext::Get(this).builder(); + return xla_context()->builder(); +} + +XlaCompiler* XlaOpKernelContext::compiler() const { + return xla_context()->compiler(); } // Retrieves an XlaExpression that was allocated by a previous Op. @@ -338,8 +346,8 @@ Status XlaOpKernelContext::ConstantInputList( namespace { Status ReadVariableInputTensor(const Tensor& tensor, DataType type, - const OpKernelContext* ctx, TensorShape* shape, - xla::XlaOp* value) { + const XlaOpKernelContext* ctx, + TensorShape* shape, xla::XlaOp* value) { const XlaExpression* expression = CastExpressionFromTensor(tensor); XlaResource* variable = expression->resource(); TF_RET_CHECK(variable != nullptr); @@ -357,10 +365,9 @@ Status ReadVariableInputTensor(const Tensor& tensor, DataType type, *shape = variable->shape(); } - XlaContext& xla_context = XlaContext::Get(ctx); - TF_ASSIGN_OR_RETURN( - xla::Shape representation_shape, - xla_context.RepresentationShape(variable->shape(), variable->type())); + TF_ASSIGN_OR_RETURN(xla::Shape representation_shape, + ctx->compiler()->options().shape_representation_fn( + variable->shape(), variable->type())); xla::Shape xla_shape; TF_RETURN_IF_ERROR( TensorShapeToXLAShape(variable->type(), variable->shape(), &xla_shape)); @@ -377,15 +384,15 @@ Status ReadVariableInputTensor(const Tensor& tensor, DataType type, Status XlaOpKernelContext::ReadVariableInput(int index, DataType type, TensorShape* shape, xla::XlaOp* value) { - return ReadVariableInputTensor(context_->input(index), type, context_, shape, + return ReadVariableInputTensor(context_->input(index), type, this, shape, value); } Status XlaOpKernelContext::ReadVariableInput(absl::string_view name, DataType type, TensorShape* shape, xla::XlaOp* value) { - return ReadVariableInputTensor(GetInputTensorByName(name), type, context_, - shape, value); + return ReadVariableInputTensor(GetInputTensorByName(name), type, this, shape, + value); } Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type, @@ -464,7 +471,7 @@ Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) { namespace { Status AssignVariableTensor(const Tensor& tensor, DataType type, - const OpKernelContext* ctx, xla::XlaOp handle, + const XlaOpKernelContext* ctx, xla::XlaOp handle, xla::XlaBuilder* builder) { const XlaExpression* expression = CastExpressionFromTensor(tensor); XlaResource* variable = expression->resource(); @@ -481,9 +488,9 @@ Status AssignVariableTensor(const Tensor& tensor, DataType type, TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape)); - XlaContext& xla_context = XlaContext::Get(ctx); - TF_ASSIGN_OR_RETURN(xla::Shape representation_shape, - xla_context.RepresentationShape(shape, type)); + TF_ASSIGN_OR_RETURN( + xla::Shape representation_shape, + ctx->compiler()->options().shape_representation_fn(shape, type)); xla::Shape xla_shape; TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape)); if (!xla::ShapeUtil::Compatible(xla_shape, representation_shape)) { @@ -498,19 +505,15 @@ Status AssignVariableTensor(const Tensor& tensor, DataType type, Status XlaOpKernelContext::AssignVariable(int input_index, DataType type, xla::XlaOp handle) { TF_RET_CHECK(handle.valid()); - return AssignVariableTensor(context_->input(input_index), type, context_, - handle, builder()); + return AssignVariableTensor(context_->input(input_index), type, this, handle, + builder()); } Status XlaOpKernelContext::AssignVariable(absl::string_view name, DataType type, xla::XlaOp handle) { TF_RET_CHECK(handle.valid()); - return AssignVariableTensor(GetInputTensorByName(name), type, context_, - handle, builder()); -} - -XlaCompiler* XlaOpKernelContext::compiler() const { - return XlaContext::Get(context_).compiler(); + return AssignVariableTensor(GetInputTensorByName(name), type, this, handle, + builder()); } void XlaOpKernelContext::CtxFailure(const Status& s) { @@ -530,22 +533,22 @@ void XlaOpKernelContext::CtxFailureWithWarning(const char* file, int line, const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMax( const DataType type) { - return XlaContext::Get(context_).GetOrCreateMax(type); + return xla_context()->GetOrCreateMax(type); } const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMin( const DataType type) { - return XlaContext::Get(context_).GetOrCreateMin(type); + return xla_context()->GetOrCreateMin(type); } const xla::XlaComputation* XlaOpKernelContext::GetOrCreateAdd( const DataType type) { - return XlaContext::Get(context_).GetOrCreateAdd(type); + return xla_context()->GetOrCreateAdd(type); } const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMul( const DataType type) { - return XlaContext::Get(context_).GetOrCreateMul(type); + return xla_context()->GetOrCreateMul(type); } const Tensor& XlaOpKernelContext::GetInputTensorByName(absl::string_view name) { diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index c06efa2c474..1858844bc05 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -60,6 +60,8 @@ class XlaOpKernelContext { public: explicit XlaOpKernelContext(OpKernelContext* context); + XlaContext* xla_context() const; + // Returns the XLA XlaBuilder containing the output of compilation. xla::XlaBuilder* builder() const; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index dcd0e9c5c1f..14237df6908 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" @@ -130,8 +130,7 @@ XlaOpRegistry::~XlaOpRegistry() = default; // Lazily register the CPU and GPU JIT devices the first time // GetCompilationDevice is called. static void* registration_init = [®istry]() { - legacy_flags::MarkForCompilationPassFlags* flags = - legacy_flags::GetMarkForCompilationPassFlags(); + MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); bool cpu_global_jit = flags->tf_xla_cpu_global_jit; mutex_lock lock(registry.mutex_); diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc index a322eb9015e..48a3c012727 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.cc +++ b/tensorflow/compiler/tf2xla/xla_resource.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" @@ -39,9 +40,29 @@ namespace tensorflow { } } +/*static*/ std::unique_ptr XlaResource::CreateStack( + string name, DataType type, int64 max_size) { + return absl::make_unique( + XlaResource::kStack, /*arg_num=*/-1, std::move(name), type, TensorShape(), + /*initial_value=*/xla::XlaOp(), + /*max_array_size=*/max_size, + /*tensor_array_gradients=*/std::set{}, + /*tensor_array_multiple_writes_aggregate=*/false); +} + +/*static*/ std::unique_ptr XlaResource::CreateTensorArray( + string name, DataType type, TensorShape shape, xla::XlaOp initial_value, + int64 max_array_size) { + return absl::make_unique( + XlaResource::kTensorArray, /*arg_num=*/-1, std::move(name), type, shape, + initial_value, max_array_size, + /*tensor_array_gradients=*/std::set{}, + /*tensor_array_multiple_writes_aggregate=*/false); +} + XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type, TensorShape shape, const xla::XlaOp& initial_value, - int64 tensor_array_size, + int64 max_array_size, const std::set& tensor_array_gradients, bool tensor_array_multiple_writes_aggregate) : kind_(kind), @@ -51,7 +72,7 @@ XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type, shape_(std::move(shape)), value_(initial_value), initial_value_(initial_value), - tensor_array_size_(tensor_array_size), + max_array_size_(max_array_size), tensor_array_multiple_writes_aggregate_( tensor_array_multiple_writes_aggregate) { CHECK(kind_ != kInvalid); @@ -60,7 +81,7 @@ XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type, tensor_array_gradients_[gradient].reset(new XlaResource( /*kind=*/kTensorArray, /*arg_num=*/-1, /*name=*/absl::StrCat("TensorArrayGrad: ", name_), type_, shape_, - xla::XlaOp(), tensor_array_size_, /*tensor_array_gradients=*/{}, + xla::XlaOp(), max_array_size_, /*tensor_array_gradients=*/{}, /*tensor_array_multiple_writes_aggregate=*/true)); } } @@ -113,7 +134,7 @@ Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) { } case kTensorArray: { TensorShape ta_shape; - ta_shape.AddDim(tensor_array_size_); + ta_shape.AddDim(max_array_size_); ta_shape.AppendShape(shape_); value_ = xla::Broadcast(XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes()); @@ -121,7 +142,7 @@ Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) { } case kStack: { TensorShape ta_shape; - ta_shape.AddDim(tensor_array_size_); + ta_shape.AddDim(max_array_size_); ta_shape.AppendShape(shape_); value_ = xla::Tuple(builder, {xla::Broadcast(XlaHelpers::Zero(builder, type_), @@ -146,14 +167,14 @@ Status XlaResource::GetOrCreateTensorArrayGradient(const string& source, std::unique_ptr& gradient = tensor_array_gradients_[source]; if (!gradient) { TensorShape ta_shape; - ta_shape.AddDim(tensor_array_size_); + ta_shape.AddDim(max_array_size_); ta_shape.AppendShape(shape_); xla::XlaOp gradient_value = xla::Broadcast(XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes()); gradient.reset( new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1, /*name=*/absl::StrCat("TensorArrayGrad: ", name_), - type_, shape_, gradient_value, tensor_array_size_, + type_, shape_, gradient_value, max_array_size_, /*tensor_array_gradients=*/{}, /*tensor_array_multiple_writes_aggregate=*/true)); } diff --git a/tensorflow/compiler/tf2xla/xla_resource.h b/tensorflow/compiler/tf2xla/xla_resource.h index 857b9a928bb..736588bb8b8 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.h +++ b/tensorflow/compiler/tf2xla/xla_resource.h @@ -38,9 +38,18 @@ class XlaResource { }; static absl::string_view KindToString(Kind kind); + // Creates a new Stack resource. + static std::unique_ptr CreateStack(string name, DataType type, + int64 max_size); + + // Creates a new TensorArray resource. + static std::unique_ptr CreateTensorArray( + string name, DataType type, TensorShape shape, xla::XlaOp initial_value, + int64 max_array_size); + XlaResource(Kind kind, int arg_num, string name, DataType type, TensorShape shape, const xla::XlaOp& initial_value, - int64 tensor_array_size, + int64 max_array_size, const std::set& tensor_array_gradients, bool tensor_array_multiple_writes_aggregate); @@ -119,12 +128,12 @@ class XlaResource { // TODO(phawkins): refactor this code to use subclasses, rather than putting // kind-specific fields in XlaResource. - // 'tensor_array_size' stores the expected size of the TensorArray or Stack. + // 'max_array_size' stores the expected size of the TensorArray or Stack. // We need to store this since sometimes TensorArrays must be initialized // lazily since we do not know the element shape at construction time. // Used by both TensorArrays and Stacks. - int64 tensor_array_size() const { return tensor_array_size_; } - void set_tensor_array_size(int64 size) { tensor_array_size_ = size; } + int64 max_array_size() const { return max_array_size_; } + void set_max_array_size(int64 size) { max_array_size_ = size; } bool tensor_array_multiple_writes_aggregate() const { return tensor_array_multiple_writes_aggregate_; @@ -151,7 +160,7 @@ class XlaResource { xla::XlaOp value_; xla::XlaOp initial_value_; - int64 tensor_array_size_ = -1; + int64 max_array_size_ = -1; bool tensor_array_multiple_writes_aggregate_ = false; std::map> tensor_array_gradients_; diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 91096cf1d04..4360e085796 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -226,12 +226,14 @@ cc_library( "index_util.cc", "layout_util.cc", "primitive_util.cc", + "shape.cc", "shape_util.cc", ], hdrs = [ "index_util.h", "layout_util.h", "primitive_util.h", + "shape.h", "shape_util.h", ], visibility = ["//visibility:public"], @@ -254,6 +256,23 @@ cc_library( ], ) +tf_cc_test( + name = "shape_test", + srcs = ["shape_test.cc"], + deps = [ + ":shape_util", + ":status_macros", + ":test", + ":test_helpers", + ":types", + ":util", + ":xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", + ], +) + tf_cc_test( name = "shape_util_test", srcs = ["shape_util_test.cc"], @@ -745,6 +764,8 @@ cc_library( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/compiler/xla/array2d.h b/tensorflow/compiler/xla/array2d.h index 782c966b4c5..e4aca98f67d 100644 --- a/tensorflow/compiler/xla/array2d.h +++ b/tensorflow/compiler/xla/array2d.h @@ -104,7 +104,7 @@ std::unique_ptr> MakeLinspaceArray2D(double from, double to, int64 count = n1 * n2; NativeT step = static_cast((count > 1) ? (to - from) / (count - 1) : 0); - auto set = [&array, n1, n2](int64 index, NativeT value) { + auto set = [&array, n2](int64 index, NativeT value) { (*array)(index / n2, index % n2) = value; }; for (int64 i = 0; i < count - 1; ++i) { diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index 42da0ebf499..fe99564d3c6 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -81,6 +81,7 @@ cc_library( "//tensorflow/core:lib", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) @@ -90,11 +91,12 @@ cc_library( srcs = ["executable_build_options.cc"], hdrs = ["executable_build_options.h"], deps = [ + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla:xla_proto", "//tensorflow/compiler/xla/service:device_memory_allocator", - "//tensorflow/core:lib", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", @@ -191,6 +193,7 @@ cc_library( hdrs = ["xla_computation.h"], visibility = ["//visibility:public"], deps = [ + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index eef2844e0df..74b76f92994 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/execution_options_util.h" @@ -42,7 +43,7 @@ StatusOr Client::Transfer(const GlobalData& data, TransferToClientRequest request; *request.mutable_data() = data.handle(); if (shape_with_layout != nullptr) { - *request.mutable_shape_with_layout() = *shape_with_layout; + *request.mutable_shape_with_layout() = shape_with_layout->ToProto(); } TransferToClientResponse response; @@ -123,7 +124,7 @@ StatusOr Client::TransferFromOutfeed( } request.set_replica_id(replica_id); if (shape_with_layout != nullptr) { - *request.mutable_shape_with_layout() = *shape_with_layout; + *request.mutable_shape_with_layout() = shape_with_layout->ToProto(); } TransferFromOutfeedResponse response; @@ -170,11 +171,14 @@ StatusOr Client::ExecuteAndTransfer( std::unique_ptr data, Execute(computation, arguments, execution_options, execution_profile)); - const Shape* shape_with_output_layout = nullptr; + absl::optional shape_with_output_layout; if (execution_options && execution_options->has_shape_with_output_layout()) { - shape_with_output_layout = &execution_options->shape_with_output_layout(); + shape_with_output_layout = + Shape(execution_options->shape_with_output_layout()); } - return Transfer(*data, shape_with_output_layout); + return Transfer(*data, shape_with_output_layout.has_value() + ? &(*shape_with_output_layout) + : nullptr); } StatusOr Client::ComputeConstant(const XlaComputation& computation, @@ -229,7 +233,7 @@ StatusOr Client::Compile( // The argument shapes affect how the computation is compiled. for (const auto& arg_shape : argument_shapes) { - *request.add_input_shape_with_layout() = arg_shape; + *request.add_input_shape_with_layout() = arg_shape.ToProto(); } CompileResponse response; @@ -458,7 +462,7 @@ StatusOr Client::GetShape(const GlobalData& data) { return s; } - return response.shape(); + return Shape(response.shape()); } StatusOr Client::ExecutionStatsAsString( diff --git a/tensorflow/compiler/xla/client/executable_build_options.cc b/tensorflow/compiler/xla/client/executable_build_options.cc index 0f1745366b7..1f594e551af 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.cc +++ b/tensorflow/compiler/xla/client/executable_build_options.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "absl/strings/str_format.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/shape_util.h" namespace xla { @@ -39,6 +40,13 @@ ExecutableBuildOptions& ExecutableBuildOptions::set_device_ordinal( int ExecutableBuildOptions::device_ordinal() const { return device_ordinal_; } +DebugOptions* ExecutableBuildOptions::mutable_debug_options() { + if (!has_debug_options()) { + debug_options_ = GetDebugOptionsFromFlags(); + } + return &debug_options_.value(); +} + ExecutableBuildOptions& ExecutableBuildOptions::set_result_layout( const Shape& shape_with_layout) { result_layout_set_ = true; @@ -55,68 +63,10 @@ string ExecutableBuildOptions::ToString() const { if (result_layout_set_) { result_layout = ShapeUtil::HumanStringWithLayout(result_layout_); } - string generate_hlo_graph = "nullopt"; - if (generate_hlo_graph_.has_value()) { - generate_hlo_graph = generate_hlo_graph_.value(); - } return absl::StrFormat( "ExecutableBuildOptions{device_ordinal=%d, result_layout=%s, " "generate_hlo_graph=%s}", - device_ordinal_, result_layout, generate_hlo_graph); -} - -ExecutableBuildOptions& ExecutableBuildOptions::set_generate_hlo_graph( - string regex) { - generate_hlo_graph_ = std::move(regex); - return *this; -} - -const absl::optional& ExecutableBuildOptions::generate_hlo_graph() - const { - return generate_hlo_graph_; -} - -ExecutableBuildOptions& ExecutableBuildOptions::set_dump_optimized_hlo_proto_to( - absl::string_view dirpath) { - dump_optimized_hlo_proto_to_ = string(dirpath); - return *this; -} - -const absl::optional& -ExecutableBuildOptions::dump_optimized_hlo_proto_to() const { - return dump_optimized_hlo_proto_to_; -} - -ExecutableBuildOptions& -ExecutableBuildOptions::set_dump_unoptimized_hlo_proto_to( - absl::string_view dirpath) { - dump_unoptimized_hlo_proto_to_ = string(dirpath); - return *this; -} - -const absl::optional& -ExecutableBuildOptions::dump_unoptimized_hlo_proto_to() const { - return dump_unoptimized_hlo_proto_to_; -} - -ExecutableBuildOptions& ExecutableBuildOptions::set_dump_per_pass_hlo_proto_to( - absl::string_view dirpath) { - dump_per_pass_hlo_proto_to_ = string(dirpath); - return *this; -} - -const absl::optional& -ExecutableBuildOptions::dump_per_pass_hlo_proto_to() const { - return dump_per_pass_hlo_proto_to_; -} - -ExecutableBuildOptions& ExecutableBuildOptions::set_hlo_profile(bool enabled) { - hlo_profile_ = enabled; - return *this; -} - -absl::optional ExecutableBuildOptions::hlo_profile() const { - return hlo_profile_; + device_ordinal_, result_layout, debug_options().xla_generate_hlo_graph()); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h index 93334db88bc..a58090253bf 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.h +++ b/tensorflow/compiler/xla/client/executable_build_options.h @@ -19,7 +19,9 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -44,6 +46,12 @@ class ExecutableBuildOptions { ExecutableBuildOptions& set_result_layout(const Shape& shape_with_layout); const Shape* result_layout() const; + // Expose access to the XLA debug options which will be passed to the + // compilation process. + bool has_debug_options() const { return debug_options_.has_value(); } + const DebugOptions& debug_options() const { return *debug_options_; } + DebugOptions* mutable_debug_options(); + // If set, this specifies an allocator that can be used to allocate temporary // space on the device during compilation. For example, the compiler might // want to run various algorithms on the device and pick the fastest one -- it @@ -55,56 +63,16 @@ class ExecutableBuildOptions { DeviceMemoryAllocator* allocator); DeviceMemoryAllocator* device_allocator() const; - // If set, specifies a regexp of HLO graphs to dump (as in DebugOptions). - ExecutableBuildOptions& set_generate_hlo_graph(string regex); - const absl::optional& generate_hlo_graph() const; - - // If set, specifies a dirpath to dump the end-of-optimization-pipeline HLO - // protobuf to (as in DebugOptions). - ExecutableBuildOptions& set_dump_optimized_hlo_proto_to( - absl::string_view dirpath); - const absl::optional& dump_optimized_hlo_proto_to() const; - - // If set, specifies a dirpath to dump the start-of-optimization-pipeline HLO - // protobuf to (as in DebugOptions). - ExecutableBuildOptions& set_dump_unoptimized_hlo_proto_to( - absl::string_view dirpath); - const absl::optional& dump_unoptimized_hlo_proto_to() const; - - // If set, specifies a dirpath to dump the per-pass-in-pipeline HLO protobufs - // to (as in DebugOptions). - ExecutableBuildOptions& set_dump_per_pass_hlo_proto_to( - absl::string_view dirpath); - const absl::optional& dump_per_pass_hlo_proto_to() const; - - // If true, specifies that we should record an HLO profile during execution - // and log it after execution (as in DebugOptions). If nullopt the default is - // used. - ExecutableBuildOptions& set_hlo_profile(bool enabled); - absl::optional hlo_profile() const; - - void add_disabled_hlo_pass(absl::string_view pass_name) { - disabled_hlo_passes_.push_back(std::string(pass_name)); - } - const absl::Span disabled_hlo_passes() const { - return disabled_hlo_passes_; - } - // Returns a string representation of the build options, suitable for // debugging. string ToString() const; private: - absl::optional hlo_profile_; int device_ordinal_ = -1; Shape result_layout_; bool result_layout_set_ = false; - absl::optional generate_hlo_graph_; - absl::optional dump_optimized_hlo_proto_to_; - absl::optional dump_unoptimized_hlo_proto_to_; - absl::optional dump_per_pass_hlo_proto_to_; + absl::optional debug_options_; DeviceMemoryAllocator* device_allocator_ = nullptr; - std::vector disabled_hlo_passes_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index f833ddcd323..f0f530d7d77 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -104,13 +104,17 @@ xla_test( ) cc_library( - name = "numeric", - srcs = ["numeric.cc"], - hdrs = ["numeric.h"], + name = "matrix", + srcs = ["matrix.cc"], + hdrs = ["matrix.h"], deps = [ ":arithmetic", ":constants", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", "@com_google_absl//absl/types:span", @@ -118,11 +122,12 @@ cc_library( ) xla_test( - name = "numeric_test", - srcs = ["numeric_test.cc"], + name = "matrix_test", + srcs = ["matrix_test.cc"], tags = ["enable_for_xla_interpreter"], deps = [ - ":numeric", + ":matrix", + ":slicing", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", @@ -164,7 +169,6 @@ cc_library( deps = [ ":constants", ":math", - ":numeric", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", @@ -173,13 +177,46 @@ cc_library( ], ) +cc_library( + name = "slicing", + srcs = ["slicing.cc"], + hdrs = ["slicing.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "@com_google_absl//absl/types:span", + ], +) + +xla_test( + name = "slicing_test", + srcs = ["slicing_test.cc"], + tags = ["enable_for_xla_interpreter"], + deps = [ + ":slicing", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + cc_library( name = "sorting", srcs = ["sorting.cc"], hdrs = ["sorting.h"], deps = [ - ":numeric", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", ], @@ -188,10 +225,6 @@ cc_library( xla_test( name = "sorting_test", srcs = ["sorting_test.cc"], - blacklisted_backends = [ - "cpu", - "gpu", - ], tags = ["enable_for_xla_interpreter"], deps = [ ":sorting", diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc index 08a887a6e46..36fdda39b41 100644 --- a/tensorflow/compiler/xla/client/lib/math.cc +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -268,17 +268,16 @@ XlaOp Digamma(XlaOp input) { // Implements Banker's rounding: numbers that are equidistant between two // integers are rounded towards even. XlaOp RoundToEven(XlaOp x) { - auto half = xla::ScalarLike(x, 0.5); - auto one = xla::ScalarLike(x, 1.0); - auto two = xla::ScalarLike(x, 2.0); + auto half = ScalarLike(x, 0.5); + auto one = ScalarLike(x, 1.0); + auto two = ScalarLike(x, 2.0); - auto round_val = xla::Floor(x); + auto round_val = Floor(x); auto fraction = x - round_val; - auto nearest_even_int = round_val - two * xla::Floor(half * x); - auto is_odd = xla::Eq(nearest_even_int, one); - return xla::Select(xla::Or(xla::Gt(fraction, half), - xla::And(xla::Eq(fraction, half), is_odd)), - round_val + one, round_val); + auto nearest_even_int = round_val - two * Floor(half * x); + auto is_odd = Eq(nearest_even_int, one); + return Select(Or(Gt(fraction, half), And(Eq(fraction, half), is_odd)), + round_val + one, round_val); } // Trigonometric functions. @@ -320,4 +319,13 @@ XlaOp Cosh(XlaOp x) { return (Exp(x) + Exp(-x)) * ScalarLike(x, 0.5); } XlaOp Sinh(XlaOp x) { return (Exp(x) - Exp(-x)) * ScalarLike(x, 0.5); } +XlaOp MaybeConjugate(XlaOp x, bool conjugate) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + auto perform_conj = shape.element_type() == C64 && conjugate; + return perform_conj ? Conj(x) : x; + }); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/math.h b/tensorflow/compiler/xla/client/lib/math.h index 3f06d04b9ae..17612bf9fdc 100644 --- a/tensorflow/compiler/xla/client/lib/math.h +++ b/tensorflow/compiler/xla/client/lib/math.h @@ -86,6 +86,10 @@ XlaOp Cosh(XlaOp x); // Computes the hyperbolic sine of 'x'. XlaOp Sinh(XlaOp x); +// Applies a complex conjugation operation if `a` is complex and `conjugate` +// is true, otherwise returns its argument. +xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_ diff --git a/tensorflow/compiler/xla/client/lib/matrix.cc b/tensorflow/compiler/xla/client/lib/matrix.cc new file mode 100644 index 00000000000..ffd744d1908 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/matrix.cc @@ -0,0 +1,185 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/matrix.h" + +#include +#include + +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { + +XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, + int64 n) { + auto a = Iota(builder, type, m); + auto b = Iota(builder, type, n); + auto indicator = Eq(a, Broadcast(b, {m}), /*broadcast_dimensions=*/{0}); + return ConvertElementType(indicator, type); +} + +XlaOp GetMatrixDiagonal(XlaOp x) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = ShapeUtil::Rank(shape); + TF_RET_CHECK(n_dims >= 2); + const int64 m = shape.dimensions(n_dims - 2); + const int64 n = shape.dimensions(n_dims - 1); + absl::Span major_dims = + AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2); + auto a = Iota(builder, U32, n); + auto b = Iota(builder, U32, m); + auto indicator = Eq(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); + auto mask = Broadcast(indicator, major_dims); + + // TPUs don't support S64 add reduction at the moment. But fortunately + // OR-reductions work just as well for integers. + XlaComputation reducer = + primitive_util::IsIntegralType(shape.element_type()) + ? CreateScalarOrComputation(shape.element_type(), builder) + : CreateScalarAddComputation(shape.element_type(), builder); + + return Reduce(Select(mask, x, Zeros(builder, shape)), ScalarLike(x, 0), + reducer, {m >= n ? n_dims - 2 : n_dims - 1}); + }); +} + +XlaOp Triangle(XlaOp x, bool lower) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = ShapeUtil::Rank(shape); + TF_RET_CHECK(n_dims >= 2); + const int64 m = shape.dimensions(n_dims - 2); + const int64 n = shape.dimensions(n_dims - 1); + absl::Span major_dims = + AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2); + auto a = Iota(builder, U32, n); + auto b = Iota(builder, U32, m); + XlaOp indicator; + if (lower) { + indicator = Ge(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); + } else { + indicator = Le(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); + } + auto mask = Broadcast(indicator, major_dims); + + return Select(mask, x, Zeros(builder, shape)); + }); +} + +XlaOp UpperTriangle(XlaOp x) { return Triangle(x, false); } + +XlaOp LowerTriangle(XlaOp x) { return Triangle(x, true); } + +XlaOp BatchDot(XlaOp x, XlaOp y, PrecisionConfig::Precision precision) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x)); + TF_ASSIGN_OR_RETURN(Shape y_shape, builder->GetShape(y)); + + // Check that both tensors have the same number of dimensions. There must be + // at least two (the batch dimensions can be empty). + if (ShapeUtil::Rank(x_shape) != ShapeUtil::Rank(y_shape)) { + return InvalidArgument( + "Arguments to BatchDot have different ranks: %s vs. %s", + ShapeUtil::HumanString(x_shape), ShapeUtil::HumanString(y_shape)); + } + const int ndims = ShapeUtil::Rank(x_shape); + if (ndims < 2) { + return InvalidArgument( + "Arguments to BatchDot must have rank >= 2: got %d", ndims); + } + + // The batch dimensions must be equal and the matrix dimensions must be + // valid. + std::vector batch_dimension_numbers; + for (int i = 0; i < ndims - 2; ++i) { + if (x_shape.dimensions(i) != y_shape.dimensions(i)) { + return InvalidArgument( + "Dimension %d of inputs to BatchDot must be equal: shapes %s vs %s", + i, ShapeUtil::HumanString(x_shape), + ShapeUtil::HumanString(y_shape)); + } + batch_dimension_numbers.push_back(i); + } + + int x_inner_dim = ndims - 1; + int y_inner_dim = ndims - 2; + if (x_shape.dimensions(x_inner_dim) != y_shape.dimensions(y_inner_dim)) { + return InvalidArgument( + "Dimensions %d and %d of arguments to BatchDot must be equal: " + "shapes %s vs %s", + x_inner_dim, y_inner_dim, ShapeUtil::HumanString(x_shape), + ShapeUtil::HumanString(y_shape)); + } + + // Check for zero lhs/rhs dim size. + if (ShapeUtil::IsZeroElementArray(x_shape) || + ShapeUtil::IsZeroElementArray(y_shape)) { + std::vector dimensions(batch_dimension_numbers.size()); + for (int i = 0; i < batch_dimension_numbers.size(); ++i) { + dimensions[i] = x_shape.dimensions(batch_dimension_numbers[i]); + } + int x_outer_dim = ndims - 2; + int y_outer_dim = ndims - 1; + dimensions.push_back(x_shape.dimensions(x_outer_dim)); + dimensions.push_back(y_shape.dimensions(y_outer_dim)); + return Broadcast( + ConstantLiteral(builder, LiteralUtil::Zero(x_shape.element_type())), + dimensions); + } + + PrecisionConfig precision_proto; + precision_proto.add_operand_precision(precision); + precision_proto.add_operand_precision(precision); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(x_inner_dim); + dot_dnums.add_rhs_contracting_dimensions(y_inner_dim); + for (auto batch_dimension_number : batch_dimension_numbers) { + dot_dnums.add_lhs_batch_dimensions(batch_dimension_number); + dot_dnums.add_rhs_batch_dimensions(batch_dimension_number); + } + + return DotGeneral(x, y, dot_dnums, &precision_proto); + }); +} + +XlaOp TransposeInMinorDims(XlaOp x) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = ShapeUtil::Rank(shape); + TF_RET_CHECK(n_dims >= 2); + std::vector permutation(n_dims); + std::iota(permutation.begin(), permutation.end(), 0); + std::swap(permutation[n_dims - 1], permutation[n_dims - 2]); + return Transpose(x, permutation); + }); +} + +XlaOp MaybeTransposeInMinorDims(XlaOp x, bool transpose) { + return transpose ? TransposeInMinorDims(x) : x; +} +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/numeric.h b/tensorflow/compiler/xla/client/lib/matrix.h similarity index 56% rename from tensorflow/compiler/xla/client/lib/numeric.h rename to tensorflow/compiler/xla/client/lib/matrix.h index efd8cdc2572..8856f99c7a0 100644 --- a/tensorflow/compiler/xla/client/lib/numeric.h +++ b/tensorflow/compiler/xla/client/lib/matrix.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_ -#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_ +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATRIX_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATRIX_H_ #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/types.h" @@ -22,9 +22,6 @@ limitations under the License. namespace xla { -// Returns a rank 1 tensor of `type` containing values [0, 1, 2, ...]. -XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size); - // Returns an m x n matrix with 1s on the diagonal elements, zeros everywhere // else. XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, int64 n); @@ -43,6 +40,34 @@ XlaOp UpperTriangle(XlaOp x); // Get the lower triangle part of the last two dimensions XlaOp LowerTriangle(XlaOp x); +// Multiplies slices of two tensors in batches. + +// Multiplies all slices of `Tensor` `x` and `y` (each slice can be +// viewed as an element of a batch), and arranges the individual results +// in a single output tensor of the same batch size. +// +// The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]` +// and `[..., r_y, c_y]`. +// +// The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where: +// +// r_o = c_x if transpose_x else r_x +// c_o = r_y if transpose_y else c_y +// +// It is computed as: +// +// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) +xla::XlaOp BatchDot( + xla::XlaOp x, xla::XlaOp y, + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); + +// Transposes a stack of matrices `x` by swapping the last two dimensions. +xla::XlaOp TransposeInMinorDims(xla::XlaOp x); + +// Transposes `x` in its minor dimensions if `transpose` is true, otherwise +// returns `x` unchanged. +xla::XlaOp MaybeTransposeInMinorDims(xla::XlaOp x, bool transpose); + } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_ +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATRIX_H_ diff --git a/tensorflow/compiler/xla/client/lib/numeric_test.cc b/tensorflow/compiler/xla/client/lib/matrix_test.cc similarity index 53% rename from tensorflow/compiler/xla/client/lib/numeric_test.cc rename to tensorflow/compiler/xla/client/lib/matrix_test.cc index 7d6aedd4946..0593a7517ac 100644 --- a/tensorflow/compiler/xla/client/lib/numeric_test.cc +++ b/tensorflow/compiler/xla/client/lib/matrix_test.cc @@ -13,7 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" + +#include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -24,13 +26,13 @@ limitations under the License. namespace xla { namespace { -class NumericTest : public ClientLibraryTestBase { +class MatrixTest : public ClientLibraryTestBase { protected: template void TestMatrixDiagonal(); }; -XLA_TEST_F(NumericTest, Triangle) { +XLA_TEST_F(MatrixTest, Triangle) { XlaBuilder builder(TestName()); Array3D input(2, 3, 4); input.FillIota(0); @@ -45,7 +47,7 @@ XLA_TEST_F(NumericTest, Triangle) { } template -void NumericTest::TestMatrixDiagonal() { +void MatrixTest::TestMatrixDiagonal() { XlaBuilder builder("GetMatrixDiagonal"); Array3D input(2, 3, 4); input.FillIota(0); @@ -58,11 +60,46 @@ void NumericTest::TestMatrixDiagonal() { ComputeAndCompareR2(&builder, expected, {a_data.get()}); } -XLA_TEST_F(NumericTest, GetMatrixDiagonal_S32) { TestMatrixDiagonal(); } +XLA_TEST_F(MatrixTest, GetMatrixDiagonal_S32) { TestMatrixDiagonal(); } -XLA_TEST_F(NumericTest, GetMatrixDiagonal_S64) { TestMatrixDiagonal(); } +XLA_TEST_F(MatrixTest, GetMatrixDiagonal_S64) { TestMatrixDiagonal(); } -XLA_TEST_F(NumericTest, GetMatrixDiagonal_F32) { TestMatrixDiagonal(); } +XLA_TEST_F(MatrixTest, GetMatrixDiagonal_F32) { TestMatrixDiagonal(); } +Array3D BatchedAValsFull() { + return {{ + {2, 0, 1, 2}, + {3, 6, 0, 1}, + {4, 7, 9, 0}, + {5, 8, 10, 11}, + }, + { + {16, 24, 8, 12}, + {24, 61, 82, 48}, + {8, 82, 456, 106}, + {12, 48, 106, 62}, + }}; +} + +XLA_TEST_F(MatrixTest, RowBatchDot) { + XlaBuilder builder(TestName()); + + int n = 4; + + XlaOp a, row, index; + auto a_data = + CreateR3Parameter(BatchedAValsFull(), 0, "a", &builder, &a); + auto row_data = CreateR3Parameter({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1, + "row", &builder, &row); + // Select {{3, 6, 0, 1}, {24, 61, 82, 48}} out of BatchedAValsFull(). + auto index_data = CreateR0Parameter(1, 2, "index", &builder, &index); + + auto l_index = DynamicSliceInMinorDims( + a, {index, ConstantR0(&builder, 0)}, {1, n}); + BatchDot(l_index, TransposeInMinorDims(row)); + + ComputeAndCompareR3(&builder, {{{33}}, {{292}}}, + {a_data.get(), row_data.get(), index_data.get()}); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/numeric.cc b/tensorflow/compiler/xla/client/lib/numeric.cc deleted file mode 100644 index 377654220b5..00000000000 --- a/tensorflow/compiler/xla/client/lib/numeric.cc +++ /dev/null @@ -1,89 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "absl/types/span.h" -#include "tensorflow/compiler/xla/client/lib/arithmetic.h" -#include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" - -namespace xla { - -XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, - int64 n) { - auto a = Iota(builder, type, m); - auto b = Iota(builder, type, n); - auto indicator = Eq(a, Broadcast(b, {m}), /*broadcast_dimensions=*/{0}); - return ConvertElementType(indicator, type); -} - -XlaOp GetMatrixDiagonal(XlaOp x) { - XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { - TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64 n_dims = ShapeUtil::Rank(shape); - TF_RET_CHECK(n_dims >= 2); - const int64 m = shape.dimensions(n_dims - 2); - const int64 n = shape.dimensions(n_dims - 1); - absl::Span major_dims = - AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2); - auto a = Iota(builder, U32, n); - auto b = Iota(builder, U32, m); - auto indicator = Eq(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); - auto mask = Broadcast(indicator, major_dims); - - // TPUs don't support S64 add reduction at the moment. But fortunately - // OR-reductions work just as well for integers. - XlaComputation reducer = - primitive_util::IsIntegralType(shape.element_type()) - ? CreateScalarOrComputation(shape.element_type(), builder) - : CreateScalarAddComputation(shape.element_type(), builder); - - return Reduce(Select(mask, x, Zeros(builder, shape)), ScalarLike(x, 0), - reducer, {m >= n ? n_dims - 2 : n_dims - 1}); - }); -} - -XlaOp Triangle(XlaOp x, bool lower) { - XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { - TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64 n_dims = ShapeUtil::Rank(shape); - TF_RET_CHECK(n_dims >= 2); - const int64 m = shape.dimensions(n_dims - 2); - const int64 n = shape.dimensions(n_dims - 1); - absl::Span major_dims = - AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2); - auto a = Iota(builder, U32, n); - auto b = Iota(builder, U32, m); - xla::XlaOp indicator; - if (lower) { - indicator = Ge(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); - } else { - indicator = Le(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); - } - auto mask = Broadcast(indicator, major_dims); - - return Select(mask, x, Zeros(builder, shape)); - }); -} - -XlaOp UpperTriangle(XlaOp x) { return Triangle(x, false); } - -XlaOp LowerTriangle(XlaOp x) { return Triangle(x, true); } - -} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/prng.cc b/tensorflow/compiler/xla/client/lib/prng.cc index c6f68c8ee2f..85b9e1827dc 100644 --- a/tensorflow/compiler/xla/client/lib/prng.cc +++ b/tensorflow/compiler/xla/client/lib/prng.cc @@ -18,7 +18,6 @@ limitations under the License. #include "absl/base/casts.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/math.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/util.h" diff --git a/tensorflow/compiler/xla/client/lib/slicing.cc b/tensorflow/compiler/xla/client/lib/slicing.cc new file mode 100644 index 00000000000..f8c7df3ff51 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/slicing.cc @@ -0,0 +1,134 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/slicing.h" + +namespace xla { + +XlaOp SliceInMinorDims(XlaOp x, absl::Span start, + absl::Span end) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_RET_CHECK(start.size() == end.size()); + int64 n_minor_dims = start.size(); + + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + + const int64 n_dims = ShapeUtil::Rank(shape); + TF_RET_CHECK(n_minor_dims <= n_dims); + auto major_dims = AsInt64Slice(shape.dimensions()) + .subspan( + /*pos=*/0, + /*len=*/n_dims - n_minor_dims); + + // Prepends 0s in the major dim + std::vector padded_start(n_dims, 0); + std::copy(start.begin(), start.end(), + padded_start.begin() + major_dims.size()); + + // Prepends the shape of the major dims. + std::vector padded_end(n_dims); + std::copy(major_dims.begin(), major_dims.end(), padded_end.begin()); + std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size()); + + std::vector strides(n_dims, 1); + return Slice(x, padded_start, padded_end, strides); + }); +} + +XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span start) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + // TODO(phawkins): make int64 work on all backends, remove the int32 cast. + std::vector start_as_int32(start.begin(), start.end()); + auto start_constant = ConstantR1(builder, start_as_int32); + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = ShapeUtil::Rank(shape); + TF_ASSIGN_OR_RETURN(Shape start_constant_shape, + builder->GetShape(start_constant)); + const int64 start_length = + ShapeUtil::GetDimension(start_constant_shape, -1); + TF_RET_CHECK(start_length == n_dims); + return DynamicUpdateSlice(x, update, start_constant); + }); +} + +XlaOp UpdateSliceInMinorDims(XlaOp x, XlaOp update, + absl::Span start) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = ShapeUtil::Rank(shape); + const int64 n_minor_dims = start.size(); + TF_RET_CHECK(n_minor_dims <= n_dims); + std::vector padded_start(n_dims, 0); + std::copy(start.begin(), start.end(), + padded_start.begin() + (n_dims - n_minor_dims)); + return UpdateSlice(x, update, padded_start); + }); +} + +namespace { + +std::vector ConcatVectors(absl::Span xs, + absl::Span ys) { + std::vector output(xs.size() + ys.size()); + std::copy(xs.begin(), xs.end(), output.begin()); + std::copy(ys.begin(), ys.end(), output.begin() + xs.size()); + return output; +} + +XlaOp PrependZerosInMajorDims(XlaOp x, absl::Span starts) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = ShapeUtil::Rank(shape); + auto zero = Reshape(ConstantR0(builder, 0), {1}); + std::vector padded_starts(n_dims, zero); + for (int i = 0; i < starts.size(); ++i) { + padded_starts[n_dims - starts.size() + i] = Reshape(starts[i], {1}); + } + return ConcatInDim(builder, padded_starts, 0); + }); +} + +} // namespace + +XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span starts, + absl::Span sizes) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = ShapeUtil::Rank(shape); + int64 n_minor_dims = starts.size(); + TF_RET_CHECK(n_minor_dims == sizes.size()); + TF_RET_CHECK(n_minor_dims <= n_dims); + auto major_dims = AsInt64Slice(shape.dimensions()) + .subspan( + /*pos=*/0, + /*len=*/n_dims - sizes.size()); + auto padded_starts = PrependZerosInMajorDims(x, starts); + auto padded_sizes = ConcatVectors(major_dims, sizes); + return DynamicSlice(x, padded_starts, padded_sizes); + }); +} + +XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update, + absl::Span starts) { + auto padded_starts = PrependZerosInMajorDims(x, starts); + return DynamicUpdateSlice(x, update, padded_starts); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/slicing.h b/tensorflow/compiler/xla/client/lib/slicing.h new file mode 100644 index 00000000000..6c482a38b54 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/slicing.h @@ -0,0 +1,48 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/types.h" + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SLICING_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SLICING_H_ + +namespace xla { + +// Updates a slice of 'x', i.e., +// x[start[0], ..., start[n]] = update +XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span start); + +// Performs a slice in the minor dimensions of a tensor. +// x[..., start[0]:end[0], ..., start[n]:end[n]] +XlaOp SliceInMinorDims(XlaOp x, absl::Span start, + absl::Span end); + +// Updates a slice of 'x', where 'start' contains a list of minor dimensions: +// x[..., start[0]:..., ..., start[n]:...] = update +XlaOp UpdateSliceInMinorDims(XlaOp x, XlaOp update, + absl::Span start); + +// Performs a dynamic slice in the minor dimensions of a tensor. +XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span starts, + absl::Span sizes); + +XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update, + absl::Span starts); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SLICING_H_ diff --git a/tensorflow/compiler/tf2xla/lib/util_test.cc b/tensorflow/compiler/xla/client/lib/slicing_test.cc similarity index 67% rename from tensorflow/compiler/tf2xla/lib/util_test.cc rename to tensorflow/compiler/xla/client/lib/slicing_test.cc index 442fe92c34c..8d362119e01 100644 --- a/tensorflow/compiler/tf2xla/lib/util_test.cc +++ b/tensorflow/compiler/xla/client/lib/slicing_test.cc @@ -13,28 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" -#include -#include -#include - -#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" -#include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" -#include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/status_test_util.h" -namespace tensorflow { +namespace xla { namespace { -using UtilTest = xla::ClientLibraryTestBase; -using UtilLeftLookingTest = xla::ClientLibraryTestBase; +using SlicingTest = xla::ClientLibraryTestBase; xla::Array2D BValsRight() { return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}; @@ -63,7 +54,7 @@ xla::Array3D BatchedAValsFull() { }}; } -XLA_TEST_F(UtilTest, Simple2dLookup) { +XLA_TEST_F(SlicingTest, Simple2dLookup) { xla::XlaBuilder builder(TestName()); xla::XlaOp a, x, y; @@ -77,7 +68,7 @@ XLA_TEST_F(UtilTest, Simple2dLookup) { xla::ErrorSpec(1e-2, 1e-2)); } -XLA_TEST_F(UtilTest, Simple3dLookup) { +XLA_TEST_F(SlicingTest, Simple3dLookup) { xla::XlaBuilder builder(TestName()); xla::XlaOp a, index; @@ -92,7 +83,7 @@ XLA_TEST_F(UtilTest, Simple3dLookup) { {a_data.get(), index_data.get()}); } -XLA_TEST_F(UtilTest, SimpleSliceUpdate) { +XLA_TEST_F(SlicingTest, SimpleSliceUpdate) { xla::XlaBuilder builder(TestName()); xla::XlaOp a, b, x, y; @@ -111,26 +102,5 @@ XLA_TEST_F(UtilTest, SimpleSliceUpdate) { {a_data.get(), b_data.get(), x_data.get(), y_data.get()}); } -XLA_TEST_F(UtilTest, RowBatchDot) { - xla::XlaBuilder builder(TestName()); - - int n = 4; - - xla::XlaOp a, row, index; - auto a_data = - CreateR3Parameter(BatchedAValsFull(), 0, "a", &builder, &a); - auto row_data = CreateR3Parameter({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1, - "row", &builder, &row); - // Select {{3, 6, 0, 1}, {24, 61, 82, 48}} out of BatchedAValsFull(). - auto index_data = CreateR0Parameter(1, 2, "index", &builder, &index); - - auto l_index = DynamicSliceInMinorDims( - a, {index, xla::ConstantR0(&builder, 0)}, {1, n}); - BatchDot(l_index, row, /*transpose_x=*/false, /*transpose_y=*/true); - - ComputeAndCompareR3(&builder, {{{33}}, {{292}}}, - {a_data.get(), row_data.get(), index_data.get()}); -} - } // namespace -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/sorting.cc b/tensorflow/compiler/xla/client/lib/sorting.cc index 0475fd9c94f..e8553a08bb0 100644 --- a/tensorflow/compiler/xla/client/lib/sorting.cc +++ b/tensorflow/compiler/xla/client/lib/sorting.cc @@ -14,7 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/client/lib/sorting.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/util.h" namespace xla { @@ -23,13 +25,12 @@ XlaOp TopK(XlaOp input, int64 k) { return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); int last_dim = input_shape.dimensions_size() - 1; - int last_dim_size = input_shape.dimensions(last_dim); - XlaOp iota_s32 = Iota(builder, S32, last_dim_size); + Shape iota_shape = + ShapeUtil::MakeShape(S32, AsInt64Slice(input_shape.dimensions())); + XlaOp iota_s32 = Iota(builder, iota_shape, last_dim); auto input_dims = input_shape.dimensions(); - std::vector broadcast_dims(input_dims.begin(), input_dims.end() - 1); - XlaOp broadcast_s32 = Broadcast(iota_s32, broadcast_dims); - XlaOp sort_result = Sort(Neg(input), {broadcast_s32}); + XlaOp sort_result = Sort(Neg(input), {iota_s32}); std::vector start_indices(input_shape.dimensions_size(), 0); std::vector limit_indices(input_dims.begin(), input_dims.end()); limit_indices[last_dim] = k; diff --git a/tensorflow/compiler/xla/client/lib/sorting_test.cc b/tensorflow/compiler/xla/client/lib/sorting_test.cc index fef98c99230..27ff36c7491 100644 --- a/tensorflow/compiler/xla/client/lib/sorting_test.cc +++ b/tensorflow/compiler/xla/client/lib/sorting_test.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/client/lib/sorting.h" + +#include + #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -41,6 +44,28 @@ XLA_TEST_F(SortingTest, TopK3From8Indices) { ComputeAndCompareR1(&builder, {0, 1, 2}, {}); } +// TODO(b/119930279): enable this test. +XLA_TEST_F(SortingTest, DISABLED_TopKFullSortMinInt) { + XlaBuilder builder(TestName()); + auto x_rev = ConstantR1(&builder, {std::numeric_limits::min(), + std::numeric_limits::min() + 1, + std::numeric_limits::max()}); + xla::GetTupleElement(xla::TopK(x_rev, 3), 1); + ComputeAndCompareR1(&builder, {2, 1, 0}, {}); +} + +XLA_TEST_F(SortingTest, NOT_TopKFullSortMinInt) { + XlaBuilder builder(TestName()); + auto x_rev = ConstantR1(&builder, {std::numeric_limits::min(), + std::numeric_limits::min() + 1, + std::numeric_limits::max()}); + xla::GetTupleElement(xla::TopK(x_rev, 3), 1); + // TopK currently negates the keys, which doesn't work correctly for + // std::numeric_limits::min(). Therefore, it will sort this key to the + // front instead of to the back. + ComputeAndCompareR1(&builder, {0, 2, 1}, {}); +} + XLA_TEST_F(SortingTest, TopKFullSort) { XlaBuilder builder(TestName()); const int kSize = 16; @@ -56,5 +81,13 @@ XLA_TEST_F(SortingTest, TopKFullSort) { ComputeAndCompareR1(&builder, inputs, {}); } +XLA_TEST_F(SortingTest, TopKFullSortWithDuplicates) { + XlaBuilder builder(TestName()); + XlaOp a; + auto a_data = CreateR1Parameter({1, 1, 2, 2, 1}, 0, "a", &builder, &a); + xla::GetTupleElement(xla::TopK(a, 5), 1); + ComputeAndCompareR1(&builder, {2, 3, 0, 1, 4}, {a_data.get()}); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index a44681f5862..a95bbf2c8c8 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -66,7 +66,7 @@ std::unique_ptr MakeFakeDataViaDeviceOrDie(const Shape& shape, XlaComputation computation = b.Build().ConsumeValueOrDie(); auto execution_options = CreateDefaultExecutionOptions(); - *execution_options.mutable_shape_with_output_layout() = shape; + *execution_options.mutable_shape_with_output_layout() = shape.ToProto(); return client->Execute(computation, /*arguments=*/{}, &execution_options) .ConsumeValueOrDie(); } @@ -98,8 +98,8 @@ std::vector> MakeFakeArgumentsOrDie( auto program_shape = computation.proto().host_program_shape(); std::vector> results; - for (const Shape& shape : program_shape.parameters()) { - results.push_back(MakeFakeDataOrDie(shape, client)); + for (const ShapeProto& shape : program_shape.parameters()) { + results.push_back(MakeFakeDataOrDie(Shape(shape), client)); } return results; } diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index f96b6c9c261..aaa5d6989ee 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -310,4 +310,28 @@ StatusOr LocalClient::ReplicaNumberToDeviceOrdinal(int replica_number) { return local_service_->ReplicaNumberToDeviceOrdinal(replica_number); } +StatusOr LocalClient::TransferToLocalServer( + const ::xla::BorrowingLiteral& literal, int device_oridinal) { + const ::xla::Shape& shape = literal.shape(); + + TF_ASSIGN_OR_RETURN( + ::xla::ScopedShapedBuffer shaped_buffer, + backend().transfer_manager()->AllocateScopedShapedBuffer( + shape, backend().memory_allocator(), device_oridinal)); + TF_ASSIGN_OR_RETURN(auto stream, + mutable_backend()->BorrowStream(device_oridinal)); + TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( + stream.get(), literal, shaped_buffer)); + std::vector<::xla::ScopedShapedBuffer> replicated_buffer; + replicated_buffer.emplace_back(std::move(shaped_buffer)); + ::xla::TransferToServerResponse result; + TF_ASSIGN_OR_RETURN(*result.mutable_data(), + local_service_->RegisterReplicatedBuffers( + std::move(replicated_buffer), + absl::StrCat("TransferToServer literal of shape ", + ::xla::ShapeUtil::HumanString(shape)))); + + return result; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index e49451ca970..ddb36680e8b 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -129,6 +129,10 @@ class LocalClient : public Client { const Literal& literal, int device_ordinal, DeviceMemoryAllocator* allocator = nullptr); + // Transfer the BorrowingLiteral to the device with the given ordinal. + StatusOr TransferToLocalServer( + const ::xla::BorrowingLiteral& literal, int device_oridinal); + // Copy the data from the device contained in the given ShapedBuffer and // return as a Literal. StatusOr ShapedBufferToLiteral(const ShapedBuffer& shaped_buffer); diff --git a/tensorflow/compiler/xla/client/sharding_builder.cc b/tensorflow/compiler/xla/client/sharding_builder.cc index 176802b33ef..fb9ea6ec3fc 100644 --- a/tensorflow/compiler/xla/client/sharding_builder.cc +++ b/tensorflow/compiler/xla/client/sharding_builder.cc @@ -36,7 +36,7 @@ OpSharding Tile(const Shape& tile_shape, const TileAssignment& tile_assignment) { OpSharding result; result.set_type(OpSharding::Type::OpSharding_Type_OTHER); - *result.mutable_tile_shape() = tile_shape; + *result.mutable_tile_shape() = tile_shape.ToProto(); for (int64 dim : tile_assignment.dimensions()) { result.add_tile_assignment_dimensions(dim); } @@ -52,7 +52,7 @@ OpSharding Tile1D(const Shape& tile_shape, int64 num_tiles) { CHECK_EQ(ShapeUtil::Rank(tile_shape), 1); std::vector dimensions(1, num_tiles); - *result.mutable_tile_shape() = tile_shape; + *result.mutable_tile_shape() = tile_shape.ToProto(); auto& tile_dimension = (*result.mutable_tile_shape()->mutable_dimensions())[0]; tile_dimension = CeilOfRatio(static_cast(tile_dimension), num_tiles); diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 0a587725d20..60df2ec3959 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -102,7 +102,7 @@ StatusOr XlaBuilder::GetShape(const XlaOp& op) const { TF_RETURN_IF_ERROR(first_error_); TF_ASSIGN_OR_RETURN(auto instr, LookUpInstruction(op)); - return instr->shape(); + return Shape(instr->shape()); } StatusOr> XlaBuilder::GetOperandShapes( @@ -155,7 +155,7 @@ StatusOr XlaBuilder::GetProgramShape(int64 root_id) const { ProgramShape program_shape; - *program_shape.mutable_result() = root_proto->shape(); + *program_shape.mutable_result() = Shape(root_proto->shape()); // Check that the parameter numbers are continuous from 0, and add parameter // shapes and names to the program shape. @@ -172,7 +172,7 @@ StatusOr XlaBuilder::GetProgramShape(int64 root_id) const { const int64 index = instr.parameter_number(); TF_RET_CHECK(index >= 0 && index < param_count) << "invalid parameter number: " << index; - *program_shape.mutable_parameters(index) = instr.shape(); + *program_shape.mutable_parameters(index) = Shape(instr.shape()); *program_shape.mutable_parameter_names(index) = instr.name(); } } @@ -239,6 +239,19 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle, visited->insert(op_handle); } +Status XlaBuilder::SetDynamicBinding(int64 dynamic_size_param_num, + ShapeIndex dynamic_size_param_index, + int64 target_param_num, + ShapeIndex target_param_index, + int64 target_dim_num) { + TF_RETURN_IF_ERROR(dynamic_parameter_binding_.Bind( + DynamicParameterBinding::DynamicParameter{dynamic_size_param_num, + dynamic_size_param_index}, + DynamicParameterBinding::DynamicDimension{ + target_param_num, target_param_index, target_dim_num})); + return Status::OK(); +} + XlaComputation XlaBuilder::BuildAndNoteError() { DCHECK(parent_builder_ != nullptr); auto build_status = Build(); @@ -275,7 +288,8 @@ StatusOr XlaBuilder::Build(int64 root_id) { HloComputationProto entry; SetProtoIdAndName(&entry, name_, kNameSeparator, GetNextId()); - TF_ASSIGN_OR_RETURN(*entry.mutable_program_shape(), GetProgramShape(root_id)); + TF_ASSIGN_OR_RETURN(ProgramShape program_shape, GetProgramShape(root_id)); + *entry.mutable_program_shape() = program_shape.ToProto(); entry.set_root_id(root_id); for (auto& instruction : instructions_) { @@ -297,6 +311,9 @@ StatusOr XlaBuilder::Build(int64 root_id) { } module->add_computations()->Swap(&entry); + *(module->mutable_dynamic_parameter_binding()) = + dynamic_parameter_binding_.ToProto(); + // Clear data held by this builder. this->instructions_.clear(); this->handle_to_index_.clear(); @@ -312,7 +329,7 @@ StatusOr XlaBuilder::InDimBroadcast( TF_RETURN_IF_ERROR(first_error_); HloInstructionProto instr; - *instr.mutable_shape() = shape; + *instr.mutable_shape() = shape.ToProto(); for (int64 dim : broadcast_dimensions) { instr.add_dimensions(dim); } @@ -363,8 +380,9 @@ XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferUnaryOpShape(unop, operand_shape)); + *instr.mutable_shape() = shape.ToProto(); return AddInstruction(std::move(instr), unop, {operand}); }); } @@ -375,9 +393,10 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferBinaryOpShape( binop, lhs_shape, rhs_shape, broadcast_dimensions)); + *instr.mutable_shape() = shape.ToProto(); const int64 lhs_rank = ShapeUtil::Rank(lhs_shape); const int64 rhs_rank = ShapeUtil::Rank(rhs_shape); @@ -391,7 +410,7 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, const Shape& from_shape = should_broadcast_lhs ? lhs_shape : rhs_shape; std::vector to_size; - for (int64 size : instr.shape().dimensions()) { + for (int64 size : shape.dimensions()) { to_size.push_back(size); } for (int64 from_dim = 0; from_dim < ShapeUtil::Rank(from_shape); @@ -411,14 +430,14 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, } TF_ASSIGN_OR_RETURN(Shape updated_lhs_shape, GetShape(updated_lhs)); - if (!ShapeUtil::SameDimensions(instr.shape(), updated_lhs_shape)) { + if (!ShapeUtil::SameDimensions(shape, updated_lhs_shape)) { TF_ASSIGN_OR_RETURN(updated_lhs, - AddBroadcastSequence(instr.shape(), updated_lhs)); + AddBroadcastSequence(shape, updated_lhs)); } TF_ASSIGN_OR_RETURN(Shape updated_rhs_shape, GetShape(updated_rhs)); - if (!ShapeUtil::SameDimensions(instr.shape(), updated_rhs_shape)) { + if (!ShapeUtil::SameDimensions(shape, updated_rhs_shape)) { TF_ASSIGN_OR_RETURN(updated_rhs, - AddBroadcastSequence(instr.shape(), updated_rhs)); + AddBroadcastSequence(shape, updated_rhs)); } return AddInstruction(std::move(instr), binop, {updated_lhs, updated_rhs}); @@ -432,30 +451,28 @@ XlaOp XlaBuilder::TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs, TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); TF_ASSIGN_OR_RETURN(const Shape& ehs_shape, GetShape(ehs)); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), - ShapeInference::InferTernaryOpShape( - triop, lhs_shape, rhs_shape, ehs_shape)); + TF_ASSIGN_OR_RETURN( + Shape shape, ShapeInference::InferTernaryOpShape(triop, lhs_shape, + rhs_shape, ehs_shape)); + *instr.mutable_shape() = shape.ToProto(); XlaOp updated_lhs = lhs; XlaOp updated_rhs = rhs; XlaOp updated_ehs = ehs; - if (!ShapeUtil::IsTuple(instr.shape())) { + if (!ShapeUtil::IsTuple(shape)) { if (!ShapeUtil::IsTuple(lhs_shape) && - !ShapeUtil::SameDimensions(instr.shape(), lhs_shape)) { + !ShapeUtil::SameDimensions(shape, lhs_shape)) { // lhs is being implicitly broadcasted. Change to explicit. - TF_ASSIGN_OR_RETURN(updated_lhs, - AddBroadcastSequence(instr.shape(), lhs)); + TF_ASSIGN_OR_RETURN(updated_lhs, AddBroadcastSequence(shape, lhs)); } if (!ShapeUtil::IsTuple(rhs_shape) && - !ShapeUtil::SameDimensions(instr.shape(), rhs_shape)) { + !ShapeUtil::SameDimensions(shape, rhs_shape)) { // rhs is being implicitly broadcasted. Change to explicit. - TF_ASSIGN_OR_RETURN(updated_rhs, - AddBroadcastSequence(instr.shape(), rhs)); + TF_ASSIGN_OR_RETURN(updated_rhs, AddBroadcastSequence(shape, rhs)); } if (!ShapeUtil::IsTuple(ehs_shape) && - !ShapeUtil::SameDimensions(instr.shape(), ehs_shape)) { + !ShapeUtil::SameDimensions(shape, ehs_shape)) { // ehs is being implicitly broadcasted. Change to explicit. - TF_ASSIGN_OR_RETURN(updated_ehs, - AddBroadcastSequence(instr.shape(), ehs)); + TF_ASSIGN_OR_RETURN(updated_ehs, AddBroadcastSequence(shape, ehs)); } } return AddInstruction(std::move(instr), triop, @@ -476,7 +493,7 @@ XlaOp XlaBuilder::Mul(const XlaOp& lhs, const XlaOp& rhs, XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - *instr.mutable_shape() = literal.shape(); + *instr.mutable_shape() = literal.shape().ToProto(); *instr.mutable_literal() = literal.ToProto(); return AddInstruction(std::move(instr), HloOpcode::kConstant); }); @@ -485,7 +502,7 @@ XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) { XlaOp XlaBuilder::Iota(const Shape& shape, int64 iota_dimension) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - *instr.mutable_shape() = shape; + *instr.mutable_shape() = shape.ToProto(); instr.add_dimensions(iota_dimension); return AddInstruction(std::move(instr), HloOpcode::kIota); }); @@ -505,10 +522,10 @@ XlaOp XlaBuilder::Call(const XlaComputation& computation, [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape, computation.GetProgramShape()); - TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferCallShape(operand_shape_ptrs, - /*to_apply=*/called_program_shape)); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCallShape( + operand_shape_ptrs, + /*to_apply=*/called_program_shape)); + *instr.mutable_shape() = shape.ToProto(); AddCalledComputation(computation, &instr); @@ -526,7 +543,7 @@ XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape, } instr.set_parameter_number(parameter_number); instr.set_name(name); - *instr.mutable_shape() = shape; + *instr.mutable_shape() = shape.ToProto(); return AddInstruction(std::move(instr), HloOpcode::kParameter); }); } @@ -556,27 +573,35 @@ XlaOp XlaBuilder::Broadcast(const XlaOp& operand, } XlaOp XlaBuilder::BroadcastInDim( - const XlaOp& operand, const Shape& shape, + const XlaOp& operand, const absl::Span out_dim_size, const absl::Span broadcast_dimensions) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); - TF_RETURN_IF_ERROR(ShapeInference::InferBroadcastShape(operand_shape, shape, - broadcast_dimensions) + // Output shape, in the case of degenerate broadcast, the out_dim_size is + // not necessarily the same as the dimension sizes of the output shape. + const auto& output_shape = + ShapeUtil::MakeShape(operand_shape.element_type(), out_dim_size); + + TF_RETURN_IF_ERROR(ShapeInference::InferBroadcastShape( + operand_shape, output_shape, broadcast_dimensions) .status()); - std::vector in_dim_size(ShapeUtil::Rank(shape)); - absl::c_copy(shape.dimensions(), in_dim_size.begin()); + std::vector in_dim_size(out_dim_size.begin(), out_dim_size.end()); for (int i = 0; i < broadcast_dimensions.size(); i++) { in_dim_size[broadcast_dimensions[i]] = operand_shape.dimensions(i); } const auto& in_dim_shape = - ShapeUtil::MakeShape(shape.element_type(), in_dim_size); + ShapeUtil::MakeShape(operand_shape.element_type(), in_dim_size); TF_ASSIGN_OR_RETURN( XlaOp in_dim_broadcast, InDimBroadcast(in_dim_shape, operand, broadcast_dimensions)); - if (ShapeUtil::Equal(in_dim_shape, shape)) { + + // If broadcast is not degenerate, return broadcasted result. + if (ShapeUtil::Equal(in_dim_shape, output_shape)) { return in_dim_broadcast; } - return AddBroadcastSequence(shape, in_dim_broadcast); + + // Otherwise handle degenerate broadcast case. + return AddBroadcastSequence(output_shape, in_dim_broadcast); }); } @@ -584,7 +609,7 @@ StatusOr XlaBuilder::Reshape(const Shape& shape, const XlaOp& operand) { TF_RETURN_IF_ERROR(first_error_); HloInstructionProto instr; - *instr.mutable_shape() = shape; + *instr.mutable_shape() = shape.ToProto(); return AddInstruction(std::move(instr), HloOpcode::kReshape, {operand}); } @@ -596,9 +621,9 @@ XlaOp XlaBuilder::Slice(const XlaOp& operand, HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferSliceShape(operand_shape, start_indices, - limit_indices, strides)); + Shape shape, ShapeInference::InferSliceShape( + operand_shape, start_indices, limit_indices, strides)); + *instr.mutable_shape() = shape.ToProto(); for (int i = 0; i < start_indices.size(); i++) { auto* slice_config = instr.add_slice_dimensions(); slice_config->set_start(start_indices[i]); @@ -633,9 +658,10 @@ XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape, GetShape(start_indices)); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDynamicSliceShape( operand_shape, start_indices_shape, slice_sizes)); + *instr.mutable_shape() = shape.ToProto(); for (int64 size : slice_sizes) { instr.add_dynamic_slice_sizes(size); @@ -655,9 +681,10 @@ XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, TF_ASSIGN_OR_RETURN(const Shape& update_shape, GetShape(update)); TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape, GetShape(start_indices)); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDynamicUpdateSliceShape( operand_shape, update_shape, start_indices_shape)); + *instr.mutable_shape() = shape.ToProto(); return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice, {operand, update, start_indices}); @@ -673,9 +700,9 @@ XlaOp XlaBuilder::ConcatInDim(absl::Span operands, TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands)); absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), [](const Shape& shape) { return &shape; }); - TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferConcatOpShape(operand_shape_ptrs, dimension)); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConcatOpShape( + operand_shape_ptrs, dimension)); + *instr.mutable_shape() = shape.ToProto(); instr.add_dimensions(dimension); @@ -692,10 +719,9 @@ XlaOp XlaBuilder::Pad(const XlaOp& operand, const XlaOp& padding_value, TF_ASSIGN_OR_RETURN(const Shape& padding_value_shape, GetShape(padding_value)); TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferPadShape(operand_shape, padding_value_shape, - padding_config)); - + Shape shape, ShapeInference::InferPadShape( + operand_shape, padding_value_shape, padding_config)); + *instr.mutable_shape() = shape.ToProto(); *instr.mutable_padding_config() = padding_config; return AddInstruction(std::move(instr), HloOpcode::kPad, @@ -708,7 +734,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand, absl::Span new_sizes) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); - TF_ASSIGN_OR_RETURN(const Shape& shape, + TF_ASSIGN_OR_RETURN(const Shape shape, ShapeInference::InferReshapeShape( operand_shape, dimensions, new_sizes)); XlaOp transposed = IsIdentityPermutation(dimensions) @@ -721,7 +747,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand, XlaOp XlaBuilder::Reshape(const XlaOp& operand, absl::Span new_sizes) { return ReportErrorOrReturn([&]() -> StatusOr { - TF_ASSIGN_OR_RETURN(auto shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN(Shape shape, GetShape(operand)); std::vector dimensions(shape.dimensions_size()); std::iota(dimensions.begin(), dimensions.end(), 0); return Reshape(operand, dimensions, new_sizes); @@ -771,7 +797,7 @@ XlaOp XlaBuilder::Collapse(const XlaOp& operand, void XlaBuilder::Trace(const string& tag, const XlaOp& operand) { ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - *instr.mutable_shape() = ShapeUtil::MakeNil(); + *instr.mutable_shape() = ShapeUtil::MakeNil().ToProto(); *instr.mutable_literal() = LiteralUtil::CreateR1U8(tag).ToProto(); return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand}); }); @@ -797,9 +823,10 @@ XlaOp XlaBuilder::Tuple(absl::Span elements) { TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(elements)); absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), [](const Shape& shape) { return &shape; }); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + TF_ASSIGN_OR_RETURN(const Shape shape, ShapeInference::InferVariadicOpShape( HloOpcode::kTuple, operand_shape_ptrs)); + *instr.mutable_shape() = shape.ToProto(); return AddInstruction(std::move(instr), HloOpcode::kTuple, elements); }); } @@ -814,7 +841,7 @@ XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) { ShapeUtil::HumanString(tuple_shape)); } *instr.mutable_shape() = - ShapeUtil::GetTupleElementShape(tuple_shape, index); + ShapeUtil::GetTupleElementShape(tuple_shape, index).ToProto(); instr.set_tuple_index(index); @@ -873,9 +900,10 @@ XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs, HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dimension_numbers)); + *instr.mutable_shape() = shape.ToProto(); *instr.mutable_dot_dimension_numbers() = dimension_numbers; if (precision_config != nullptr) { *instr.mutable_precision_config() = *precision_config; @@ -1017,10 +1045,11 @@ XlaOp XlaBuilder::ConvGeneralDilated( MakeWindow(window_dimensions, window_strides, padding, lhs_dilation, rhs_dilation)); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvolveShape( lhs_shape, rhs_shape, feature_group_count, instr.window(), dimension_numbers)); + *instr.mutable_shape() = shape.ToProto(); *instr.mutable_convolution_dimension_numbers() = dimension_numbers; instr.set_feature_group_count(feature_group_count); @@ -1093,10 +1122,9 @@ XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type, return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); - TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferFftShape(operand_shape, fft_type, fft_length)); - + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferFftShape( + operand_shape, fft_type, fft_length)); + *instr.mutable_shape() = shape.ToProto(); instr.set_fft_type(fft_type); for (int64 i : fft_length) { instr.add_fft_length(i); @@ -1114,7 +1142,7 @@ XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) { } const Shape infeed_instruction_shape = ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()}); - *instr.mutable_shape() = infeed_instruction_shape; + *instr.mutable_shape() = infeed_instruction_shape.ToProto(); instr.set_infeed_config(config); if (ShapeUtil::IsArray(shape) && sharding() && @@ -1135,7 +1163,7 @@ XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) { XlaOp token; auto make_token = [&]() { HloInstructionProto token_instr; - *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto(); return AddInstruction(std::move(token_instr), HloOpcode::kAfterAll, {}); }; if (sharding()) { @@ -1174,7 +1202,7 @@ XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) { // TODO(b/80000000): Remove this when clients have been updated to handle // tokens. HloInstructionProto infeed_data; - *infeed_data.mutable_shape() = shape; + *infeed_data.mutable_shape() = shape.ToProto(); infeed_data.set_tuple_index(0); return AddInstruction(std::move(infeed_data), HloOpcode::kGetTupleElement, {infeed}); @@ -1190,7 +1218,7 @@ XlaOp XlaBuilder::InfeedWithToken(const XlaOp& token, const Shape& shape, } const Shape infeed_instruction_shape = ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()}); - *instr.mutable_shape() = infeed_instruction_shape; + *instr.mutable_shape() = infeed_instruction_shape.ToProto(); instr.set_infeed_config(config); if (ShapeUtil::IsArray(shape) && sharding() && @@ -1215,7 +1243,7 @@ void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout, ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - *instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto(); // Check and set outfeed shape. if (!LayoutUtil::HasLayout(shape_with_layout)) { @@ -1228,14 +1256,14 @@ void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout, ShapeUtil::HumanStringWithLayout(shape_with_layout), ShapeUtil::HumanStringWithLayout(operand_shape)); } - *instr.mutable_outfeed_shape() = shape_with_layout; + *instr.mutable_outfeed_shape() = shape_with_layout.ToProto(); instr.set_outfeed_config(outfeed_config); // Outfeed takes a token as its second operand. Generate the token to pass // to the outfeed. HloInstructionProto token_instr; - *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto(); TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr), HloOpcode::kAfterAll, {})); @@ -1249,7 +1277,7 @@ void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout, // TODO(b/80000000): Remove this when clients have been updated to handle // tokens. HloInstructionProto tuple_instr; - *tuple_instr.mutable_shape() = ShapeUtil::MakeNil(); + *tuple_instr.mutable_shape() = ShapeUtil::MakeNil().ToProto(); // The dummy tuple should have no sharding. { @@ -1268,7 +1296,7 @@ XlaOp XlaBuilder::OutfeedWithToken(const XlaOp& operand, const XlaOp& token, return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - *instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto(); // Check and set outfeed shape. if (!LayoutUtil::HasLayout(shape_with_layout)) { @@ -1281,7 +1309,7 @@ XlaOp XlaBuilder::OutfeedWithToken(const XlaOp& operand, const XlaOp& token, ShapeUtil::HumanStringWithLayout(shape_with_layout), ShapeUtil::HumanStringWithLayout(operand_shape)); } - *instr.mutable_outfeed_shape() = shape_with_layout; + *instr.mutable_outfeed_shape() = shape_with_layout.ToProto(); instr.set_outfeed_config(outfeed_config); @@ -1293,7 +1321,7 @@ XlaOp XlaBuilder::OutfeedWithToken(const XlaOp& operand, const XlaOp& token, XlaOp XlaBuilder::CreateToken() { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - *instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto(); return AddInstruction(std::move(instr), HloOpcode::kAfterAll); }); } @@ -1303,8 +1331,17 @@ XlaOp XlaBuilder::AfterAll(absl::Span tokens) { if (tokens.empty()) { return InvalidArgument("AfterAll requires at least one operand"); } + for (int i = 0; i < tokens.size(); ++i) { + const XlaOp& operand = tokens[i]; + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + if (!ShapeUtil::IsToken(operand_shape)) { + return InvalidArgument( + "All operands to AfterAll must be tokens; operand %d has shape %s", + i, ShapeUtil::HumanString(operand_shape)); + } + } HloInstructionProto instr; - *instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto(); return AddInstruction(std::move(instr), HloOpcode::kAfterAll, tokens); }); } @@ -1321,7 +1358,7 @@ XlaOp XlaBuilder::CustomCall( "are reserved for internal use.", call_target_name); } - *instr.mutable_shape() = shape; + *instr.mutable_shape() = shape.ToProto(); instr.set_custom_call_target(call_target_name); instr.set_custom_call_opaque(opaque); if (operand_shapes_with_layout.has_value()) { @@ -1345,7 +1382,7 @@ XlaOp XlaBuilder::CustomCall( "constrained layout.", operand_num); } - *instr.add_operand_shapes_with_layout() = operand_shape; + *instr.add_operand_shapes_with_layout() = operand_shape.ToProto(); ++operand_num; } } @@ -1499,9 +1536,9 @@ XlaOp XlaBuilder::Transpose(const XlaOp& operand, return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); - TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferTransposeShape(operand_shape, permutation)); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferTransposeShape( + operand_shape, permutation)); + *instr.mutable_shape() = shape.ToProto(); for (int64 dim : permutation) { instr.add_dimensions(dim); } @@ -1514,9 +1551,9 @@ XlaOp XlaBuilder::Rev(const XlaOp& operand, return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); - TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferReverseShape(operand_shape, dimensions)); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReverseShape( + operand_shape, dimensions)); + *instr.mutable_shape() = shape.ToProto(); for (int64 dim : dimensions) { instr.add_dimensions(dim); } @@ -1535,9 +1572,9 @@ XlaOp XlaBuilder::Sort(const XlaOp& keys, absl::Span values, GetOperandShapes(values)); absl::c_transform(values_shapes, std::back_inserter(operand_shape_ptrs), [](const Shape& shape) { return &shape; }); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), - ShapeInference::InferVariadicOpShape( - HloOpcode::kSort, operand_shape_ptrs)); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferVariadicOpShape( + HloOpcode::kSort, operand_shape_ptrs)); + *instr.mutable_shape() = shape.ToProto(); if (dimension == -1) { TF_ASSIGN_OR_RETURN(const Shape& keys_shape, GetShape(keys)); dimension = ShapeUtil::Rank(keys_shape) - 1; @@ -1559,9 +1596,9 @@ XlaOp XlaBuilder::ConvertElementType(const XlaOp& operand, return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); - TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferConvertShape(operand_shape, new_element_type)); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvertShape( + operand_shape, new_element_type)); + *instr.mutable_shape() = shape.ToProto(); return AddInstruction(std::move(instr), HloOpcode::kConvert, {operand}); }); } @@ -1571,9 +1608,9 @@ XlaOp XlaBuilder::BitcastConvertType(const XlaOp& operand, return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); - TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferConvertShape(operand_shape, new_element_type)); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvertShape( + operand_shape, new_element_type)); + *instr.mutable_shape() = shape.ToProto(); return AddInstruction(std::move(instr), HloOpcode::kBitcastConvert, {operand}); }); @@ -1605,11 +1642,11 @@ XlaOp XlaBuilder::Map(absl::Span operands, TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape, computation.GetProgramShape()); TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferMapShape(operand_shape_ptrs, called_program_shape, - dimensions)); + Shape shape, ShapeInference::InferMapShape( + operand_shape_ptrs, called_program_shape, dimensions)); + *instr.mutable_shape() = shape.ToProto(); - const Shape& output_shape = instr.shape(); + Shape output_shape(instr.shape()); const int64 output_rank = ShapeUtil::Rank(output_shape); AddCalledComputation(computation, &instr); std::vector new_operands(operands.begin(), operands.end()); @@ -1652,7 +1689,7 @@ XlaOp XlaBuilder::RngOp(RandomDistribution distribution, } TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape)); - *instr.mutable_shape() = shape; + *instr.mutable_shape() = shape.ToProto(); instr.set_distribution(distribution); @@ -1680,10 +1717,10 @@ XlaOp XlaBuilder::While(const XlaComputation& condition, TF_ASSIGN_OR_RETURN(const auto& condition_program_shape, condition.GetProgramShape()); TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init)); - TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferWhileShape(condition_program_shape, - body_program_shape, init_shape)); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferWhileShape( + condition_program_shape, + body_program_shape, init_shape)); + *instr.mutable_shape() = shape.ToProto(); // Body comes before condition computation in the vector. AddCalledComputation(body, &instr); AddCalledComputation(condition, &instr); @@ -1700,10 +1737,10 @@ XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& start_indices, TF_ASSIGN_OR_RETURN(const Shape& input_shape, GetShape(input)); TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape, GetShape(start_indices)); - TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferGatherShape(input_shape, start_indices_shape, + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferGatherShape( + input_shape, start_indices_shape, dimension_numbers, slice_sizes)); + *instr.mutable_shape() = shape.ToProto(); *instr.mutable_gather_dimension_numbers() = dimension_numbers; for (int64 bound : slice_sizes) { @@ -1728,10 +1765,11 @@ XlaOp XlaBuilder::Scatter(const XlaOp& input, const XlaOp& scatter_indices, TF_ASSIGN_OR_RETURN(const Shape& updates_shape, GetShape(updates)); TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape, update_computation.GetProgramShape()); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferScatterShape( input_shape, scatter_indices_shape, updates_shape, to_apply_shape, dimension_numbers)); + *instr.mutable_shape() = shape.ToProto(); *instr.mutable_scatter_dimension_numbers() = dimension_numbers; @@ -1758,10 +1796,11 @@ XlaOp XlaBuilder::Conditional(const XlaOp& predicate, const XlaOp& true_operand, TF_ASSIGN_OR_RETURN(const ProgramShape& false_computation_shape, false_computation.GetProgramShape()); TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), + Shape shape, ShapeInference::InferConditionalShape( predicate_shape, true_operand_shape, false_operand_shape, true_computation_shape, false_computation_shape)); + *instr.mutable_shape() = shape.ToProto(); // The index of true_computation must be 0 and that of false computation // must be 1. @@ -1803,9 +1842,10 @@ XlaOp XlaBuilder::Reduce(absl::Span operands, [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), + Shape shape, ShapeInference::InferReduceShape( operand_shape_ptrs, dimensions_to_reduce, called_program_shape)); + *instr.mutable_shape() = shape.ToProto(); for (int64 dim : dimensions_to_reduce) { instr.add_dimensions(dim); @@ -1868,10 +1908,10 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding( MakeWindow(window_dimensions, window_strides, padding, /*lhs_dilation=*/base_dilations, /*rhs_dilation=*/window_dilations)); - TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferReduceWindowShape(operand_shape, init_shape, - instr.window(), to_apply_shape)); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReduceWindowShape( + operand_shape, init_shape, + instr.window(), to_apply_shape)); + *instr.mutable_shape() = shape.ToProto(); AddCalledComputation(computation, &instr); return AddInstruction(std::move(instr), HloOpcode::kReduceWindow, @@ -1889,9 +1929,10 @@ XlaOp XlaBuilder::BatchNormTraining(const XlaOp& operand, const XlaOp& scale, TF_ASSIGN_OR_RETURN(const Shape& scale_shape, GetShape(scale)); TF_ASSIGN_OR_RETURN(const Shape& offset_shape, GetShape(offset)); TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), + Shape shape, ShapeInference::InferBatchNormTrainingShape( operand_shape, scale_shape, offset_shape, feature_index)); + *instr.mutable_shape() = shape.ToProto(); instr.set_epsilon(epsilon); instr.set_feature_index(feature_index); @@ -1913,10 +1954,11 @@ XlaOp XlaBuilder::BatchNormInference(const XlaOp& operand, const XlaOp& scale, TF_ASSIGN_OR_RETURN(const Shape& offset_shape, GetShape(offset)); TF_ASSIGN_OR_RETURN(const Shape& mean_shape, GetShape(mean)); TF_ASSIGN_OR_RETURN(const Shape& variance_shape, GetShape(variance)); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), - ShapeInference::InferBatchNormInferenceShape( - operand_shape, scale_shape, offset_shape, - mean_shape, variance_shape, feature_index)); + TF_ASSIGN_OR_RETURN( + Shape shape, ShapeInference::InferBatchNormInferenceShape( + operand_shape, scale_shape, offset_shape, mean_shape, + variance_shape, feature_index)); + *instr.mutable_shape() = shape.ToProto(); instr.set_epsilon(epsilon); instr.set_feature_index(feature_index); @@ -1938,10 +1980,11 @@ XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale, TF_ASSIGN_OR_RETURN(const Shape& batch_mean_shape, GetShape(batch_mean)); TF_ASSIGN_OR_RETURN(const Shape& batch_var_shape, GetShape(batch_var)); TF_ASSIGN_OR_RETURN(const Shape& grad_output_shape, GetShape(grad_output)); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferBatchNormGradShape( operand_shape, scale_shape, batch_mean_shape, batch_var_shape, grad_output_shape, feature_index)); + *instr.mutable_shape() = shape.ToProto(); instr.set_epsilon(epsilon); instr.set_feature_index(feature_index); @@ -1972,9 +2015,9 @@ XlaOp XlaBuilder::CrossReplicaSum( return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); - TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferCrossReplicaSumShape({&operand_shape})); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCrossReplicaSumShape( + {&operand_shape})); + *instr.mutable_shape() = shape.ToProto(); for (const ReplicaGroup& group : replica_groups) { *instr.add_replica_groups() = group; @@ -2027,8 +2070,8 @@ XlaOp XlaBuilder::AllToAll(const XlaOp& operand, int64 split_dimension, absl::c_transform(slice_shapes, std::back_inserter(slice_shape_ptrs), [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferAllToAllTupleShape(slice_shape_ptrs)); + Shape shape, ShapeInference::InferAllToAllTupleShape(slice_shape_ptrs)); + *instr.mutable_shape() = shape.ToProto(); for (const ReplicaGroup& group : replica_groups) { *instr.add_replica_groups() = group; } @@ -2053,8 +2096,9 @@ XlaOp XlaBuilder::CollectivePermute( TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); HloInstructionProto instr; TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), + Shape shape, ShapeInference::InferCollectivePermuteShape(operand_shape)); + *instr.mutable_shape() = shape.ToProto(); for (const auto& pair : source_target_pairs) { auto* proto_pair = instr.add_source_target_pairs(); @@ -2103,10 +2147,11 @@ XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding( TF_ASSIGN_OR_RETURN(*instr.mutable_window(), MakeWindow(window_dimensions, window_strides, padding, /*lhs_dilation=*/{}, /*rhs_dilation=*/{})); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferSelectAndScatterShape( operand_shape, select_shape, instr.window(), source_shape, init_shape, scatter_shape)); + *instr.mutable_shape() = shape.ToProto(); AddCalledComputation(select, &instr); AddCalledComputation(scatter, &instr); @@ -2121,9 +2166,10 @@ XlaOp XlaBuilder::ReducePrecision(const XlaOp& operand, const int exponent_bits, return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReducePrecisionShape( operand_shape, exponent_bits, mantissa_bits)); + *instr.mutable_shape() = shape.ToProto(); instr.set_exponent_bits(exponent_bits); instr.set_mantissa_bits(mantissa_bits); return AddInstruction(std::move(instr), HloOpcode::kReducePrecision, @@ -2138,7 +2184,7 @@ void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) { // TODO(b/80000000): Remove this when clients have been updated to handle // tokens. HloInstructionProto token_instr; - *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto(); TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr), HloOpcode::kAfterAll, {})); @@ -2157,15 +2203,17 @@ XlaOp XlaBuilder::SendWithToken(const XlaOp& operand, const XlaOp& token, // token}. HloInstructionProto send_instr; TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); - *send_instr.mutable_shape() = ShapeUtil::MakeTupleShape( - {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}); + *send_instr.mutable_shape() = + ShapeUtil::MakeTupleShape( + {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}) + .ToProto(); send_instr.set_channel_id(handle.handle()); TF_ASSIGN_OR_RETURN(XlaOp send, AddInstruction(std::move(send_instr), HloOpcode::kSend, {operand, token})); HloInstructionProto send_done_instr; - *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto(); send_done_instr.set_channel_id(handle.handle()); return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone, {send}); @@ -2179,7 +2227,7 @@ XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) { // TODO(b/80000000): Remove this when clients have been updated to handle // tokens. HloInstructionProto token_instr; - *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto(); TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr), HloOpcode::kAfterAll, {})); @@ -2190,7 +2238,7 @@ XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) { // TODO(b/80000000): Remove this when clients have been updated to handle // tokens. HloInstructionProto recv_data; - *recv_data.mutable_shape() = shape; + *recv_data.mutable_shape() = shape.ToProto(); recv_data.set_tuple_index(0); return AddInstruction(std::move(recv_data), HloOpcode::kGetTupleElement, {recv}); @@ -2207,15 +2255,18 @@ XlaOp XlaBuilder::RecvWithToken(const XlaOp& token, const Shape& shape, // Recv instruction produces a tuple of {receive buffer, U32 context, // token}. HloInstructionProto recv_instr; - *recv_instr.mutable_shape() = ShapeUtil::MakeTupleShape( - {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}); + *recv_instr.mutable_shape() = + ShapeUtil::MakeTupleShape( + {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}) + .ToProto(); recv_instr.set_channel_id(handle.handle()); TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr), HloOpcode::kRecv, {token})); HloInstructionProto recv_done_instr; *recv_done_instr.mutable_shape() = - ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()}); + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()}) + .ToProto(); recv_done_instr.set_channel_id(handle.handle()); return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone, {recv}); @@ -2249,9 +2300,11 @@ XlaOp XlaBuilder::SendToHost(const XlaOp& operand, const XlaOp& token, // Send instruction produces a tuple of {aliased operand, U32 context, // token}. HloInstructionProto send_instr; - *send_instr.mutable_shape() = ShapeUtil::MakeTupleShape( - {shape_with_layout, ShapeUtil::MakeShape(U32, {}), - ShapeUtil::MakeTokenShape()}); + *send_instr.mutable_shape() = + ShapeUtil::MakeTupleShape({shape_with_layout, + ShapeUtil::MakeShape(U32, {}), + ShapeUtil::MakeTokenShape()}) + .ToProto(); send_instr.set_channel_id(handle.handle()); send_instr.set_is_host_transfer(true); TF_ASSIGN_OR_RETURN(XlaOp send, @@ -2259,7 +2312,7 @@ XlaOp XlaBuilder::SendToHost(const XlaOp& operand, const XlaOp& token, {operand, token})); HloInstructionProto send_done_instr; - *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto(); send_done_instr.set_channel_id(handle.handle()); send_done_instr.set_is_host_transfer(true); return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone, @@ -2288,8 +2341,10 @@ XlaOp XlaBuilder::RecvFromHost(const XlaOp& token, const Shape& shape, // Recv instruction produces a tuple of {receive buffer, U32 context, // token}. HloInstructionProto recv_instr; - *recv_instr.mutable_shape() = ShapeUtil::MakeTupleShape( - {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}); + *recv_instr.mutable_shape() = + ShapeUtil::MakeTupleShape( + {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}) + .ToProto(); recv_instr.set_channel_id(handle.handle()); recv_instr.set_is_host_transfer(true); TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr), @@ -2297,7 +2352,8 @@ XlaOp XlaBuilder::RecvFromHost(const XlaOp& token, const Shape& shape, HloInstructionProto recv_done_instr; *recv_done_instr.mutable_shape() = - ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()}); + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()}) + .ToProto(); recv_done_instr.set_channel_id(handle.handle()); recv_done_instr.set_is_host_transfer(true); return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone, @@ -2309,9 +2365,9 @@ XlaOp XlaBuilder::GetDimensionSize(const XlaOp& operand, int64 dimension) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const auto& operand_shape, GetShape(operand)); - TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferGetDimensionSizeShape(operand_shape, dimension)); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferGetDimensionSizeShape( + operand_shape, dimension)); + *instr.mutable_shape() = shape.ToProto(); instr.add_dimensions(dimension); return AddInstruction(std::move(instr), HloOpcode::kGetDimensionSize, {operand}); @@ -2356,7 +2412,7 @@ StatusOr XlaBuilder::BuildConstantSubGraph( SetProtoIdAndName(&entry, StrCat(name_, "_compute_constant"), kNameSeparator, GetNextId()); entry.set_root_id(root->id()); - ProgramShape* program_shape = entry.mutable_program_shape(); + ProgramShapeProto* program_shape = entry.mutable_program_shape(); *program_shape->mutable_result() = root->shape(); // We use std::set to keep the instruction ids in ascending order (which is @@ -2617,9 +2673,10 @@ XlaOp Broadcast(const XlaOp& operand, absl::Span broadcast_sizes) { return operand.builder()->Broadcast(operand, broadcast_sizes); } -XlaOp BroadcastInDim(const XlaOp& operand, const Shape& shape, +XlaOp BroadcastInDim(const XlaOp& operand, + const absl::Span out_dim_size, const absl::Span broadcast_dimensions) { - return operand.builder()->BroadcastInDim(operand, shape, + return operand.builder()->BroadcastInDim(operand, out_dim_size, broadcast_dimensions); } diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 68314a026ea..098efb60f9b 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -263,35 +264,30 @@ class XlaBuilder { // evaluating the computation. StatusOr IsConstant(const XlaOp& operand) const; + // Sets up binding which indicates that the `target_dim_num` in the subshape + // `target_param_index` of parameter `target_param_num` is a dynamic dimension + // and its real dynamic size is represented by `dynamic_param_index` in + // parameter `dynamic_param_num`. + // + // TODO(b/119520625): Remove this API once we have more dynamic shape infra + // ready. + Status SetDynamicBinding(int64 dynamic_size_param_num, + ShapeIndex dynamic_size_param_index, + int64 target_param_num, + ShapeIndex target_param_index, int64 target_dim_num); + private: // Build helper which takes the id of the root operation.. StatusOr Build(int64 root_id); - // Enqueues a "retrieve parameter value" instruction for a parameter that was - // passed to the computation. + // Description for the methods below can be found in the corresponding public + // functions section in this file. + XlaOp Parameter(int64 parameter_number, const Shape& shape, const string& name); - // Enqueues a constant with the value of the given literal onto the - // computation. XlaOp ConstantLiteral(const LiteralSlice& literal); - // Enqueues a constant onto the computation. Methods are templated on the - // native host type (NativeT) which corresponds to a specific XLA - // PrimitiveType as given in the following table: - // - // Native Type PrimitiveType - // ----------------------------- - // bool PRED - // int32 S32 - // int64 S64 - // uint32 U32 - // uint64 U64 - // float F32 - // double F64 - // - // Note: not all primitive types defined in xla_data.proto have a - // corresponding native type yet. template XlaOp ConstantR0(NativeT value); template @@ -321,181 +317,79 @@ class XlaBuilder { template XlaOp ConstantR4FromArray4D(const Array4D& values); - // Enqueues a rank one constant (vector) onto the computation. The vector has - // size 'length' and every element has the value 'value'. template XlaOp ConstantR1(int64 length, NativeT value); - // Adds dimensions to an array by duplicating the data in the array. - // - // The new dimensions are inserted on the left, i.e. if - // broadcast_sizes has values {a0, ..., aN} and the operand shape - // has dimensions {b0, ..., bM} then the shape of the output has - // dimensions {a0, ..., aN, b0, ..., bM}. - // - // The new dimensions index into copies of the operand, i.e. - // - // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM] XlaOp Broadcast(const XlaOp& operand, absl::Span broadcast_sizes); - XlaOp BroadcastInDim(const XlaOp& operand, const Shape& shape, + XlaOp BroadcastInDim(const XlaOp& operand, + const absl::Span out_dim_size, const absl::Span broadcast_dimensions); - // Enqueues a pad operation onto the computation that pads the given value on - // the edges as well as between the elements of the input. padding_config - // specifies the padding amount for each dimension. XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, const PaddingConfig& padding_config); - // Enqueues an operation onto the computation that flattens the operand based - // on the dimension order (major/slowest-varying to minor/fastest-varying) - // given, followed by reshaping it into the shape with the given dimension - // sizes (also major to minor). Conceptually, this is a limited form of - // "shape casting". XlaOp Reshape(const XlaOp& operand, absl::Span dimensions, absl::Span new_sizes); - // Enqueues an operation onto the computation that collapses the operand, from - // first to last dimension (C order), then reshapes it to the given dimension - // sizes. Conceptually, this is a limited form of "shape casting". XlaOp Reshape(const XlaOp& operand, absl::Span new_sizes); - // Wrapper for Reshape. - // Enqueues an operation to collapse the provided dimensions; e.g. an - // operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to - // {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must - // be a consecutive, in-order subsequence of the operand dimensions. - // - // Note that collapsing a single dimension does nothing: - // - // {256} collapsing {0} => {256} - // {1} collapsing {0} => {1} - // - // Collapsing multiple dimensions produces a single result dimension: - // - // {256, 2} collapsing {0,1} => {512} - // {256, 2, 3} collapsing {0,1} => {512, 3} - // - // This could potentially cause data to be moved -- it provides a more - // structured form of reshaping than an arbitrary Reshape operation. XlaOp Collapse(const XlaOp& operand, absl::Span dimensions); - // Enqueues a slice operation onto the computation that slices the operand - // from the start indices to the limit indices; e.g. - // - // x - // [ 0 1 2 3 ] - // y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ] - // [ 8 9 a b ] - // - // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D - // range notation. - // The strides parameter determines the stride over the slice XlaOp Slice(const XlaOp& operand, absl::Span start_indices, absl::Span limit_indices, absl::Span strides); - // Enqueues a slice operation in a given dimension, taking all other - // dimensions as they are; e.g. if dimno is 1 from start_index 2 to - // limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand - // for: - // - // array[:, 2:4:1, :] XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, int64 stride, int64 dimno); - // Enqueues a slice operation onto the computation that slices the 'operand' - // from dynamic start indices which are passed in 'start_indices'. - // The size of the slice in each dimension is passed in 'slice_sizes', - // which specify the end point of exclusive slice intervals in each - // dimension [start, start + size). - // The shape of 'start_indices' must be rank == 1, with dimension size - // equal to the rank of the 'operand'. - // Slice index calculations are computed modulo input dimension sizes to - // prevent dynamic start indices from generating out-of-bound array accesses. XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, absl::Span slice_sizes); - // Enqueues a dynamic update slice operation onto the computation, which - // updates a slice of 'operand' with 'update' at dynamic 'start_indices'. - // The shape of 'update' determines the shape of the slice of 'operand' - // which is updated. - // The indices specified in 'start_indices' specify the offset of the slice - // of 'operand' which is updated. - // - // update = {10, 11} // calculated at runtime. - // [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ] - // [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11] - // [7 8 9] [7 8 9 ] - // - // The shape of 'start_indices' must be rank == 1, with dimension size - // equal to the rank of the 'operand'. - // Slice index calculations are computed modulo update dimension sizes to - // prevent dynamic start indices from generating out-of-bound array accesses. XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, const XlaOp& start_indices); - // Enqueues a concatenate instruction onto the computation. 'operands' must - // have >= 1 entry. XlaOp ConcatInDim(absl::Span operands, int64 dimension); - // Enqueue a tracing operation onto the computation; the computation will emit - // a logging message with the operand. void Trace(const string& tag, const XlaOp& operand); - // Enqueues a conditional-move-like select operation onto the computation; - // predicated on pred, selects between on_true and on_false. XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false); - // Enqueues a tuple-creation instruction onto the computation. XlaOp Tuple(absl::Span elements); - // Enqueues a tuple-element-get instruction onto the computation. XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index); - // Enqueues an equal-to comparison instruction onto the computation. XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a not-equal comparison instruction onto the computation. XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a greater-or-equal comparison instruction onto the computation. XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a greater-than comparison instruction onto the computation. XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a less-than comparison instruction onto the computation. XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a less-or-equal comparison instruction onto the computation. XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a dot instruction onto the computation. XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, const PrecisionConfig* precision_config = nullptr); - // Enqueues a general dot instruction onto the computation. XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, const DotDimensionNumbers& dimension_numbers, const PrecisionConfig* precision_config = nullptr); - // Enqueues a convolution instruction onto the computation, which uses the - // default convolution dimension numbers. XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, int64 feature_group_count = 1, const PrecisionConfig* precision_config = nullptr); - // Enqueues a convolution instruction onto the computation, with the caller - // provided padding configuration in the format returned by MakePadding(). XlaOp ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, @@ -503,8 +397,6 @@ class XlaBuilder { int64 feature_group_count = 1, const PrecisionConfig* precision_config = nullptr); - // Enqueues a convolution instruction onto the computation, with the caller - // provided dimension numbers configuration. XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, @@ -512,8 +404,6 @@ class XlaBuilder { int64 feature_group_count = 1, const PrecisionConfig* precision_config = nullptr); - // Enqueues a convolution instruction onto the computation, with the caller - // provided padding configuration as well as the dimension numbers. XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, @@ -521,8 +411,6 @@ class XlaBuilder { int64 feature_group_count = 1, const PrecisionConfig* precision_config = nullptr); - // Enqueues a convolution instruction onto the computation, with the caller - // provided padding configuration, dilation factors and dimension numbers. XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, @@ -532,80 +420,53 @@ class XlaBuilder { int64 feature_group_count = 1, const PrecisionConfig* precision_config = nullptr); - // Enqueues an FFT instruction onto the computation, of the given type and - // with the given FFT length. XlaOp Fft(const XlaOp& operand, FftType fft_type, absl::Span fft_length); - // Enqueues an infeed instruction onto the computation, which writes data of - // the given shape to the infeed buffer of the device. XlaOp Infeed(const Shape& shape, const string& config = ""); XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape, const string& config = ""); - // Enqueues an outfeed instruction onto the computation. This instruction - // generates outgoing data transfers for the given data. - // - // shape_with_layout communicates the laid out shape that we want to outfeed - // -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error - // will occur. void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, const string& outfeed_config); XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token, const Shape& shape_with_layout, const string& outfeed_config); - // Enqueues a call instruction onto the computation. XlaOp Call(const XlaComputation& computation, absl::Span operands); - // Enqueues a custom call instruction onto the computation. XlaOp CustomCall( const string& call_target_name, absl::Span operands, const Shape& shape_with_layout, const string& opaque, absl::optional> operand_shapes_with_layout); - // The following methods enqueue element-wise binary arithmetic operations - // onto the computation. The shapes of the operands have to match unless one - // of the operands is a scalar, or an explicit broadcast dimension is given - // (see g3doc for more details). - - // Enqueues a complex compose instruction onto the computation. XlaOp Complex(const XlaOp& real, const XlaOp& imag, absl::Span broadcast_dimensions = {}); - // Enqueues a complex conjugate instruction onto the computation. XlaOp Conj(const XlaOp& operand); - // Enqueues an add instruction onto the computation. XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a subtract instruction onto the computation. XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a multiply instruction onto the computation. XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a divide instruction onto the computation. XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a remainder instruction onto the computation. XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a max instruction onto the computation. XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a min instruction onto the computation. XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Element-wise logical operators XlaOp And(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); @@ -624,32 +485,23 @@ class XlaBuilder { XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Reduces an array among the provided dimensions, given "computation" as a - // reduction operator. XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, absl::Span dimensions_to_reduce); - // Reduces several arrays simultaneously among the provided dimensions, given - // "computation" as a reduction operator. XlaOp Reduce(absl::Span operands, absl::Span init_values, const XlaComputation& computation, absl::Span dimensions_to_reduce); - // Convenience wrapper around the above that reduces all the dimensions in the - // operand shape. XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation); - // Enqueues a windowed reduce instruction onto the computation. XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, absl::Span window_dimensions, absl::Span window_strides, Padding padding); - // As ReduceWindow(), but the padding is given in the format - // returned by MakePadding(). XlaOp ReduceWindowWithGeneralPadding( const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, @@ -659,48 +511,22 @@ class XlaBuilder { absl::Span window_dilations, absl::Span> padding); - // Returns the sum of the operand value within each subgroup of replicas. All - // replicas supply one input to the sum and all replicas receive the resulting - // sum for each subgroup. XlaOp CrossReplicaSum(const XlaOp& operand, absl::Span replica_groups = {}); - // Enqueues an operation that do an AllReduce of the operand cross cores. Here - // AllReduce means doing a reduction on the input operand cross cores and then - // broadcasting the reduction result to those cores. The reduction function is - // defined by `computation`, which should be a commutative computation on - // scalars, e.g., add, min, or max. The way that AllReduce is applied is - // configured by: - // - // - `replica_groups`: each ReplicaGroup contains a list of replica id. If - // empty, all replicas belong to one group. Allreduce will be applied within - // subgroups. For example, we have 4 replicas, then - // replica_groups={{0,2},{1,3}} means, replica 0 and 2 are in subgroup 0, - // replica 1 and 3 are in subgroup 1. - // - // - `channel_id`: for Allreduce nodes from different modules, if they have - // the same channel_id, they will be 'Allreduce'd. If empty, Allreduce will - // not be applied cross modules. - // - // TODO(b/117564385): Rename this to AllReduce when it's ready to use. XlaOp CrossReplicaSum( const XlaOp& operand, const XlaComputation& computation, absl::Span replica_groups = {}, const absl::optional& channel_id = absl::nullopt); - // Enqueues an operation that do an Alltoall of the operand cross cores. XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, int64 concat_dimension, int64 split_count, const std::vector& replica_groups); - // Enqueues an operation that do an CollectivePermute of the operand cross - // cores. XlaOp CollectivePermute( const XlaOp& operand, const std::vector>& source_target_pairs); - // Enqueues an operation that scatters the `source` array to the selected - // indices of each window. XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, absl::Span window_dimensions, absl::Span window_strides, @@ -708,8 +534,6 @@ class XlaBuilder { const XlaOp& init_value, const XlaComputation& scatter); - // As SelectAndScatter(), but the padding is given in the format - // returned by MakePadding(). XlaOp SelectAndScatterWithGeneralPadding( const XlaOp& operand, const XlaComputation& select, absl::Span window_dimensions, @@ -717,217 +541,119 @@ class XlaBuilder { absl::Span> padding, const XlaOp& source, const XlaOp& init_value, const XlaComputation& scatter); - // Enqueues an abs instruction onto the computation. XlaOp Abs(const XlaOp& operand); - // Enqueues a atan2 instruction onto the computation. XlaOp Atan2(const XlaOp& y, const XlaOp& x, absl::Span broadcast_dimensions = {}); - // Enqueues an exp instruction onto the computation. XlaOp Exp(const XlaOp& operand); - // Enqueues an expm1 instruction onto the computation. XlaOp Expm1(const XlaOp& operand); - // Enqueues a floor instruction onto the computation. XlaOp Floor(const XlaOp& operand); - // Enqueues a ceil instruction onto the computation. XlaOp Ceil(const XlaOp& operand); - // Enqueues a round instruction onto the computation, rounding to nearest even - // with half-way cases rounding away from zero. XlaOp Round(const XlaOp& operand); - // Enqueues an log instruction (natural logarithm) onto the computation. XlaOp Log(const XlaOp& operand); - // Enqueues an log1p instruction (log(x+1)) onto the computation. XlaOp Log1p(const XlaOp& operand); - // Enqueues a sign instruction onto the computation. XlaOp Sign(const XlaOp& operand); - // Enqueues a count leading zeros instruction onto the computation. XlaOp Clz(const XlaOp& operand); - // Enqueues a cosine instruction onto the computation. XlaOp Cos(const XlaOp& operand); - // Enqueues a sine instruction onto the computation. XlaOp Sin(const XlaOp& operand); - // Enqueues a tanh instruction onto the computation. XlaOp Tanh(const XlaOp& operand); - // Enqueues a real-part instruction onto the computation. XlaOp Real(const XlaOp& operand); - // Enqueues an imaginary-part instruction onto the computation. XlaOp Imag(const XlaOp& operand); - // Enqueues a lhs^rhs computation onto the computation. XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues an operator that tests if the operand's values are finite, i.e., - // not Inf or NaN. Defined only for floating-point types. Returns an array of - // booleans with the same shape where entries are true iff the corresponding - // entry was NaN. XlaOp IsFinite(const XlaOp& operand); - // Enqueues an iota operation onto the computation. XlaOp Iota(const Shape& shape, int64 iota_dimension); - // Enqueues a rank-1 iota operation onto the computation. XlaOp Iota(PrimitiveType type, int64 size); - // Enqueues a convert instruction onto the computation that changes the - // element type of the operand array to primitive_type. XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type); - // Enqueues a no-op instruction onto the computation that changes - // the element type of the operand array to primitive_type. The - // bit-widths of the source and destination element types must be - // identical. XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type); - // Enqueues a negate instruction onto the computation. XlaOp Neg(const XlaOp& operand); - // Enqueues a transpose instruction onto the computation. XlaOp Transpose(const XlaOp& operand, absl::Span permutation); - // Enqueues a reverse instruction onto the computation. The order of the - // elements in the given dimensions is reversed (i.e., the element at index i - // is moved to index dimension_size - 1 - i). XlaOp Rev(const XlaOp& operand, absl::Span dimensions); - // Enqueues a sort (as increasing order) instruction onto the computation. - // If only keys are provided: - // * If the keys are an rank-1 tensor (an array), the result is a sorted array - // of keys, in ascending order. - // * If the keys have higher rank, the keys are sorted along the provided - // dimension. For example, for a rank-2 tensor (a matrix) of keys, a dimension - // value of 0 will indepenently sort every column, and a dimension value of 1 - // will independently sort each row. If no dimension number is provided, then - // the last dimension is chosen by default. - // - // If both keys and values are provided: - // * The keys and all values must be tensors with the same dimensions. The - // element types of the tensors may be different. - // * The result is a tuple that consists of a sorted tensor of keys (along the - // provided dimension, as above) as the first element, and tensors with their - // corresponding values as the other elements. XlaOp Sort(const XlaOp& keys, absl::Span values = {}, int64 dimension = -1); - // Enqueues a clamp instruction onto the computation. XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); - // Enqueues a map instruction onto the computation. XlaOp Map(absl::Span operands, const XlaComputation& computation, absl::Span dimensions, absl::Span static_operands = {}); - // Enqueues a N(mu, sigma) random number generation instruction onto the - // computation. XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape); - // Enqueues a U(a, b) random number generation instruction onto the - // computation. Returns values in the semi-open interval [a, b). XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape); - // Enqueues a while node onto the computation. XlaOp While(const XlaComputation& condition, const XlaComputation& body, const XlaOp& init); - // Enqueues a conditional node onto the computation. XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand, const XlaComputation& true_computation, const XlaOp& false_operand, const XlaComputation& false_computation); - // Enqueues a ReducePrecision node onto the computation. XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, const int mantissa_bits); - // Enqueues a Gather node onto the computation. XlaOp Gather(const XlaOp& input, const XlaOp& start_indices, const GatherDimensionNumbers& dimension_numbers, absl::Span slice_sizes); - // Enqueues a Scatter node onto the computation. XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, const XlaOp& updates, const XlaComputation& update_computation, const ScatterDimensionNumbers& dimension_numbers); - // Enqueues a Send node onto the computation for device-to-device - // communication, to send the given operand to a Recv instruction that shares - // the same channel handle. void Send(const XlaOp& operand, const ChannelHandle& handle); XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token, const ChannelHandle& handle); - // Enqueues a Send node which sends data to the host. XlaOp SendToHost(const XlaOp& operand, const XlaOp& token, const Shape& shape_with_layout, const ChannelHandle& handle); - // Enqueues a Recv node which receives data from the host. XlaOp RecvFromHost(const XlaOp& token, const Shape& shape, const ChannelHandle& handle); - // Enqueues an AfterAll operation with no operands producing a token-shaped - // value. XlaOp CreateToken(); - // Enqueues an AfterAll operation with no operands producing a token-shaped - // value. XlaOp AfterAll(absl::Span tokens); - // Enqueues a Recv node onto the computation. The data comes from a Send - // instruction that shares the same channel handle and its shape must - // be the same as the given shape. XlaOp Recv(const Shape& shape, const ChannelHandle& handle); XlaOp RecvWithToken(const XlaOp& token, const Shape& shape, const ChannelHandle& handle); - // Normalizes operand across spatial and batch dimensions for each feature. - // - // Returns a tuple (normalized, batch_mean, batch_var) where `normalized` - // is the normalized result and batch_mean and batch_var are the mean and - // variance, respectively, across batch for the operand. XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale, const XlaOp& offset, float epsilon, int64 feature_index); - // Normalizes operand across spatial and batch dimensions for each feature. - // - // `BatchNormInference` is equivalent to calling `BatchNormTraining` without - // computing `mean` and `variance` for each batch inside the operation. It - // uses the input `mean` and `variance` instead as estimated values. The - // purpose of this op is to reduce latency in inference, hence the name - // `BatchNormInference`. - // - // The output has the same shape as `operand`, and contains the normalized - // values for each batch. XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale, const XlaOp& offset, const XlaOp& mean, const XlaOp& variance, float epsilon, int64 feature_index); - // Calculates the gradients of a batch norm op. - // - // The inputs `batch_mean` and `batch_var` represent the mean and variance - // across the batch. - // - // Returns a tuple of three elements: - // - grad_operand: Gradient with respect to input `operand` - // - grad_offset: Gradient with respect to input `offset` - // - grad_scale: Gradient with respect to input `scale` XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, const XlaOp& batch_mean, const XlaOp& batch_var, const XlaOp& grad_output, float epsilon, @@ -1019,6 +745,9 @@ class XlaBuilder { // The instructions of this computation. std::vector instructions_; + // Dynamic parameter configuration of this computation. + DynamicParameterBinding dynamic_parameter_binding_; + // A map from XlaOp::Handle to the index in the instructions_ vector where the // instruction is held. absl::flat_hash_map handle_to_index_; @@ -1096,7 +825,7 @@ class XlaBuilder { absl::Span broadcast_sizes); friend XlaOp BroadcastInDim( - const XlaOp& operand, const Shape& shape, + const XlaOp& operand, const absl::Span out_dim_size, const absl::Span broadcast_dimensions); friend XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, @@ -1393,6 +1122,7 @@ class XlaScopedShardingAssignment { // Free functions for building XlaOps. The intention is that these will // become the public API for building XlaOps rather than calling methods on // XlaBuilder directly. +// // Enqueues a "retrieve parameter value" instruction for a parameter that was // passed to the computation. @@ -1488,7 +1218,8 @@ XlaOp Broadcast(const XlaOp& operand, absl::Span broadcast_sizes); // will generate output // {{1 , 1}, // {2 , 2}} -XlaOp BroadcastInDim(const XlaOp& operand, const Shape& shape, +XlaOp BroadcastInDim(const XlaOp& operand, + const absl::Span out_dim_size, const absl::Span broadcast_dimensions); // Enqueues a pad operation onto the computation that pads the given value on @@ -2138,6 +1869,7 @@ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, XlaOp GetDimensionSize(const XlaOp& operand, int64 dimension); // Implementation details below this point. +// template XlaOp XlaBuilder::ConstantR0(NativeT value) { diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc index 8aa85c3cd63..b3f5be300d3 100644 --- a/tensorflow/compiler/xla/client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -267,7 +267,7 @@ TEST_F(XlaBuilderTest, BinopHasInDimAndDegenerateBroadcast) { TEST_F(XlaBuilderTest, BroadcastInDim) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 3}), "x"); - BroadcastInDim(x, ShapeUtil::MakeShape(F32, {2, 4, 3}), + BroadcastInDim(x, {2, 4, 3}, /*broadcast_dimensions=*/{0, 2}); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); auto root = module->entry_computation()->root_instruction(); @@ -277,7 +277,7 @@ TEST_F(XlaBuilderTest, BroadcastInDim) { TEST_F(XlaBuilderTest, BroadcastInDimWithDegeneratedDim) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 1, 4}), "x"); - BroadcastInDim(x, ShapeUtil::MakeShape(F32, {2, 3, 4}), + BroadcastInDim(x, {2, 3, 4}, /*broadcast_dimensions=*/{0, 1, 2}); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -446,5 +446,14 @@ TEST_F(XlaBuilderTest, ProtoMatches) { EXPECT_EQ(c0_string, c1_string); } +TEST_F(XlaBuilderTest, AfterAllWithNonTokenOperands) { + XlaBuilder b(TestName()); + AfterAll(&b, {CreateToken(&b), ConstantR0(&b, 1.0)}); + Status status = b.Build().status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("All operands to AfterAll must be tokens")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_computation.cc b/tensorflow/compiler/xla/client/xla_computation.cc index c9870b65b91..f317892c125 100644 --- a/tensorflow/compiler/xla/client/xla_computation.cc +++ b/tensorflow/compiler/xla/client/xla_computation.cc @@ -25,7 +25,7 @@ namespace xla { StatusOr XlaComputation::GetProgramShape() const { TF_RET_CHECK(proto_.has_host_program_shape()); - return proto_.host_program_shape(); + return ProgramShape(proto_.host_program_shape()); } StatusOr> XlaComputation::Snapshot() const { diff --git a/tensorflow/compiler/xla/client/xla_computation.h b/tensorflow/compiler/xla/client/xla_computation.h index 71598ef8b29..3ccbfb28bd0 100644 --- a/tensorflow/compiler/xla/client/xla_computation.h +++ b/tensorflow/compiler/xla/client/xla_computation.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index 033887d7c11..d7e7b9e6218 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -54,7 +54,7 @@ void SetDebugOptionsDefaults(DebugOptions* flags) { // TODO(jlebar): Disable fastmath once doing so is not a performance // regression. flags->set_xla_cpu_enable_fast_math(true); - flags->set_xla_gpu_enable_fast_math(true); + flags->set_xla_gpu_enable_fast_min_max(true); flags->set_xla_force_host_platform_device_count(1); } @@ -160,11 +160,11 @@ void AllocateFlags() { "Enable unsafe fast-math optimizations in the CPU compiler; " "this may produce faster code at the expense of some accuracy."), tensorflow::Flag( - "xla_gpu_enable_fast_math", - bool_setter_for(&DebugOptions::set_xla_cpu_enable_fast_math), - flag_values->xla_cpu_enable_fast_math(), - "Enable unsafe fast-math optimizations in the GPU compiler; " - "this may produce faster code at the expense of some accuracy."), + "xla_gpu_enable_fast_min_max", + bool_setter_for(&DebugOptions::set_xla_gpu_enable_fast_min_max), + flag_values->xla_gpu_enable_fast_min_max(), + "Enable fast floating point min/max lowering that does not propagate " + "NaNs."), tensorflow::Flag( "xla_llvm_enable_alias_scope_metadata", bool_setter_for( @@ -335,7 +335,7 @@ void AllocateFlags() { "behavior to help run tests on the host that run models in parallel " "across multiple devices."), }); - ParseFlagsFromEnv(*flag_objects); + ParseFlagsFromEnvAndDieIfUnknown("XLA_FLAGS", *flag_objects); } } // namespace diff --git a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py index fb135f5ceda..1fea816a803 100644 --- a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py +++ b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py @@ -18,12 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import math - import numpy as _np # Avoids becoming a part of public Tensorflow API. from tensorflow.compiler.xla import xla_data_pb2 -from tensorflow.compiler.xla.python_api import xla_shape from tensorflow.core.framework import attr_value_pb2 @@ -64,22 +61,18 @@ class Sharding(object): tile_assignment_devices=[core])) @classmethod - def tile(cls, tile_shape, tile_assignment): + def tile(cls, tile_assignment): """Returns a Tiled sharding attribute. This causes an op to be partially computed on multiple cores in the XLA device. Args: - tile_shape: A xla_shape.Shape describing the tile shape that each core - will compute. - The tile shape does not need to be divisible by the tile assignment. tile_assignment: An np.ndarray describing the topology of the tiling and which device will compute which part of the topology. Raises: - TypeError: tile_assignment was not of np.array type or tile_shape was - not of xla_shape.Shape type. + TypeError: tile_assignment was not of np.array type. TODO(jmolloy): This concept is nefarious and is not something we really want to expose to users (especially as the @@ -87,14 +80,11 @@ class Sharding(object): """ if not isinstance(tile_assignment, _np.ndarray): raise TypeError('Tile assignment must be of type np.ndarray') - if not isinstance(tile_shape, xla_shape.Shape): - raise TypeError('Tile shape must be of type xla_shape.Shape') dims = list(tile_assignment.shape) flattened_devices = tile_assignment.reshape(-1, order='C') return Sharding( proto=xla_data_pb2.OpSharding( type=xla_data_pb2.OpSharding.OTHER, - tile_shape=tile_shape.message, tile_assignment_dimensions=dims, tile_assignment_devices=list(flattened_devices))) @@ -118,14 +108,8 @@ class Sharding(object): shape = tensor.shape.as_list() if shape[split_dimension] < num_devices: raise ValueError('Split dimension was smaller than the required number ' - 'of splits: shape=%r, dimension=%r, num_devices=%r', - shape, split_dimension, num_devices) - - tile_shape = shape - tile_shape[split_dimension] = int( - math.ceil(tile_shape[split_dimension] / num_devices)) - tile_shape_proto = xla_data_pb2.Shape( - element_type=xla_data_pb2.F32, dimensions=tile_shape) + 'of splits: shape=%r, dimension=%r, num_devices=%r' % + (shape, split_dimension, num_devices)) tile_assignment_dims = [1] * len(shape) tile_assignment_dims[split_dimension] = num_devices @@ -133,7 +117,6 @@ class Sharding(object): return Sharding( proto=xla_data_pb2.OpSharding( type=xla_data_pb2.OpSharding.OTHER, - tile_shape=tile_shape_proto, tile_assignment_dimensions=tile_assignment_dims, tile_assignment_devices=range(num_devices))) @@ -149,7 +132,6 @@ class Sharding(object): type=xla_data_pb2.OpSharding.TUPLE, tuple_shardings=tuple_shardings) else: proto = self._proto - attr_value = attr_value_pb2.AttrValue(s=proto.SerializeToString()) # TODO(jmolloy): This need to be seriously revisited before declaring this # API available for public use. @@ -194,8 +176,8 @@ def assign_device(tensor, device): return tensor -def tile(tensor, tile_shape, tile_assignment): - Sharding.tile(tile_shape, tile_assignment).apply_to_tensor(tensor) +def tile(tensor, tile_assignment): + Sharding.tile(tile_assignment).apply_to_tensor(tensor) return tensor diff --git a/tensorflow/compiler/xla/g3doc/_book.yaml b/tensorflow/compiler/xla/g3doc/_book.yaml index bcfbcc3a22f..12b7094705e 100644 --- a/tensorflow/compiler/xla/g3doc/_book.yaml +++ b/tensorflow/compiler/xla/g3doc/_book.yaml @@ -3,15 +3,15 @@ upper_tabs: - include: /_upper_tabs_left.yaml - include: /api_docs/_upper_tabs_api.yaml # Dropdown menu -- name: Ecosystem - path: /ecosystem +- name: Resources + path: /resources is_default: true menu: - - include: /ecosystem/_menu_toc.yaml + - include: /resources/_menu_toc.yaml lower_tabs: # Subsite tabs other: - - name: Guide + - name: Guide & Tutorials contents: - title: XLA overview path: /xla/overview @@ -27,3 +27,7 @@ upper_tabs: path: /xla/shapes - title: Using AOT compilation path: /xla/tfcompile + - heading: Tutorials + - title: XLA compile API + path: /xla/tutorials/xla_compile + status: experimental diff --git a/tensorflow/compiler/xla/g3doc/_index.yaml b/tensorflow/compiler/xla/g3doc/_index.yaml index 7934cd11ba2..858de427119 100644 --- a/tensorflow/compiler/xla/g3doc/_index.yaml +++ b/tensorflow/compiler/xla/g3doc/_index.yaml @@ -17,7 +17,7 @@ landing_page: - classname: devsite-landing-row-cards items: - heading: XLA - TensorFlow, compiled - image_path: /ecosystem/images/tf-logo-card-16x9.png + image_path: /resources/images/tf-logo-card-16x9.png path: https://developers.googleblog.com/2017/03/xla-tensorflow-compiled.html buttons: - label: Read on Google Developers blog @@ -28,7 +28,7 @@ landing_page: - label: Watch the video path: https://www.youtube.com/watch?v=kAOanJczHA0 - heading: XLA on GitHub - image_path: /ecosystem/images/github-card-16x9.png + image_path: /resources/images/github-card-16x9.png path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla buttons: - label: View on GitHub diff --git a/tensorflow/compiler/xla/g3doc/images/xla_array_layout_figure1.png b/tensorflow/compiler/xla/g3doc/images/xla_array_layout_figure1.png new file mode 100644 index 00000000000..00cefe4c780 Binary files /dev/null and b/tensorflow/compiler/xla/g3doc/images/xla_array_layout_figure1.png differ diff --git a/tensorflow/compiler/xla/g3doc/images/xla_array_layout_figure2.png b/tensorflow/compiler/xla/g3doc/images/xla_array_layout_figure2.png new file mode 100644 index 00000000000..6439c6e4027 Binary files /dev/null and b/tensorflow/compiler/xla/g3doc/images/xla_array_layout_figure2.png differ diff --git a/tensorflow/compiler/xla/g3doc/jit.md b/tensorflow/compiler/xla/g3doc/jit.md index ded1e582b24..85fa16ccc7f 100644 --- a/tensorflow/compiler/xla/g3doc/jit.md +++ b/tensorflow/compiler/xla/g3doc/jit.md @@ -86,7 +86,7 @@ on uncompilable operator, xla.compile() returns an explicit error. This is useful if you want more predictable behaviors from XLA compilation. Please see -[xla.compile() tutorial Colab](https://colab.sandbox.google.com/github/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb) +[xla.compile() tutorial Colab](./tutorials/xla_compile.ipynb) for how to use it. ### Placing operators on XLA devices @@ -144,7 +144,7 @@ Execute the python script to train the model with XLA and turn on a debugging feature of XLA via an environmental variable that outputs the XLA graph. ```shell -TF_XLA_FLAGS="--xla_hlo_graph_path=/tmp --xla_generate_hlo_graph=.*" python mnist_softmax_xla.py +XLA_FLAGS="--xla_hlo_graph_path=/tmp --xla_generate_hlo_graph=.*" python mnist_softmax_xla.py ``` Open the timeline file created (`timeline.ctf.json`). The rendered timeline diff --git a/tensorflow/compiler/xla/g3doc/layout_with_tiling.md b/tensorflow/compiler/xla/g3doc/layout_with_tiling.md new file mode 100644 index 00000000000..5e990851af7 --- /dev/null +++ b/tensorflow/compiler/xla/g3doc/layout_with_tiling.md @@ -0,0 +1,159 @@ +# Tiled layout + +*Note: This doc describes how tiled layout is intended to work. Tiling is being +implemented, but this is an early effort and it is currently not even guaranteed +to get an Unimplemented error if one tries to use tiling - it may be just +silently ignored.* + +
![](images/xla_array_layout_figure1.png) + +Figure 1
+ +Figure 1 shows how an array F32[3,5] is laid out in memory with 2x2 tiling. A +shape with this layout is written as F32[3,5]{1,0:(2,2)}, where 1,0 relates to +the physical order of dimensions (minor_to_major field in Layout) while (2,2) +after the colon indicates tiling of the physical dimensions by a 2x2 tile. + +Intuitively tiles are laid out to cover the shape and then within each tile, +elements are then laid out without tiling, as in the example above, where the +right part of the example shows the layout in memory, including the white +padding elements that are added in order to have complete 2x2 tiles even though +the original array bounds are not even. + +The extra elements in the padding are not required to contain any particular +value. + +## Linear index formulas for tiling given a shape and a tile + +Without tiling, an element e=(en, en-1, ... , +e1) in an array with array bounds d=(dn, dn-1, +... , d1) (d1 is the most minor dimension) is laid out by major to +minor order at position: + +   linear_index(e, d) \ += linear_index((en, en-1, ... , e1), +(dn, dn-1, ... , d1)) \ += endn-1...d1 + +en-1dn-2...d1 + ... + e1 + +For simplicity of notation in this document we assume a tile has the same number +of dimensions as the array. In XLA's implementation of tiling, this is +generalized to tilings with fewer dimensions by leaving the initial most-major +dimensions unchanged and applying the tiling only to the most minor dimensions, +so that the tiling that is specified mentions a suffix of the physical +dimensions of the shape being tiled. + +When tiling of size (tn, tn-1, ... , t1) is +used, an element in the array with indices (en, en-1, ... +, e1) is mapped to this position in the final layout: + +   linear_index_with_tile(e, d, t) \ += linear_index((āŒŠe/tāŒ‹, e mod t), (āŒˆd/tāŒ‰, t))     (arithmetic is +elementwise, (a,b) is concatenation) \ += linear_index((āŒŠen/tnāŒ‹, ... , +āŒŠe1/t1āŒ‹, en mod tn, ... , +e1 mod t1), (āŒˆdn/tnāŒ‰, ... , +āŒˆd1/t1āŒ‰, tn, tn-1, ... , +t1)) \ += linear_index((āŒŠen/tnāŒ‹, ... , +āŒŠe1/t1āŒ‹), (āŒˆdn/tnāŒ‰, ... , +āŒˆd1/t1āŒ‰))āˆ™tntn-1...t1 + +linear_index((en mod tn, ... , e1 mod +t1), (tn, tn-1, ... , t1)) + +The layout can be thought of as having two parts: +(āŒŠen/tnāŒ‹, ... , āŒŠe1/t1āŒ‹), which +corresponds to a tile index in an array of tiles of size +(āŒˆdn/tnāŒ‰, ... , āŒˆd1/t1āŒ‰), and +(en mod tn, ... , e1 mod t1), which +corresponds to a within-tile index. The ceil function appears in +āŒˆdi/tiāŒ‰ because if tiles overrun the bounds of the larger +array, padding is inserted as in Figure 1. Both the tiles and elements within +tiles are laid out recursively without tiling. + +For the example in Figure 1, element (2,3) has tile index (1,1), and within-tile +index (0,1), for a combined coordinate vector of (1, 1, 0, 1). The tile indices +have bounds (2, 3) and the tile itself is (2, 2) for a combined vector of (2, 3, +2, 2). The linear index with tile for the element with index (2, 3) in the +logical shape is then + +   linear_index_with_tile((2,3), (3,5), (2,2)) \ += linear_index((1,1,0,1), (2,3,2,2)) \ += linear_index((1,1), (2,3)) āˆ™ 2 āˆ™ 2 + linear_index((0,1), (2,2)) \ += (1 āˆ™ 3 + 1) āˆ™ 2 āˆ™ 2 + (0 āˆ™ 2 + 1) \ += 17. + +# Tiling as pad-reshape-transpose + +Tiling-based layout operates as follows: \ +Consider an array of dimensions (dn, dn-1, ... , d1) (d1 +is the most minor dimension). When itā€™s laid out with tiling of size +(tn, tn-1, ... , t1) (t1 is the most +minor dimension), that tiling can be described in terms of pad-reshape-transpose +in the following way. + +1. The array is padded to (āŒˆdn/tnāŒ‰āˆ™tn, ... , + āŒˆd1/t1āŒ‰āˆ™t1). +2. Each dimension i is broken into (āŒˆdi/tiāŒ‰, + ti), i.e. the array is reshaped to \ +     (āŒˆdn/tnāŒ‰, tn, ... , + āŒˆd1/t1āŒ‰, t1). \ + There is no physical layout change in this reshape by itself, so this + reshape is a bitcast. If one is not explicitly thinking of a tiling, this + reshape could express any shape with the same number of elements as the + padded shape - the example here is of how to express a tile in this way. +3. A transpose happens by moving tn, ... , t1 to the most + minor dimensions while keeping their relative order, so that the order of + dimensions from most major to most minor becomes \ +     (āŒˆdn/tnāŒ‰, ... , + āŒˆd1/t1āŒ‰, tn, ... , t1). + +The final shape has the prefix \ +    (āŒˆdn/tnāŒ‰, ... , +āŒˆd1/t1āŒ‰), which describes the number of tiles in each +dimension. An element in the array (en, ... , e1) is +mapped to this element in the final shape: \ +    (āŒŠen/tnāŒ‹, ... , +āŒŠe0/t0āŒ‹, en mod tn, ... , +e1 mod t1). It is easy to see that the linear index of the +element follows the formula above as expected. + +# Repeated tiling + +XLA's tiling becomes even more flexible by applying it repeatedly. + +
![](images/xla_array_layout_figure2.png) + +Figure 2
+ +Figure 2 shows how an array of size 4x8 is tiled by two levels of tiling (first +2x4 then 2x1). We represent this repeated tiling as (2,4)(2,1). Each color +indicates a 2x4 tile and each red border box is a 2x1 tile. The numbers +indicates the linear index in memory of that element in the tiled format. This +format matches the format used for BF16 on TPU, except that the initial tile is +bigger, namely the tiling is (8,128)(2,1), where the purpose of the second +tiling by 2x1 is to collect together two 16 bit values to form one 32 bit value +in a way that aligns with the architecture of a TPU. + +Note that a second or later tile can refer to both the minor within-tile +dimensions, which just rearranges data within the tile, as in this example with +(8,128)(2,1), but can also refer to the major cross-tile dimensions from the +prior tiling. + +# Combining dimensions using tiles + +XLA's tiling also supports combining dimensions. For example, it can combine +dimensions in F32[2,7,8,11,10]{4,3,2,1,0} into F32[112,110]{1,0} first before +tiling it with (2,3). The tile used is (∗,∗,2,∗,3). Here an +asterisk in a tile implies taking that dimension and combining it with the next +more minor dimension. Multiple adjacent dimensions can be subsumed together into +one dimension. A subsumed dimension is represented by a tile value of -1 in that +dimension of the tile, which is not otherwise valid in a tile as a dimension +size. + +More precisely, if dimension i of the shape is eliminated via an asterisk in the +tile, then before the prior definition of tiling is applied, that dimension is +removed from both the shape being tiled and the tile vector, and what was +dimension i-1 of the shape has its array bound increased from di-1 to +didi-1. This step is repeated for each asterisk in the +tile vector. diff --git a/tensorflow/compiler/xla/g3doc/operation_semantics.md b/tensorflow/compiler/xla/g3doc/operation_semantics.md index 73a9db75f6b..d888b1f23f3 100644 --- a/tensorflow/compiler/xla/g3doc/operation_semantics.md +++ b/tensorflow/compiler/xla/g3doc/operation_semantics.md @@ -13,6 +13,22 @@ arbitrary-dimensional array. For convenience, special cases have more specific and familiar names; for example a *vector* is a 1-dimensional array and a *matrix* is a 2-dimensional array. +## AfterAll + +See also +[`XlaBuilder::AfterAll`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). + +AfterAll takes a variadic number of tokens and produces a single token. Tokens +are primitive types which can be threaded between side-effecting operations to +enforce ordering. `AfterAll` can be used as a join of tokens for ordering a +operation after a set operations. + + `AfterAll(operands)` + +Arguments | Type | Semantics +---------- | ------- | ------------------------- +`operands` | `XlaOp` | variadic number of tokens + ## AllToAll See also @@ -402,6 +418,33 @@ then v12 == f32[8x3] {{10, 11, 12}, ``` +## CollectivePermute + +See also +[`XlaBuilder::CollectivePermute`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). + +CollectivePermute is a collective operation that sends and receives data cross +replicas. + + `CollectivePermute(operand, source_target_pairs)` + +| Arguments | Type | Semantics | +| --------------------- | ----------------------- | -------------------------- | +| `operand` | `XlaOp` | n dimensional input array | +| `source_target_pairs` | `` vector | A list of | +: : : (source_replica_id, : +: : : target_replica_id) pairs. : +: : : For each pair, the operand : +: : : is sent from source : +: : : replica to target replica. : + +Note that there are the following restrictions on the `source_target_pair`: + +- Any two pairs should not have the same target replica id, and they should + not have the same source replica id. +- If a replica id is not a target in any pair, then the output on that replica + is a tensor consists of 0(s) with the same shape as the input. + ## Concatenate See also @@ -1423,10 +1466,11 @@ Builds a constant literal on device rather than a potentially large host transfer. Creates a rank 1 array of values starting at zero and incrementing by one. -Arguments | Type | Semantics ---------- | --------------- | ------------------------------------ -`type` | `PrimitiveType` | type U -`size` | `int64` | The number of elements in the array. +Arguments | Type | Semantics +---------------- | --------------- | ------------------------------------ +`type` | `PrimitiveType` | type U +`size` | `int64` | The number of elements in the array. +`iota_dimension` | `int64` | The dimension to increment along. ## Map @@ -1780,8 +1824,9 @@ XlaBuilder builder(client_, "reduce_window_2x3"); auto shape = ShapeUtil::MakeShape(F32, {4, 6}); auto input = builder.Parameter(0, shape, "input"); builder.ReduceWindow( - input, *max, + input, /*init_val=*/builder.ConstantLiteral(LiteralUtil::MinValue(F32)), + *max, /*window_dimensions=*/{2, 3}, /*window_stride_dimensions=*/{2, 3}, Padding::kValid); diff --git a/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb b/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb index a83e3f78598..2a83092805b 100644 --- a/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb +++ b/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb @@ -1,25 +1,38 @@ { + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "The XLA compile API", + "version": "0.3.2", + "provenance": [], + "collapsed_sections": [], + "toc_visible": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + } + }, "cells": [ { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "f4TSNCvpENrW" }, + "cell_type": "markdown", "source": [ "##### Copyright 2018 The TensorFlow Authors." ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { "cellView": "form", - "colab": {}, "colab_type": "code", - "id": "vamNSA0vEP-m" + "id": "vamNSA0vEP-m", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", @@ -32,139 +45,84 @@ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." - ] - }, - { - "cell_type": "code", + ], "execution_count": 0, - "metadata": { - "cellView": "form", - "colab": {}, - "colab_type": "code", - "id": "xD_ydfejEV7H" - }, - "outputs": [], - "source": [ - "#@title MIT License\n", - "#\n", - "# Copyright (c) 2017 FranƧois Chollet\n", - "#\n", - "# Permission is hereby granted, free of charge, to any person obtaining a\n", - "# copy of this software and associated documentation files (the \"Software\"),\n", - "# to deal in the Software without restriction, including without limitation\n", - "# the rights to use, copy, modify, merge, publish, distribute, sublicense,\n", - "# and/or sell copies of the Software, and to permit persons to whom the\n", - "# Software is furnished to do so, subject to the following conditions:\n", - "#\n", - "# The above copyright notice and this permission notice shall be included in\n", - "# all copies or substantial portions of the Software.\n", - "#\n", - "# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n", - "# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n", - "# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL\n", - "# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n", - "# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\n", - "# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER\n", - "# DEALINGS IN THE SOFTWARE." - ] + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "e1oSi4lHFt3z" }, + "cell_type": "markdown", "source": [ - "# Welcome to `xla.compile()` tutorial" + "# The XLA compile API" ] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "b7noD9NjFRL-" }, + "cell_type": "markdown", "source": [ - "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n", - " \u003ctd\u003e\n", - " \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/xla/jit#turning_on_jit_compilation\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n", - " \u003c/td\u003e\n", - " \u003ctd\u003e\n", - " \u003ca target=\"_blank\" href=\"https://colab.sandbox.google.com/github/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", - " \u003c/td\u003e\n", - " \u003ctd\u003e\n", - " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n", - " \u003c/td\u003e\n", - "\u003c/table\u003e" + "\n", + " \n", + " \n", + " \n", + "
\n", + " View on TensorFlow.org\n", + " \n", + " Run in Google Colab\n", + " \n", + " View source on GitHub\n", + "
" ] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "v9YbsuLZaBXy" }, + "cell_type": "markdown", "source": [ - "xla.compile() is a new experimental API that compiles part or all of a model with [XLA](https://www.tensorflow.org/extend/xla/).\n", "\n", - "Please run all code blocks in order." + "\n", + "Import TensorFlow and the XLA library. XLA contains `xla.compile()`, an experimental API that compiles part or all of a model with [XLA](https://www.tensorflow.org/extend/xla/)." ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "45kUPj5ZFrRa" + "id": "45kUPj5ZFrRa", + "colab": {} }, - "outputs": [], - "source": [ - "import tensorflow as tf" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "9NMQFjroSMns" - }, - "source": [ - "Imports XLA library, which includes xla.compile() experimental API." - ] - }, - { "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "-Uggy03rSGJm" - }, - "outputs": [], "source": [ + "import tensorflow as tf\n", + "\n", "from tensorflow.contrib.compiler import xla" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "GZVNiRmTDV-5" }, + "cell_type": "markdown", "source": [ - "Define some necessary constants and prepare MNIST dataset." + "Define some necessary constants and prepare the MNIST dataset." ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "f37TSEGvGX4_" + "id": "f37TSEGvGX4_", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "# Size of each input image, 28 x 28 pixels\n", "IMAGE_SIZE = 28 * 28\n", @@ -174,17 +132,17 @@ "TRAIN_BATCH_SIZE = 100\n", "# Number of training steps to run\n", "TRAIN_STEPS = 1000" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "TiVXchblG5hK" + "id": "TiVXchblG5hK", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "# Loads MNIST dataset.\n", "train, test = tf.keras.datasets.mnist.load_data()\n", @@ -195,16 +153,18 @@ "images, labels = iterator.get_next()\n", "images = tf.reshape(images, [-1, IMAGE_SIZE])\n", "images, labels = tf.cast(images, tf.float32), tf.cast(labels, tf.int64)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "x_ZehpZP-SfS" }, + "cell_type": "markdown", "source": [ - "## Defines build_mnist_model function to construct model\n", + "# Define the model constructing function\n", "\n", "Following code block contains a function that constructs a simple model with one dense layer, including both forward and backward propagation.\n", "\n", @@ -212,14 +172,12 @@ ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "ZbhJl_WvGa3g" + "id": "ZbhJl_WvGa3g", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "def build_mnist_model(x, y_):\n", " y = tf.keras.layers.Dense(NUM_CLASSES).apply(x)\n", @@ -228,47 +186,41 @@ " train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)\n", "\n", " return y, train_step" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "7Jh3lyQHDfM9" }, - "source": [ - "## Uses xla.compile with build_mnist_model function to enable XLA" - ] - }, - { "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "EtDwez_1gjzv" - }, "source": [ - "Following code block wraps the model with xla.compile(), which allows the target function with provided inputs to be executed by XLA." + "# Enable XLA\n", + "\n", + "Use `xla.compile` with the `build_mnist_model` function to enable XLA. Following code block wraps the model with `xla.compile()`, which allows the target function with provided inputs to be executed by XLA." ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "kYpCXCdRHNuN" + "id": "kYpCXCdRHNuN", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "[y] = xla.compile(build_mnist_model, inputs=[images, labels])" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "4giQh62IrZGF" }, + "cell_type": "markdown", "source": [ "When compiling the graph, XLA replaces all the graph nodes constructed in the target function with a few XLA ops.\n", "\n", @@ -293,62 +245,62 @@ ] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "TPGas4jjFLZl" }, + "cell_type": "markdown", "source": [ "If you were to print the constructed graph now, you will see that it is not much different from a normal Tensorflow graph and you won't be able to find XLA ops mentioned before. This is because the actual compilation happens later when you try to execute the graph with `sess.run()`. At that time, Tensorflow triggers a series of graph rewrite passes that actually generate XLA ops, which compiles and executes computation when all inputs are ready." ] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "EZD1m_n1DxAF" }, + "cell_type": "markdown", "source": [ - "## Trains and tests the model" + "# Train and test the model" ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "qe28bAHNHUG2" + "id": "qe28bAHNHUG2", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "# Creates session and initialize all variables.\n", "# xla.compile() doesn't work with Keras model.fit() API or TF eager mode yet.\n", "sess = tf.Session()\n", "sess.run(tf.global_variables_initializer())" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "qgsKmz3n2UiW" }, + "cell_type": "markdown", "source": [ - "Following code block trains model.\n", - "\n", - "Note that evaluating `y` also triggers its control dependency node `train_step`, which updates model variables." + "Following code block trains model. Evaluating `y` also triggers its control dependency node `train_step`, which updates model variables." ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "_GxF6jTRHVuA" + "id": "_GxF6jTRHVuA", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "outputId": "fbf299ca-02d5-4e95-f9fe-8f3c0432d132" }, - "outputs": [], + "cell_type": "code", "source": [ "# Feeds training dataset\n", "sess.run(iterator.make_initializer(train_ds))\n", @@ -356,18 +308,31 @@ "# Runs TRAIN_STEPS steps\n", "for i in range(TRAIN_STEPS):\n", " sess.run(y)\n", + "\n", "print(\"Model trained for %s steps.\" % TRAIN_STEPS)" + ], + "execution_count": 21, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Model trained for 1000 steps.\n" + ], + "name": "stdout" + } ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "dHlQlRSRHXD1" + "id": "dHlQlRSRHXD1", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "outputId": "9c3677a2-ec84-406f-9d2c-d722844f3093" }, - "outputs": [], + "cell_type": "code", "source": [ "# Tests trained model\n", "\n", @@ -378,35 +343,31 @@ "correct_prediction = tf.equal(tf.argmax(y, 1), labels)\n", "accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n", "print(\"Prediction accuracy after training: %s\" % sess.run(accuracy))" + ], + "execution_count": 22, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Prediction accuracy after training: 0.91\n" + ], + "name": "stdout" + } ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "ynJQIuzjHYOb" + "id": "ynJQIuzjHYOb", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "# Cleans up session\n", "sess.close()" - ] + ], + "execution_count": 0, + "outputs": [] } - ], - "metadata": { - "colab": { - "collapsed_sections": [], - "name": "xla.compile() Tutorial", - "provenance": [], - "version": "0.3.2" - }, - "kernelspec": { - "display_name": "Python 2", - "name": "python2" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} + ] +} \ No newline at end of file diff --git a/tensorflow/compiler/xla/index_util.h b/tensorflow/compiler/xla/index_util.h index 458bdaf2f89..d76f61eb62c 100644 --- a/tensorflow/compiler/xla/index_util.h +++ b/tensorflow/compiler/xla/index_util.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/types/span.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/macros.h" diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index 2398470dd49..dbb81381acd 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -460,6 +460,13 @@ std::ostream& operator<<(std::ostream& out, const Layout& layout) { } hash_value = Hash64Combine(hash_value, layout.max_sparse_elements()); + for (Tile tile : layout.tiles()) { + for (int64 tile_dim : tile.dimensions()) { + hash_value = Hash64Combine(hash_value, hash()(tile_dim)); + } + } + hash_value = Hash64Combine(hash_value, layout.element_size_in_bits()); + return hash_value; } diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index 6e0390763da..6c298e57252 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/types/span.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index cb00a0ab16d..8f480c1f107 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -62,6 +63,14 @@ void ConvertEndianShort(char* bytes, int64 size) { } } +// Since Eigen::half doesn't satisfy the absl::bit_cast contract, we need to be +// able to transparently access the raw 16-bit value contained within. +template +T GetRawValue(T val) { + return val; +} +uint16 GetRawValue(Eigen::half val) { return val.x; } + } // namespace LiteralBase::~LiteralBase() {} @@ -283,16 +292,17 @@ Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal, if (!proto.has_shape()) { return InvalidArgument("LiteralProto has no shape"); } - if (ShapeUtil::HasPrimitiveType(proto.shape(), OPAQUE)) { + Shape shape(proto.shape()); + if (ShapeUtil::HasPrimitiveType(shape, OPAQUE)) { return InvalidArgument("Literal shape cannot include OPAQUE sub-shape"); } - if (!LayoutUtil::HasLayout(proto.shape())) { + if (!LayoutUtil::HasLayout(shape)) { return InvalidArgument("LiteralProto has no layout"); } - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(proto.shape())); + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape)); - Literal literal(proto.shape()); + Literal literal(shape); TF_RETURN_IF_ERROR(literal.root_piece_->ForEachMutableSubpieceWithStatus( [&](const ShapeIndex& index, Piece* piece) { @@ -1012,166 +1022,143 @@ void LiteralBase::Piece::SortSparseElementsInternal() { namespace { +string ShapeToString(bool print_layout, const Shape& shape) { + return print_layout ? ShapeUtil::HumanStringWithLayout(shape) + : ShapeUtil::HumanString(shape); +} + +void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, + bool print_layout, std::vector* pieces); + +void TupleToStringHelper(const LiteralBase& literal, + const ShapeIndex& shape_index, bool print_layout, + std::vector* pieces) { + const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); + pieces->push_back(ShapeToString(print_layout, subshape)); + pieces->push_back(" (\n"); + std::vector tuple_pieces; + for (int i = 0; i < ShapeUtil::TupleElementCount(subshape); ++i) { + ShapeIndex element_index = shape_index; + element_index.push_back(i); + std::vector element_pieces; + ToStringHelper(literal, element_index, print_layout, &element_pieces); + tuple_pieces.push_back(absl::StrJoin(element_pieces, "")); + } + pieces->push_back(absl::StrJoin(tuple_pieces, ",\n")); + pieces->push_back("\n)"); +} + +void SparseArrayToStringHelper(const LiteralBase& literal, + const Shape& subshape, bool print_layout, + std::vector* pieces) { + pieces->push_back(ShapeToString(print_layout, subshape)); + pieces->push_back("{"); + int64 rank = ShapeUtil::Rank(subshape); + int64 num_elements = literal.sparse_element_count(); + for (int64 i = 0; i < num_elements; ++i) { + if (i > 0) { + pieces->push_back(", "); + } + if (rank == 1) { + pieces->push_back(StrCat(literal.GetSparseIndex(i)[0])); + pieces->push_back(": "); + } else { + pieces->push_back("["); + pieces->push_back(absl::StrJoin(literal.GetSparseIndex(i), ", ")); + pieces->push_back("]: "); + } + pieces->push_back(literal.GetSparseElementAsString(i)); + } + pieces->push_back("}"); +} + +void DenseArrayToStringHelper(const LiteralBase& literal, + const ShapeIndex& shape_index, bool print_layout, + std::vector* pieces) { + const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); + int64 rank = ShapeUtil::Rank(subshape); + + std::function dimensions, std::vector*)> + to_string_recursive = [&](absl::Span dimensions, + std::vector* accum_indices) { + // dimensions.size() decreases by 1 at each recursive call, + // and accum_indices->size() increases by 1. + // Their sum is equal to the rank of the tensor. + CHECK_EQ(rank, dimensions.size() + accum_indices->size()); + + auto brace_to_string = [&](string brace) -> string { + // Handle 1D tensor + if (rank == 1) { + return brace; + } + // Handle the innermost tensor of a 2D+ tensor. + if (dimensions.size() == 1 && brace == "{") { + return StrCat(" ", brace, dimensions[0] <= 1 ? "" : " "); + } + if (dimensions.size() == 1 && brace == "}") { + return StrCat(dimensions[0] <= 1 ? "" : " ", brace); + } + // Handle the non-innermost tensors of a 2D+ tensor. + if (brace == "{") { + if (rank > 3 && !accum_indices->empty() && + accum_indices->size() < rank) { + int index = accum_indices->size() - 1; + int value = accum_indices->back(); + return StrCat(brace, " /*i", index, "=", value, "*/\n"); + } + return StrCat(brace, "\n"); + } + return StrCat("\n", brace); + }; + + if (dimensions.empty()) { + // Display predicates as 0s and 1s so that the string is more dense. + string elem; + if (subshape.element_type() == PRED && rank > 0) { + elem = literal.Get(*accum_indices, shape_index) ? "1" : "0"; + } else { + elem = literal.GetAsString(*accum_indices, shape_index); + } + pieces->push_back(elem); + } else { + pieces->push_back(brace_to_string("{")); + for (int i = 0; i < dimensions[0]; ++i) { + std::vector cloned_indices(*accum_indices); + cloned_indices.push_back(i); + to_string_recursive(dimensions.subspan(1), &cloned_indices); + if (i < dimensions[0] - 1) { + pieces->push_back(","); + pieces->push_back(dimensions.size() > 1 ? "\n" : " "); + } + } + pieces->push_back(brace_to_string("}")); + } + }; + + if (rank > 1) { + pieces->push_back(ShapeToString(print_layout, subshape)); + pieces->push_back(" "); + } + std::vector indices = {}; + std::vector dimensions(subshape.dimensions().begin(), + subshape.dimensions().end()); + to_string_recursive(dimensions, &indices); +} + void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, bool print_layout, std::vector* pieces) { const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); CHECK(LayoutUtil::HasLayout(literal.shape())); CHECK(LayoutUtil::HasLayout(subshape)); - - auto shape_to_string = [print_layout](const Shape& shape) { - if (print_layout) { - return ShapeUtil::HumanStringWithLayout(shape); - } else { - return ShapeUtil::HumanString(shape); - } - }; - - // TODO(b/32894291): refactor this code to reduce code duplication. if (ShapeUtil::IsTuple(subshape)) { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back(" (\n"); - std::vector tuple_pieces; - for (int i = 0; i < ShapeUtil::TupleElementCount(subshape); ++i) { - ShapeIndex element_index = shape_index; - element_index.push_back(i); - std::vector element_pieces; - ToStringHelper(literal, element_index, print_layout, &element_pieces); - tuple_pieces.push_back(absl::StrJoin(element_pieces, "")); - } - pieces->push_back(absl::StrJoin(tuple_pieces, ",\n")); - pieces->push_back("\n)"); - return; - } - - if (ShapeUtil::IsToken(subshape)) { + TupleToStringHelper(literal, shape_index, print_layout, pieces); + } else if (ShapeUtil::IsToken(subshape)) { pieces->push_back("token"); - return; - } - - if (LayoutUtil::IsSparseArray(subshape)) { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back("{"); - int64 rank = ShapeUtil::Rank(subshape); - int64 num_elements = literal.sparse_element_count(); - for (int64 i = 0; i < num_elements; ++i) { - if (i > 0) { - pieces->push_back(", "); - } - if (rank == 1) { - pieces->push_back(StrCat(literal.GetSparseIndex(i)[0])); - pieces->push_back(": "); - } else { - pieces->push_back("["); - pieces->push_back(absl::StrJoin(literal.GetSparseIndex(i), ", ")); - pieces->push_back("]: "); - } - pieces->push_back(literal.GetSparseElementAsString(i)); - } - pieces->push_back("}"); - return; - } - - CHECK(LayoutUtil::IsDenseArray(subshape)); - - auto element_to_string = [&](absl::Span indices) -> string { - PrimitiveType element_type = subshape.element_type(); - // We display predicates as 0s and 1s so that the string is more dense. - string elem = element_type == PRED - ? literal.Get(indices, shape_index) ? "1" : "0" - : literal.GetAsString(indices, shape_index); - return ((!indices.empty() && indices.back() > 0) ? ", " : "") + elem; - }; - - if (ShapeUtil::Rank(subshape) == 0) { - pieces->push_back(literal.GetAsString({}, shape_index)); - } else if (ShapeUtil::Rank(subshape) == 1) { - pieces->push_back("{"); - for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { - pieces->push_back(element_to_string({i0})); - } - pieces->push_back("}"); - } else if (ShapeUtil::Rank(subshape) == 2) { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back(" {\n"); - for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { - pieces->push_back(" { "); - for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { - pieces->push_back(element_to_string({i0, i1})); - } - pieces->push_back(" "); - pieces->push_back(i0 == subshape.dimensions(0) - 1 ? "}\n" : "},\n"); - } - pieces->push_back("}"); - } else if (ShapeUtil::Rank(subshape) == 3) { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back(" {\n"); - for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { - pieces->push_back(i0 > 0 ? ",\n{" : "{"); - for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { - pieces->push_back(i1 > 0 ? ",\n { " : " { "); - for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { - pieces->push_back(element_to_string({i0, i1, i2})); - } - pieces->push_back(" }"); - } - pieces->push_back(" }"); - } - pieces->push_back("\n}"); - } else if (ShapeUtil::Rank(subshape) == 4) { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back(" {\n"); - for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { - pieces->push_back(StrFormat(" { /*i0=%d*/\n", i0)); - for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { - pieces->push_back(StrFormat(" { /*i1=%d*/\n", i1)); - for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { - pieces->push_back(" {"); - for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) { - pieces->push_back(element_to_string({i0, i1, i2, i3})); - } - pieces->push_back(i2 == subshape.dimensions(2) - 1 ? "}\n" : "},\n"); - } - pieces->push_back(i1 == subshape.dimensions(1) - 1 ? " }\n" - : " },\n"); - } - pieces->push_back(i0 == subshape.dimensions(0) - 1 ? " }\n" : " },\n"); - } - pieces->push_back("}"); - } else if (ShapeUtil::Rank(subshape) == 5) { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back(" {\n"); - for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { - pieces->push_back(StrFormat(" { /*i0=%d*/\n", i0)); - for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { - pieces->push_back(StrFormat(" { /*i1=%d*/\n", i1)); - for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { - pieces->push_back(StrFormat(" { /*i2=%d*/\n", i2)); - for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) { - pieces->push_back(" {"); - for (int64 i4 = 0; i4 < subshape.dimensions(4); ++i4) { - pieces->push_back(element_to_string({i0, i1, i2, i3, i4})); - } - pieces->push_back(i3 == subshape.dimensions(3) - 1 ? "}\n" - : "},\n"); - } - pieces->push_back(i2 == subshape.dimensions(2) - 1 ? " }\n" - : " },\n"); - } - pieces->push_back(i1 == subshape.dimensions(1) - 1 ? " }\n" - : " },\n"); - } - pieces->push_back(i0 == subshape.dimensions(0) - 1 ? " }\n" : " },\n"); - } - pieces->push_back("}"); + } else if (LayoutUtil::IsSparseArray(subshape)) { + SparseArrayToStringHelper(literal, subshape, print_layout, pieces); } else { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back(" {"); - literal.EachCellAsString( - [&](absl::Span indices, const string& value) { - pieces->push_back(" "); - pieces->push_back(value); - }); - pieces->push_back("}"); + CHECK(LayoutUtil::IsDenseArray(subshape)); + DenseArrayToStringHelper(literal, shape_index, print_layout, pieces); } } @@ -1228,16 +1215,32 @@ Literal ConvertBetweenNativeTypes(const LiteralBase& src_literal) { } template -typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)), +typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT) && + !std::is_same::value), Literal>::type BitcastBetweenNativeTypes(const LiteralBase& src_literal) { auto converter = [](NativeSrcT src) { - return absl::bit_cast(src); + return absl::bit_cast(GetRawValue(src)); }; return ConvertBetweenNativeTypesWithConverter( src_literal, converter); } +template +typename std::enable_if<(sizeof(NativeSrcT) == sizeof(Eigen::half) && + std::is_same::value), + Literal>::type +BitcastBetweenNativeTypes(const LiteralBase& src_literal) { + // Eigen::half doesn't satisfy the absl::bit_cast contract, so explicitly + // cast to unsigned short and then use raw_uint16_to_half. + auto converter = [](NativeSrcT src) { + return Eigen::half_impl::raw_uint16_to_half( + absl::bit_cast(GetRawValue(src))); + }; + return ConvertBetweenNativeTypesWithConverter( + src_literal, converter); +} + // This template specialization is here to make the compiler happy. bit_cast has // a static check that the types are the same size. This specialization should // never be used because the source and destination types are checked for @@ -1792,7 +1795,7 @@ void CopyToRepeatedField(RepeatedFieldT* dest, } // namespace void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { - *proto->mutable_shape() = subshape(); + *proto->mutable_shape() = subshape().ToProto(); switch (subshape().element_type()) { case PRED: CopyToRepeatedField(proto->mutable_preds(), data()); @@ -1898,8 +1901,9 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { // These conditions should have been checked in // MutableLiteralBase::CreateFromProto. TF_RET_CHECK(proto.has_shape()); - TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape())); - TF_RET_CHECK(ShapeUtil::Equal(proto.shape(), subshape())); + Shape shape(proto.shape()); + TF_RET_CHECK(LayoutUtil::HasLayout(shape)); + TF_RET_CHECK(ShapeUtil::Equal(shape, subshape())); if (LayoutUtil::IsSparseArray(subshape())) { // Compute the number of elements (indices) in the sparse shape and reserve diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index e791048b4d9..fa9a71af4ce 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -301,7 +301,7 @@ class LiteralBase { // // Note: It's an antipattern to use this method then immediately call // MutableLiteralBase::Populate on the result (since that results in zero - // initialization, then reinitialization. Conside if a call to + // initialization, then reinitialization. Consider if a call to // absl::make_unique(shape), followed by the call to // MutableLiteralBase::Populate can be used instead. static Literal CreateFromShape(const Shape& shape); diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc index 8cec37897a9..49363ad802d 100644 --- a/tensorflow/compiler/xla/literal_test.cc +++ b/tensorflow/compiler/xla/literal_test.cc @@ -150,12 +150,58 @@ TEST_F(LiteralUtilTest, R3ToString) { const auto literal = LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}, {{5}, {6}}}); const string expected = R"(s32[3,2,1] { -{ { 1 }, - { 2 } }, -{ { 3 }, - { 4 } }, -{ { 5 }, - { 6 } } +{ + {1}, + {2} +}, +{ + {3}, + {4} +}, +{ + {5}, + {6} +} +})"; + EXPECT_EQ(expected, literal.ToString()); +} + +TEST_F(LiteralUtilTest, R6ToString) { + const auto literal = + LiteralUtil::CreateFromDimensions(S32, {2, 2, 1, 1, 1, 2}); + const string expected = R"(s32[2,2,1,1,1,2] { +{ /*i0=0*/ +{ /*i1=0*/ +{ /*i2=0*/ +{ /*i3=0*/ + { 0, 0 } +} +} +}, +{ /*i1=1*/ +{ /*i2=0*/ +{ /*i3=0*/ + { 0, 0 } +} +} +} +}, +{ /*i0=1*/ +{ /*i1=0*/ +{ /*i2=0*/ +{ /*i3=0*/ + { 0, 0 } +} +} +}, +{ /*i1=1*/ +{ /*i2=0*/ +{ /*i3=0*/ + { 0, 0 } +} +} +} +} })"; EXPECT_EQ(expected, literal.ToString()); } @@ -190,12 +236,16 @@ TEST_F(LiteralUtilTest, CreateR3FromArray3d) { EXPECT_THAT(literal.shape().dimensions(), ElementsAre(2, 3, 2)); string result = literal.ToString(); const string expected = R"(f32[2,3,2] { -{ { 1, 2 }, +{ + { 1, 2 }, { 3, 4 }, - { 5, 6 } }, -{ { 7, 8 }, + { 5, 6 } +}, +{ + { 7, 8 }, { 9, 10 }, - { 11, 12 } } + { 11, 12 } +} })"; EXPECT_EQ(expected, result); } @@ -247,18 +297,18 @@ TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { EXPECT_THAT(literal.shape().dimensions(), ElementsAre(1, 2, 3, 2)); string result = literal.ToString(); const string expected = R"(f32[1,2,3,2] { - { /*i0=0*/ - { /*i1=0*/ - {1, 2}, - {1001, 1002}, - {2001, 2002} - }, - { /*i1=1*/ - {1, 2}, - {1001, 1002}, - {2001, 2002} - } - } +{ /*i0=0*/ +{ /*i1=0*/ + { 1, 2 }, + { 1001, 1002 }, + { 2001, 2002 } +}, +{ /*i1=1*/ + { 1, 2 }, + { 1001, 1002 }, + { 2001, 2002 } +} +} })"; EXPECT_EQ(expected, result); } @@ -268,30 +318,30 @@ TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) { ElementsAre(2, 2, 3, 3)); string result = literal_r4_2x2x3x3_dim0major_.ToString(); const string expected = R"(f32[2,2,3,3] { - { /*i0=0*/ - { /*i1=0*/ - {1, 2, 3}, - {4, 5, 6}, - {7, 8, 9} - }, - { /*i1=1*/ - {11, 12, 13}, - {14, 15, 16}, - {17, 18, 19} - } - }, - { /*i0=1*/ - { /*i1=0*/ - {101, 102, 103}, - {104, 105, 106}, - {107, 108, 109} - }, - { /*i1=1*/ - {201, 202, 203}, - {204, 205, 206}, - {207, 208, 209} - } - } +{ /*i0=0*/ +{ /*i1=0*/ + { 1, 2, 3 }, + { 4, 5, 6 }, + { 7, 8, 9 } +}, +{ /*i1=1*/ + { 11, 12, 13 }, + { 14, 15, 16 }, + { 17, 18, 19 } +} +}, +{ /*i0=1*/ +{ /*i1=0*/ + { 101, 102, 103 }, + { 104, 105, 106 }, + { 107, 108, 109 } +}, +{ /*i1=1*/ + { 201, 202, 203 }, + { 204, 205, 206 }, + { 207, 208, 209 } +} +} })"; EXPECT_EQ(expected, result); } @@ -1327,13 +1377,26 @@ TEST_F(LiteralUtilTest, BitcastConvertBetweenInvalidTypes) { absl::StrContains(status.error_message(), "bit widths are different")); } +// Sets the layout of the given ShapeProto to the default. +void SetDefaultLayoutOnProto(ShapeProto* shape_proto) { + CHECK(ShapeUtil::IsArrayPrimitiveType(shape_proto->element_type())); + shape_proto->mutable_layout()->set_format(DENSE); + auto* minor_to_major = + shape_proto->mutable_layout()->mutable_minor_to_major(); + minor_to_major->Resize(shape_proto->dimensions_size(), 0); + const int64 size = minor_to_major->size(); + for (int64 i = 0; i < size; ++i) { + minor_to_major->Set(i, size - 1 - i); + } +} + TEST_F(LiteralUtilTest, CopyFromProto_Bool) { LiteralProto p; p.mutable_shape()->set_element_type(PRED); for (int len = 0; len < 25; ++len) { p.mutable_shape()->clear_dimensions(); p.mutable_shape()->add_dimensions(len); - LayoutUtil::SetToDefaultLayout(p.mutable_shape()); + SetDefaultLayoutOnProto(p.mutable_shape()); p.clear_preds(); for (int i = 0; i < len; ++i) { p.add_preds((i % 2) == (len % 2)); @@ -1359,7 +1422,7 @@ TEST_F(LiteralUtilTest, ToProto_f16) { EXPECT_EQ(4, m.data().size()); LiteralProto p = m.ToProto(); - EXPECT_EQ(4, ShapeUtil::ElementsIn(p.shape())); + EXPECT_EQ(4, ShapeUtil::ElementsIn(Shape(p.shape()))); EXPECT_EQ(8, p.f16s().size()); const char* d = p.f16s().data(); EXPECT_EQ(d[0], 0); @@ -1382,7 +1445,7 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) { p.mutable_shape()->set_element_type(F16); p.mutable_shape()->clear_dimensions(); p.mutable_shape()->add_dimensions(4); - LayoutUtil::SetToDefaultLayout(p.mutable_shape()); + SetDefaultLayoutOnProto(p.mutable_shape()); p.clear_f16s(); p.set_f16s(half_vals, 8); TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p)); @@ -1404,7 +1467,7 @@ TEST_F(LiteralUtilTest, CopyFromProto_u16) { p.mutable_shape()->set_element_type(U16); p.mutable_shape()->clear_dimensions(); p.mutable_shape()->add_dimensions(4); - LayoutUtil::SetToDefaultLayout(p.mutable_shape()); + SetDefaultLayoutOnProto(p.mutable_shape()); p.clear_u16s(); p.set_u16s(uint16_vals, 8); TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p)); @@ -1537,9 +1600,9 @@ TEST_F(LiteralUtilTest, DecomposeTuple) { Literal nested_tuple = LiteralUtil::MakeTuple( {&tuple_elements[0], &tuple_elements[1], &nil_literal}); - EXPECT_FALSE(ShapeUtil::IsNil(nested_tuple.shape())); + EXPECT_FALSE(ShapeUtil::IsEmptyTuple(nested_tuple.shape())); std::vector elements = nested_tuple.DecomposeTuple(); - EXPECT_TRUE(ShapeUtil::IsNil(nested_tuple.shape())); + EXPECT_TRUE(ShapeUtil::IsEmptyTuple(nested_tuple.shape())); ASSERT_EQ(elements.size(), 3); @@ -1590,7 +1653,7 @@ TEST_F(LiteralUtilTest, MoveIntoTuple) { EXPECT_EQ(literal.Get({1}, /*shape_index=*/{2, 1}), 44.0); for (const Literal& element : elements) { - EXPECT_TRUE(ShapeUtil::IsNil(element.shape())); + EXPECT_TRUE(ShapeUtil::IsEmptyTuple(element.shape())); } } @@ -1706,7 +1769,7 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { TEST_F(LiteralUtilTest, InvalidProtoNoValues) { // Proto contains a shape, but no values. LiteralProto proto; - *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}); + *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}).ToProto(); Status status = Literal::CreateFromProto(proto).status(); ASSERT_FALSE(status.ok()); EXPECT_THAT(status.error_message(), @@ -1727,7 +1790,7 @@ TEST_F(LiteralUtilTest, InvalidProtoNoShape) { TEST_F(LiteralUtilTest, InvalidProtoWrongContainer) { // Proto contains values in wrong container. LiteralProto proto; - *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}); + *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}).ToProto(); proto.add_preds(false); proto.add_preds(true); proto.add_preds(false); @@ -1740,7 +1803,7 @@ TEST_F(LiteralUtilTest, InvalidProtoWrongContainer) { TEST_F(LiteralUtilTest, InvalidProtoTooFewValues) { // Proto contains too few values. LiteralProto proto; - *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {42, 2}); + *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {42, 2}).ToProto(); proto.add_f32s(1.0); proto.add_f32s(2.0); proto.add_f32s(3.0); @@ -1753,7 +1816,7 @@ TEST_F(LiteralUtilTest, InvalidProtoTooFewValues) { TEST_F(LiteralUtilTest, InvalidProtoTooManyValues) { // Proto contains too many values. LiteralProto proto; - *proto.mutable_shape() = ShapeUtil::MakeShape(S32, {2}); + *proto.mutable_shape() = ShapeUtil::MakeShape(S32, {2}).ToProto(); proto.add_s32s(42); proto.add_s32s(-10); proto.add_s32s(100); @@ -1766,8 +1829,8 @@ TEST_F(LiteralUtilTest, InvalidProtoTooManyValues) { TEST_F(LiteralUtilTest, InvalidProtoMissingLayout) { // Proto shape missing layout. LiteralProto proto; - *proto.mutable_shape() = ShapeUtil::MakeShape(PRED, {2, 2}); - LayoutUtil::ClearLayout(proto.mutable_shape()); + *proto.mutable_shape() = ShapeUtil::MakeShape(PRED, {2, 2}).ToProto(); + proto.mutable_shape()->clear_layout(); proto.add_preds(true); proto.add_preds(false); proto.add_preds(true); @@ -1780,11 +1843,13 @@ TEST_F(LiteralUtilTest, InvalidProtoMissingLayout) { TEST_F(LiteralUtilTest, InvalidProtoTooFewTupleElements) { // Proto has the too few tuple elements. LiteralProto proto; - *proto.mutable_shape() = ShapeUtil::MakeTupleShape( - {ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})}); + *proto.mutable_shape() = + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})}) + .ToProto(); LiteralProto* element0 = proto.add_tuple_literals(); *element0->mutable_shape() = - ShapeUtil::GetTupleElementShape(proto.shape(), 0); + ShapeUtil::GetTupleElementShape(Shape(proto.shape()), 0).ToProto(); element0->add_preds(false); element0->add_preds(true); @@ -1796,19 +1861,21 @@ TEST_F(LiteralUtilTest, InvalidProtoTooFewTupleElements) { TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) { // Proto has the too many tuple elements. LiteralProto proto; - *proto.mutable_shape() = ShapeUtil::MakeTupleShape( - {ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})}); + *proto.mutable_shape() = + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})}) + .ToProto(); LiteralProto* element0 = proto.add_tuple_literals(); *element0->mutable_shape() = - ShapeUtil::GetTupleElementShape(proto.shape(), 0); + ShapeUtil::GetTupleElementShape(Shape(proto.shape()), 0).ToProto(); element0->add_preds(false); element0->add_preds(true); LiteralProto* element1 = proto.add_tuple_literals(); *element1->mutable_shape() = - ShapeUtil::GetTupleElementShape(proto.shape(), 1); + ShapeUtil::GetTupleElementShape(Shape(proto.shape()), 1).ToProto(); element1->add_f32s(42.0); LiteralProto* element2 = proto.add_tuple_literals(); - *element2->mutable_shape() = ShapeUtil::MakeShape(F32, {}); + *element2->mutable_shape() = ShapeUtil::MakeShape(F32, {}).ToProto(); element2->add_f32s(123.0); Status status = Literal::CreateFromProto(proto).status(); diff --git a/tensorflow/compiler/xla/parse_flags_from_env.cc b/tensorflow/compiler/xla/parse_flags_from_env.cc index 40481331b69..5b568888d14 100644 --- a/tensorflow/compiler/xla/parse_flags_from_env.cc +++ b/tensorflow/compiler/xla/parse_flags_from_env.cc @@ -13,15 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This module exports ParseFlagsFromEnv(), which allows other modules to parse -// flags from an environtment variable, or a file named by the environment -// variable. +// This module exports ParseFlagsFromEnvAndDieIfUnknown(), which allows other +// modules to parse flags from an environtment variable, or a file named by the +// environment variable. #include #include #include +#include +#include #include +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/parse_flags_from_env.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/logging.h" @@ -32,7 +37,6 @@ limitations under the License. namespace xla { -static const char kEnvVar[] = "TF_XLA_FLAGS"; // environment variable queried static const char kWS[] = " \t\r\n"; // whitespace // The following struct represents an argv[]-style array, parsed @@ -42,12 +46,20 @@ static const char kWS[] = " \t\r\n"; // whitespace // constructor/destructor collisions with other "private" types // in the same named namespace. namespace { + +// Functor which deletes objects by calling `free`. Necessary to free strdup'ed +// strings created by AppendToEnvArgv. +struct FreeDeleter { + void operator()(char* ptr) { free(ptr); } +}; + struct EnvArgv { EnvArgv() : initialized(false), argc(0) {} bool initialized; // whether the other fields have been set. int argc; // elements used in argv[] std::vector argv; // flag arguments parsed from environment string. - std::vector argv_save; // saved values from argv[] to avoid leaks + // saved values from argv[] to avoid leaks + std::vector> argv_save; }; } // anonymous namespace @@ -63,7 +75,7 @@ static void AppendToEnvArgv(const char* s0, size_t s0len, const char* s1, string s = string(s0, s0len) + string(s1, s1len); char* str = strdup(s.c_str()); a->argv.push_back(str); - a->argv_save.push_back(str); + a->argv_save.emplace_back(str); a->argc++; } } @@ -127,14 +139,14 @@ static void ParseArgvFromString(const string& flag_str, EnvArgv* a) { } } -// Call ParseArgvFromString(..., a) on a string derived from the setting of an -// environment variable kEnvVar, or a file it points to. -static void SetArgvFromEnv(EnvArgv* a) { +// Call ParseArgvFromString(..., a) on a string derived from the setting of the +// environment variable `envvar`, or a file it points to. +static void SetArgvFromEnv(absl::string_view envvar, EnvArgv* a) { if (!a->initialized) { static const char kDummyArgv[] = ""; AppendToEnvArgv(kDummyArgv, strlen(kDummyArgv), nullptr, 0, a); // dummy argv[0] - const char* env = getenv(kEnvVar); + const char* env = getenv(string(envvar).c_str()); if (env == nullptr || env[0] == '\0') { // nothing } else if (env[strspn(env, kWS)] == '-') { // flags in env var value @@ -157,48 +169,66 @@ static void SetArgvFromEnv(EnvArgv* a) { } } -// The simulated argv[] parsed from the environment. -static EnvArgv* env_argv; +// The simulated argv[] parsed from the environment, one for each different +// environment variable we've seen. +static std::unordered_map& EnvArgvs() { + static auto* env_argvs = new std::unordered_map(); + return *env_argvs; +} -// Used to protect accesses to env_argv. +// Used to protect accesses to env_argvs. static tensorflow::mutex env_argv_mu(tensorflow::LINKER_INITIALIZED); -// Call Flags::Parse(argc, argv, flag_list) against any as yet unrecognized -// flags passed in from the environment. -bool ParseFlagsFromEnv(const std::vector& flag_list) { - env_argv_mu.lock(); - if (env_argv == nullptr) { - env_argv = new EnvArgv; - } - SetArgvFromEnv(env_argv); // a no-op if already initialized +bool ParseFlagsFromEnvAndDieIfUnknown( + absl::string_view envvar, const std::vector& flag_list) { + tensorflow::mutex_lock lock(env_argv_mu); + auto* env_argv = &EnvArgvs()[string(envvar)]; + SetArgvFromEnv(envvar, env_argv); // a no-op if already initialized bool result = tensorflow::Flags::Parse(&env_argv->argc, &env_argv->argv[0], flag_list); - env_argv_mu.unlock(); + + // There's always at least one unparsed argc, namely the fake argv[0]. + if (result && env_argv->argc != 1) { + // Skip the first argv, which is the fake argv[0]. + auto unknown_flags = absl::MakeSpan(env_argv->argv); + unknown_flags.remove_prefix(1); + + // Some flags are set on XLA_FLAGS, others on TF_XLA_FLAGS. If we find an + // unrecognized flag, suggest the alternative. + string alternate_envvar; + if (envvar == "TF_XLA_FLAGS") { + alternate_envvar = "XLA_FLAGS"; + } else if (envvar == "XLA_FLAGS") { + alternate_envvar = "TF_XLA_FLAGS"; + } + string did_you_mean; + if (!alternate_envvar.empty()) { + did_you_mean = absl::StrFormat( + "\nPerhaps you meant to specify these on the %s envvar?", + alternate_envvar); + } + + LOG(FATAL) << "Unknown flag" << (unknown_flags.size() > 1 ? "s" : "") + << " in " << envvar << ": " << absl::StrJoin(unknown_flags, " ") + << did_you_mean; + return false; + } return result; } // Testing only. -// Reset the env_argv struct so that subsequent calls to ParseFlagsFromEnv() -// will parse the environment variable (or the file it points to) anew, and set -// *pargc, and *pargv to point to the internal locations of the argc and argv -// constructed from the environment. -void ResetFlagsFromEnvForTesting(int** pargc, std::vector** pargv) { - env_argv_mu.lock(); - if (env_argv == nullptr) { - env_argv = new EnvArgv; - } - if (!env_argv->argv_save.empty()) { - for (int i = 0; env_argv->argv_save[i] != nullptr; i++) { - free(env_argv->argv_save[i]); - } - } - env_argv->initialized = false; - env_argv->argc = 0; - env_argv->argv.clear(); - env_argv->argv_save.clear(); - env_argv_mu.unlock(); - *pargc = &env_argv->argc; - *pargv = &env_argv->argv; +// +// Resets the env_argv struct so that subsequent calls to +// ParseFlagsFromEnvAndDieIfUnknown() will parse the environment variable (or +// the file it points to) anew, and set *pargc, and *pargv to point to the +// internal locations of the argc and argv constructed from the environment. +void ResetFlagsFromEnvForTesting(absl::string_view envvar, int** pargc, + std::vector** pargv) { + tensorflow::mutex_lock lock(env_argv_mu); + EnvArgvs().erase(string(envvar)); + auto& env_argv = EnvArgvs()[string(envvar)]; + *pargc = &env_argv.argc; + *pargv = &env_argv.argv; } } // namespace xla diff --git a/tensorflow/compiler/xla/parse_flags_from_env.h b/tensorflow/compiler/xla/parse_flags_from_env.h index fe86ee687f8..76940a4299a 100644 --- a/tensorflow/compiler/xla/parse_flags_from_env.h +++ b/tensorflow/compiler/xla/parse_flags_from_env.h @@ -16,48 +16,58 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_PARSE_FLAGS_FROM_ENV_H_ #define TENSORFLOW_COMPILER_XLA_PARSE_FLAGS_FROM_ENV_H_ -// This module exports ParseFlagsFromEnv(), which allows other modules to parse -// flags from the environtment variable TF_XLA_FLAGS, or (if the first +// This module exports ParseFlagsFromEnvAndDieIfUnknown(), which allows other +// modules to parse flags from an environtment variable, or (if the first // non-whitespace in the variable value is not '-'), a file named by that -// environment variable. The accepted syntax is that flags arguments are of -// the form --flag=value or (for boolean flags) --flag, and are whitespace -// separated. The may be one of: -// - -// in which case the effective value is the string itself -// - in which case the effective value is the -// string with the single-quotes removed -// - in which case the effective value if the -// string with the double-quotes removed, and escaped sequences of -// replaced by . +// environment variable. +// +// The accepted syntax is that flags arguments are of the form --flag=value or +// (for boolean flags) --flag, and are whitespace separated. The may be +// one of: +// +// - +// in which case the effective value is the string itself +// - in which case the effective value is the +// string with the single-quotes removed +// - in which case the effective value if the +// string with the double-quotes removed, and escaped sequences of +// replaced by . // // Flags values inconsistent with the type of the flag will be rejected by the // flag parser. // // Examples: -// TF_XLA_FLAGS="--foo=bar --wombat='value with a space'" // -// TF_XLA_FLAGS=/tmp/flagfile +// - TF_XLA_FLAGS="--foo=bar --wombat='value with a space'" +// - TF_XLA_FLAGS=/tmp/flagfile +// // where /tmp/flagfile might contain -// --some_flag="This is a string containing a \" and a '." -// --another_flag=wombats +// +// --some_flag="This is a string containing a \" and a '." +// --another_flag=wombats #include +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/command_line_flags.h" namespace xla { -// Call tensorflow::Flags::Parse(argc, argv, flag_list) against any as yet -// unrecognized flags passed in from the environment, and return its -// return value. -bool ParseFlagsFromEnv(const std::vector& flag_list); +// Calls tensorflow::Flags::Parse(argc, argv, flag_list) against any as yet +// unrecognized flags passed in the environment variable `envvar`, and returns +// its return value. +// +// Raises a fatal error if any flags in `envvar` were not recognized. +bool ParseFlagsFromEnvAndDieIfUnknown( + absl::string_view envvar, const std::vector& flag_list); // Used only for testing. Not to be used by clients. -void ResetFlagsFromEnvForTesting(int** pargc, std::vector** pargv); +void ResetFlagsFromEnvForTesting(absl::string_view envvar, int** pargc, + std::vector** pargv); } // namespace xla diff --git a/tensorflow/compiler/xla/parse_flags_from_env_test.cc b/tensorflow/compiler/xla/parse_flags_from_env_test.cc index edd6538402d..3465552ebbf 100644 --- a/tensorflow/compiler/xla/parse_flags_from_env_test.cc +++ b/tensorflow/compiler/xla/parse_flags_from_env_test.cc @@ -37,20 +37,7 @@ static void TestParseFlagsFromEnv(const char* msg) { // Initialize module under test. int* pargc; std::vector* pargv; - ResetFlagsFromEnvForTesting(&pargc, &pargv); - - // Ensure that environment variable can be parsed when - // no flags are expected. - std::vector empty_flag_list; - bool parsed_ok = ParseFlagsFromEnv(empty_flag_list); - CHECK(parsed_ok) << msg; - const std::vector& argv_first = *pargv; - CHECK_NE(argv_first[0], nullptr) << msg; - int i = 0; - while (argv_first[i] != nullptr) { - i++; - } - CHECK_EQ(i, *pargc) << msg; + ResetFlagsFromEnvForTesting("TF_XLA_FLAGS", &pargc, &pargv); // Check that actual flags can be parsed. bool simple = false; @@ -65,7 +52,7 @@ static void TestParseFlagsFromEnv(const char* msg) { tensorflow::Flag("single_quoted", &single_quoted, ""), tensorflow::Flag("double_quoted", &double_quoted, ""), }; - parsed_ok = ParseFlagsFromEnv(flag_list); + bool parsed_ok = ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", flag_list); CHECK_EQ(*pargc, 1) << msg; const std::vector& argv_second = *pargv; CHECK_NE(argv_second[0], nullptr) << msg; @@ -171,7 +158,8 @@ int main(int argc, char* argv[]) { tensorflow::Flag("int_flag", &int_flag, "An integer flag to test with"), }; xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); - bool parse_ok = xla::ParseFlagsFromEnv(flag_list); + bool parse_ok = + xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", flag_list); if (!parse_ok) { LOG(QFATAL) << "can't parse from environment\n" << usage; } diff --git a/tensorflow/compiler/xla/protobuf_util.cc b/tensorflow/compiler/xla/protobuf_util.cc index b507a2ef79f..ac342bf40fb 100644 --- a/tensorflow/compiler/xla/protobuf_util.cc +++ b/tensorflow/compiler/xla/protobuf_util.cc @@ -40,16 +40,6 @@ bool ProtobufEquals(const tensorflow::protobuf::Message& m1, namespace { -string SanitizeFilename(const string& file_name) { - string safe_file_name = file_name; - for (char& c : safe_file_name) { - if (c == '/' || c == '\\') { - c = '_'; - } - } - return safe_file_name; -} - std::pair>*> GetDirectoryExpanders() { static auto* mutex = new tensorflow::mutex; diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index 4d2a37cfac3..6e2ee866321 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -148,14 +148,19 @@ static StatusOr ToBuffer(LocalClient* client, /* static */ StatusOr LocalShapedBuffer::FromLiteral( - const Literal& argument, const absl::optional& shape_with_layout) { + const Literal& argument, const absl::optional& shape_with_layout, + int replica_number) { LocalClient* client = GetOrCreateLocalClient(); + TF_ASSIGN_OR_RETURN(int device_ordinal, + client->ReplicaNumberToDeviceOrdinal(replica_number)); + VLOG(1) << "Creating shaped buffer from literal on replica/ordinal: " + << replica_number << "/" << device_ordinal; StatusOr buf = [&] { if (shape_with_layout) { Literal relaid = argument.Relayout(shape_with_layout.value()); - return ToBuffer(client, /*device_ordinal=*/0, relaid); + return ToBuffer(client, device_ordinal, relaid); } - return ToBuffer(client, /*device_ordinal=*/0, argument); + return ToBuffer(client, device_ordinal, argument); }(); TF_RETURN_IF_ERROR(buf.status()); return new LocalShapedBuffer(std::move(buf).ValueOrDie()); @@ -312,67 +317,127 @@ CompiledLocalComputation::CompiledLocalComputation( StatusOr CompiledLocalComputation::Execute( absl::Span argument_handles) { LocalClient* client = GetOrCreateLocalClient(); + StatusOr device_ordinal_status = client->ReplicaNumberToDeviceOrdinal(0); + StatusOr result_buffer_status; + if (!device_ordinal_status.ok()) { + result_buffer_status = device_ordinal_status.status(); + } else { + const int device_ordinal = device_ordinal_status.ValueOrDie(); + VLOG(3) << "Replica 0 mapped to device ordinal for execution: " + << device_ordinal; - VLOG(1) << "Execution requested with " << GetReplicaCount() << " replicas."; + std::vector argument_buffers; + argument_buffers.reserve(argument_handles.size()); + for (auto& handle : argument_handles) { + argument_buffers.push_back(handle->shaped_buffer()); + } + + DeviceAssignment device_assignment = + client->backend() + .computation_placer() + ->AssignDevices(1, /*computation_count=*/1) + .ConsumeValueOrDie(); + + ExecutableRunOptions options; + options.set_device_ordinal(device_ordinal); + options.set_allocator(client->backend().memory_allocator()); + options.set_intra_op_thread_pool( + client->backend().eigen_intra_op_thread_pool_device()); + options.set_device_assignment(&device_assignment); + + result_buffer_status = executable_->Run(argument_buffers, options); + } + + if (!result_buffer_status.ok()) { + return InternalError( + "Failed running replica 0 (other replicas may have failed as well): " + "%s.", + result_buffer_status.status().ToString()); + } + return new LocalShapedBuffer(std::move(result_buffer_status).ValueOrDie()); +} + +StatusOr CompiledLocalComputation::ExecutePerReplica( + absl::Span> argument_handles) { + LocalClient* client = GetOrCreateLocalClient(); + const int num_replicas = GetReplicaCount(); + + if (argument_handles.size() != num_replicas) { + return InvalidArgument( + "Attempted to execute with %d replicas when replica count is %d", + argument_handles.size(), num_replicas); + } + + VLOG(1) << "Executing with " << num_replicas << " replicas."; // Each replica populates a StatusOr result, but only the output value of // replica zero is returned. - std::vector> results(GetReplicaCount()); - { - tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "xlarun", - GetReplicaCount()); - - for (int replica = 0; replica < GetReplicaCount(); ++replica) { - pool.Schedule( - [this, client, replica, &argument_handles, &results] { - StatusOr device_ordinal_status = - client->ReplicaNumberToDeviceOrdinal(replica); - if (!device_ordinal_status.ok()) { - results[replica] = device_ordinal_status.status(); - return; - } - const int device_ordinal = device_ordinal_status.ValueOrDie(); - VLOG(3) << "Replica " << replica - << " mapped to device ordinal for execution: " - << device_ordinal; - - std::vector argument_buffers; - argument_buffers.reserve(argument_handles.size()); - for (auto& handle : argument_handles) { - argument_buffers.push_back(handle->shaped_buffer()); - } - - DeviceAssignment device_assignment = - client->backend() - .computation_placer() - ->AssignDevices(GetReplicaCount(), /*computation_count=*/1) - .ConsumeValueOrDie(); - - ExecutableRunOptions options; - options.set_device_ordinal(device_ordinal); - options.set_allocator(client->backend().memory_allocator()); - options.set_intra_op_thread_pool( - client->backend().eigen_intra_op_thread_pool_device()); - options.set_device_assignment(&device_assignment); - StatusOr result_buffer_status = - executable_->Run(argument_buffers, options); - - results[replica] = std::move(result_buffer_status); - }); + std::vector> results(num_replicas); + auto execute = [this, client, num_replicas, &argument_handles, + &results](int replica) { + StatusOr device_ordinal_status = + client->ReplicaNumberToDeviceOrdinal(replica); + if (!device_ordinal_status.ok()) { + results[replica] = device_ordinal_status.status(); + return; } + const int device_ordinal = device_ordinal_status.ValueOrDie(); + VLOG(3) << "Replica " << replica + << " mapped to device ordinal for execution: " << device_ordinal; + + std::vector argument_buffers; + argument_buffers.reserve(argument_handles[replica].size()); + for (auto& handle : argument_handles[replica]) { + argument_buffers.push_back(handle->shaped_buffer()); + } + + DeviceAssignment device_assignment = + client->backend() + .computation_placer() + ->AssignDevices(num_replicas, /*computation_count=*/1) + .ConsumeValueOrDie(); + + ExecutableRunOptions options; + options.set_device_ordinal(device_ordinal); + options.set_allocator(client->backend().memory_allocator()); + options.set_intra_op_thread_pool( + client->backend().eigen_intra_op_thread_pool_device()); + options.set_device_assignment(&device_assignment); + StatusOr result_buffer_status = + executable_->Run(argument_buffers, options); + + results[replica] = std::move(result_buffer_status); + }; + + if (num_replicas == 1) { + // Fast-path if there is only one replica ā€” run the computation on the + // current thread. + execute(0); + } else { + // TODO(phawkins): don't recreate the threadpool for each execution. + tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "xlarun", + num_replicas - 1); + + for (int replica = 0; replica < num_replicas - 1; ++replica) { + pool.Schedule([&execute, replica] { execute(replica); }); + } + execute(num_replicas - 1); } - for (int replica = 0; replica < GetReplicaCount(); ++replica) { - const auto& statusor = results[replica]; + std::vector wrapped_results(num_replicas); + for (int replica = 0; replica < num_replicas; ++replica) { + auto& statusor = results[replica]; if (!statusor.ok()) { return InternalError( "Failed running replica %d (other replicas may have failed as well): " "%s.", replica, statusor.status().ToString()); } + wrapped_results[replica] = + new LocalShapedBuffer(std::move(statusor).ValueOrDie()); } - return new LocalShapedBuffer(std::move(results[0]).ValueOrDie()); + return new LocalShapedBufferTuple(std::move(wrapped_results)); } static StatusOr GetReturnValueShape(const XlaComputation& computation) { @@ -487,12 +552,13 @@ StatusOr LocalComputation::CompileForXrt( xrt::XLAComputation c; auto config = c.mutable_config(); - auto shapes = config->mutable_program_shape(); + ProgramShape shapes; for (auto& shape : argument_shapes) { - *shapes->add_parameters() = shape; + *shapes.add_parameters() = shape; } - TF_ASSIGN_OR_RETURN(*shapes->mutable_result(), GetReturnValueShape()); - LayoutUtil::SetToDefaultLayout(shapes); + TF_ASSIGN_OR_RETURN(*shapes.mutable_result(), GetReturnValueShape()); + LayoutUtil::SetToDefaultLayout(&shapes); + *config->mutable_program_shape() = shapes.ToProto(); auto snapshot = computation().Snapshot().ValueOrDie(); *c.mutable_hlo_snapshot() = *snapshot; @@ -584,9 +650,9 @@ LocalOp LocalComputationBuilder::Broadcast( } LocalOp LocalComputationBuilder::BroadcastInDim( - const LocalOp& operand, const Shape& shape, + const LocalOp& operand, absl::Span out_dim_sizes, absl::Span broadcast_dimensions) { - return xla::BroadcastInDim(operand.op(), shape, broadcast_dimensions); + return xla::BroadcastInDim(operand.op(), out_dim_sizes, broadcast_dimensions); } LocalOp LocalComputationBuilder::Pad(const LocalOp& operand, diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index 9e617c48bdc..149e44570df 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -71,7 +71,8 @@ StatusOr TransferFromOutfeedLocalReplica(const Shape& shape, class LocalShapedBuffer { public: static StatusOr FromLiteral( - const Literal& argument, const absl::optional& shape_with_layout); + const Literal& argument, const absl::optional& shape_with_layout, + int replica_number); LocalShapedBuffer(ScopedShapedBuffer shaped_buffer); StatusOr ToLiteral() const; @@ -175,6 +176,12 @@ class CompiledLocalComputation { StatusOr Execute( absl::Span argument_handles); + // Execute on many replicas. Takes a sequence of argument lists (one argument + // list per replica) and returns a tuple of results (one result per replica). + // The number of argument lists must be equal to the replica count. + StatusOr ExecutePerReplica( + absl::Span > argument_handles); + private: std::unique_ptr executable_; }; @@ -282,7 +289,8 @@ class LocalComputationBuilder { LocalOp Broadcast(const LocalOp& operand, absl::Span broadcast_sizes); - LocalOp BroadcastInDim(const LocalOp& operand, const Shape& shape, + LocalOp BroadcastInDim(const LocalOp& operand, + absl::Span out_dim_sizes, absl::Span broadcast_dimensions); LocalOp Pad(const LocalOp& operand, const LocalOp& padding_value, diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index feabfdb889c..d23d693c1e5 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -363,6 +363,37 @@ tensorflow::ImportNumpy(); $1 = temps; } +%typemap(in) absl::Span > + (std::vector > temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + SWIG_fail; + } + const int size = PySequence_Size($input); + temps.reserve(size); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + std::vector vec; + const int vec_size = PySequence_Size(o); + vec.reserve(vec_size); + for (int j = 0; j < vec_size; ++j) { + PyObject* vec_elt = PySequence_GetItem(o, j); + LocalShapedBuffer* lsbp; + if ((SWIG_ConvertPtr(vec_elt, (void**) &lsbp, $descriptor(xla::swig::LocalShapedBuffer*), + SWIG_POINTER_EXCEPTION)) == -1) { + Py_DECREF(vec_elt); + Py_DECREF(o); + SWIG_fail; + } + vec.push_back(lsbp); + Py_DECREF(vec_elt); + } + temps.push_back(vec); + Py_DECREF(o); + } + $1 = temps; +} + %typemap(in) absl::Span (std::vector temps) { if (!PySequence_Check($input)) { @@ -921,22 +952,22 @@ tensorflow::ImportNumpy(); $1 = NULL; } else { if (!HandleStringAttribute($input, "generate_hlo_graph", [&](string s) { - build_options.set_generate_hlo_graph(std::move(s)); + build_options.mutable_debug_options()->set_xla_generate_hlo_graph(std::move(s)); })) { return nullptr; } if (!HandleStringAttribute($input, "dump_optimized_hlo_proto_to", [&](string s) { - build_options.set_dump_optimized_hlo_proto_to(std::move(s)); + build_options.mutable_debug_options()->set_xla_dump_optimized_hlo_proto_to(std::move(s)); })) { return nullptr; } if (!HandleStringAttribute($input, "dump_unoptimized_hlo_proto_to", [&](string s) { - build_options.set_dump_unoptimized_hlo_proto_to(std::move(s)); + build_options.mutable_debug_options()->set_xla_dump_unoptimized_hlo_proto_to(std::move(s)); })) { return nullptr; } if (!HandleStringAttribute($input, "dump_per_pass_hlo_proto_to", [&](string s) { - build_options.set_dump_per_pass_hlo_proto_to(std::move(s)); + build_options.mutable_debug_options()->set_xla_dump_per_pass_hlo_proto_to(std::move(s)); })) { return nullptr; } @@ -950,7 +981,7 @@ tensorflow::ImportNumpy(); PyErr_SetString(PyExc_TypeError, "ExecutableBuildOptions.hlo_profile must be a bool or None."); SWIG_fail; } - build_options.set_hlo_profile(o == Py_True); + build_options.mutable_debug_options()->set_xla_hlo_profile(o == Py_True); } Py_DECREF(o); @@ -992,11 +1023,13 @@ tensorflow::ImportNumpy(); %unignore xla::swig::XrtAllocation; %unignore xla::swig::XrtAllocation::FromLiteral; %unignore xla::swig::XrtAllocation::ToLiteral; +%unignore xla::swig::XrtAllocation::shape; %unignore xla::swig::XrtAllocationTuple; %unignore xla::swig::XrtAllocationTuple::Release; %unignore xla::swig::XrtAllocationTuple::size; %unignore xla::swig::CompiledLocalComputation; %unignore xla::swig::CompiledLocalComputation::Execute; +%unignore xla::swig::CompiledLocalComputation::ExecutePerReplica; %unignore xla::swig::CompiledXrtComputation; %unignore xla::swig::CompiledXrtComputation::Execute; %unignore xla::swig::LocalComputation; diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 92b0685dbba..c91a2aaf56d 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -26,6 +26,9 @@ import os import numpy as np +import six +from six.moves import xrange + from tensorflow.compiler.xla import xla_data_pb2 from tensorflow.compiler.xla.python import pywrap_xla as c_api from tensorflow.compiler.xla.service import hlo_pb2 @@ -75,6 +78,13 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): source_line=lineno) +def _maybe_encode_string(s): + if six.PY3: + return s.encode('utf-8') + else: + return s + + class PaddingType(enum.Enum): VALID = 1 SAME = 2 @@ -212,23 +222,33 @@ class LocalBuffer(object): means the referent is in device memory. """ - def __init__(self, c_buffer, backend): + def __init__(self, c_buffer, backend, replica): self.c_buffer = c_buffer self._backend = backend + self._replica = replica if backend.backend_type == BackendType.XRT: self._delete = c_api.DeleteXrtAllocation else: self._delete = c_api.DeleteLocalShapedBuffer @staticmethod - def from_pyval(pyval, backend=XLA_LOCAL_BACKEND): + def from_pyval(pyval, replica=0, backend=XLA_LOCAL_BACKEND): """Allocate and copy to XLA the given python value.""" pyval = require_numpy_array_layout(pyval) + num_replicas = get_replica_count() + if not 0 <= replica < num_replicas: + raise ValueError( + 'Attempt to place buffer on replica {} when the replica count is {}' + .format(replica, num_replicas)) if backend.backend_type == BackendType.XRT: - cbuf = c_api.XrtAllocation.FromLiteral(pyval, backend.target) + if replica != 0: + raise NotImplementedError( + 'Multi-replica execution is not yet supported via the XRT backend.') + cbuf = c_api.XrtAllocation.FromLiteral( + pyval, _maybe_encode_string(backend.target)) else: - cbuf = c_api.LocalShapedBuffer.FromLiteral(pyval, None) - return LocalBuffer(cbuf, backend) + cbuf = c_api.LocalShapedBuffer.FromLiteral(pyval, None, replica) + return LocalBuffer(cbuf, backend, replica) def to_py(self): return self.c_buffer.ToLiteral() @@ -236,6 +256,9 @@ class LocalBuffer(object): def shape(self): return _wrap_shape(self.c_buffer.shape()) + def replica(self): + return self._replica + def delete(self): if self.c_buffer is not None: self._delete(self.c_buffer) @@ -245,14 +268,15 @@ class LocalBuffer(object): """Assuming a tuple buffer, unpack it into constituent tuple elements.""" assert self.c_buffer is not None if self._backend.backend_type == BackendType.XRT: - result = c_api.DestructureXrtAllocationTuple(self.c_buffer, - self._backend.target) + result = c_api.DestructureXrtAllocationTuple( + self.c_buffer, _maybe_encode_string(self._backend.target)) else: result = c_api.DestructureLocalShapedBufferTuple(self.c_buffer) self.delete() size = result.size() destructured = tuple( - LocalBuffer(result.Release(i), backend=self._backend) + LocalBuffer( + result.Release(i), replica=self._replica, backend=self._backend) for i in xrange(size)) return destructured @@ -322,6 +346,9 @@ class Shape(object): def __ne__(self, other): return not self == other + def __hash__(self): + return hash((self._dtype, self._dimensions, self._minor_to_major)) + def __repr__(self): return ('xla_client.Shape(_dtype={!r}, _dimensions={!r}, ' '_is_tuple={!r}, _minor_to_major={!r})').format( @@ -541,10 +568,13 @@ class LocalComputation(object): ] result_shape = result_shape.map_leaves(layout_fn) + argument_shapes = list(argument_shapes) + compile_options = compile_options or CompileOptions() compile_options.result_shape = result_shape if self._backend.backend_type == BackendType.XRT: - c = self.computation.CompileForXrt(argument_shapes, self._backend.target) + c = self.computation.CompileForXrt( + argument_shapes, _maybe_encode_string(self._backend.target)) else: c = self.computation.Compile(argument_shapes, compile_options) return LocalComputation(c, is_compiled=True, backend=self._backend) @@ -558,23 +588,87 @@ class LocalComputation(object): compile_options=compile_options, layout_fn=layout_fn) - def Execute(self, arguments=()): - """Execute with LocalBuffer arguments and return value.""" + def GetReturnValueShape(self): + return _wrap_shape(self._c_computation.GetReturnValueShape()) + + def Execute(self, arguments=(), check_for_deleted_args=True): + """Execute on one replica with LocalBuffer arguments and return value.""" + if check_for_deleted_args and any(arg.is_deleted() for arg in arguments): + raise ValueError('Executing with deleted local buffer argument') + raw_args = [arg.c_buffer for arg in arguments] + output_buffer = self._c_computation.Execute(raw_args) + return LocalBuffer(output_buffer, backend=self._backend, replica=0) + + def ExecutePerReplica(self, arguments=None): + """Execute on many replicas with LocalBuffer arguments and return value. + + Args: + arguments: A sequence of sequences of LocalBuffers. The i'th inner + sequence comprises the arguments for execution on the i'th replica. + + Returns: + A list of the computation's outputs on each replica, as a LocalBuffer. If + a shallow sequence of arguments was passed in for `arguments`, then the + sole, zero'th replica's output is returned instead, as a LocalBuffer. + """ if not self._is_compiled: raise ValueError('Cannot execute an uncompiled local XLA computation.') - arguments = tuple(arguments) - if any(arg.is_deleted() for arg in arguments): - raise ValueError('Executing with deleted local buffer argument') - return LocalBuffer( - self._c_computation.Execute([arg.c_buffer for arg in arguments]), - backend=self._backend) + if arguments is None: + arguments = ((),) * get_replica_count() + else: + arguments = [list(replica_args) for replica_args in arguments] + + # Check arguments + for replica, replica_args in enumerate(arguments): + for arg in replica_args: + if arg.is_deleted(): + raise ValueError('Executing with deleted local buffer argument') + if arg.replica() != replica: + raise ValueError( + 'Executing on replica {} with argument from replica {}'.format( + replica, arg.replica())) + + # Pull out argument buffer handles + stripped_args = [ + [arg.c_buffer for arg in replica_args] for replica_args in arguments + ] + + # Execute + if self._backend.backend_type == BackendType.XRT: + if len(stripped_args) > 1: + raise NotImplementedError( + 'Multi-replica execution is not yet supported via the XRT backend.') + output_buffers = [self._c_computation.Execute(stripped_args[0])] + else: + output_buffer_tup = self._c_computation.ExecutePerReplica(stripped_args) + size = output_buffer_tup.size() + output_buffers = [output_buffer_tup.Release(i) for i in xrange(size)] + + # Wrap output handles in LocalBuffer instances + return tuple( + LocalBuffer(output_buffer, backend=self._backend, replica=replica) + for replica, output_buffer in enumerate(output_buffers)) def ExecuteWithPythonValues(self, arguments=()): - """Execute with Python values as arguments and return value.""" - arguments = tuple( - LocalBuffer.from_pyval(arg, backend=self._backend) for arg in arguments) + """Execute on one replica with Python values as arguments and output.""" + + def put(arg): + return LocalBuffer.from_pyval(arg, backend=self._backend) + + arguments = [put(arg) for arg in arguments] return self.Execute(arguments).to_py() + def ExecuteWithPythonValuesPerReplica(self, arguments): + """Execute on many replicas with Python values as arguments and output.""" + + def put(arg, replica): + return LocalBuffer.from_pyval(arg, replica, backend=self._backend) + + arguments = [[put(arg, replica) + for arg in replica_args] + for replica, replica_args in enumerate(arguments)] + return [out.to_py() for out in self.ExecutePerReplica(arguments)] + def __del__(self): self._delete(self._c_computation) @@ -761,8 +855,7 @@ class ComputationBuilder(object): Returns: A LocalOp representing the added broadcast-in-dimensions op. """ - xla_shape = Shape.array_shape(self.GetShape(operand).element_type(), shape) - return self._client.BroadcastInDim(operand, xla_shape, broadcast_dimensions) + return self._client.BroadcastInDim(operand, shape, broadcast_dimensions) def Concatenate(self, operands, dimension): """Enqueues a concatenate operation onto the computation. @@ -1380,6 +1473,7 @@ def initialize_platform_name(platform_name): Raises: A runtime exception if the XLA service has already been initialized. """ + platform_name = _maybe_encode_string(platform_name) c_api.InitializePlatformName(platform_name) diff --git a/tensorflow/compiler/xla/python_api/xla_shape.py b/tensorflow/compiler/xla/python_api/xla_shape.py index f158f6b2410..95b2bf300ec 100644 --- a/tensorflow/compiler/xla/python_api/xla_shape.py +++ b/tensorflow/compiler/xla/python_api/xla_shape.py @@ -25,9 +25,10 @@ from tensorflow.compiler.xla.python_api import types class Shape(object): - """Wraps a xla_data_pb2.Shape message with a convenient Python type. + """Wraps a xla_data_pb2.ShapeProto message with a convenient Python type. - Provides direct access to the underlying xla_data_pb2.Shape message in the + Provides direct access to the underlying xla_data_pb2.ShapeProto message in + the message attribute, along with accessor wrappers to the message's fields. Avoid direct access to .message unless interacting directly with protobuf APIs like CopyFrom. In other words, prefer hauling the shape around in a Shape, and @@ -48,7 +49,7 @@ class Shape(object): Raises: ValueError: if element_type is TUPLE but dimensions are not Shape objects. """ - self.message = xla_data_pb2.Shape() + self.message = xla_data_pb2.ShapeProto() self.message.element_type = element_type if element_type == xla_data_pb2.TUPLE: if not all(isinstance(subshape, Shape) for subshape in dimensions): diff --git a/tensorflow/compiler/xla/rpc/BUILD b/tensorflow/compiler/xla/rpc/BUILD index 3abb3855a42..26affbcceb3 100644 --- a/tensorflow/compiler/xla/rpc/BUILD +++ b/tensorflow/compiler/xla/rpc/BUILD @@ -16,7 +16,6 @@ xla_proto_library( use_grpc_plugin = True, visibility = ["//visibility:public"], deps = [ - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", ], ) diff --git a/tensorflow/compiler/xla/rpc/xla_service.proto b/tensorflow/compiler/xla/rpc/xla_service.proto index e4f332cda22..0ff8adc2acb 100644 --- a/tensorflow/compiler/xla/rpc/xla_service.proto +++ b/tensorflow/compiler/xla/rpc/xla_service.proto @@ -43,7 +43,6 @@ limitations under the License. syntax = "proto3"; import "tensorflow/compiler/xla/xla.proto"; -import "tensorflow/compiler/xla/xla_data.proto"; package xla; diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 19b5c1ca25d..81e71eee520 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -281,10 +281,12 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_element_type_converter", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", "//tensorflow/core:test", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", ], ) @@ -292,6 +294,7 @@ cc_library( name = "hlo", srcs = [ "dfs_hlo_visitor.cc", + "dynamic_parameter_binding.cc", "hlo_computation.cc", "hlo_input_output_alias_config.cc", "hlo_instruction.cc", @@ -305,6 +308,7 @@ cc_library( hdrs = [ "dfs_hlo_visitor.h", "dfs_hlo_visitor_with_default.h", + "dynamic_parameter_binding.h", "hlo_clone_context.h", "hlo_computation.h", "hlo_domain_metadata.h", @@ -350,6 +354,25 @@ cc_library( ], ) +tf_cc_test( + name = "dynamic_parameter_binding_test", + srcs = ["dynamic_parameter_binding_test.cc"], + deps = [ + ":hlo", + ":hlo_dce", + ":hlo_memory_scheduler", + ":hlo_ordering", + ":hlo_parser", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + "@com_google_absl//absl/algorithm:container", + ], +) + tf_cc_test( name = "dfs_hlo_visitor_with_default_test", srcs = ["dfs_hlo_visitor_with_default_test.cc"], @@ -387,9 +410,36 @@ tf_cc_test( ":hlo", ":pattern_matcher", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "pattern_matcher_gmock", + testonly = 1, + hdrs = ["pattern_matcher_gmock.h"], + deps = [ + ":pattern_matcher", + "//tensorflow/compiler/xla:test", + "//tensorflow/core:test", + ], +) + +tf_cc_test( + name = "pattern_matcher_gmock_test", + srcs = ["pattern_matcher_gmock_test.cc"], + deps = [ + ":hlo", + ":pattern_matcher", + ":pattern_matcher_gmock", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", ], ) @@ -403,6 +453,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/types:span", ], @@ -1336,6 +1387,7 @@ cc_library( ":hlo", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -1539,7 +1591,10 @@ tf_cc_test( ":hlo", ":hlo_casting_utils", ":hlo_matchers", + ":hlo_parser", ":hlo_pass", + ":pattern_matcher", + ":pattern_matcher_gmock", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -1707,7 +1762,9 @@ cc_library( ":hlo", ":hlo_pass", ":hlo_query", + ":pattern_matcher", ":while_loop_analysis", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -1720,9 +1777,14 @@ tf_cc_test( name = "while_loop_simplifier_test", srcs = ["while_loop_simplifier_test.cc"], deps = [ + ":algebraic_simplifier", ":hlo", + ":hlo_cse", ":hlo_dce", ":hlo_matchers", + ":hlo_pass", + ":hlo_pass_pipeline", + ":tuple_simplifier", ":while_loop_simplifier", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -2347,6 +2409,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", ], ) @@ -2600,6 +2663,8 @@ tf_cc_test( ":hlo", ":hlo_matchers", ":layout_assignment", + ":pattern_matcher", + ":pattern_matcher_gmock", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", @@ -2744,6 +2809,8 @@ tf_cc_test( ":hlo_matchers", ":hlo_parser", ":hlo_pass", + ":pattern_matcher", + ":pattern_matcher_gmock", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -2855,6 +2922,46 @@ tf_cc_test( ], ) +cc_library( + name = "hlo_get_dimension_size_rewriter", + srcs = ["hlo_get_dimension_size_rewriter.cc"], + hdrs = ["hlo_get_dimension_size_rewriter.h"], + deps = [ + ":hlo", + ":hlo_pass", + ":shape_inference", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + ], +) + +tf_cc_test( + name = "hlo_get_dimension_size_rewriter_test", + srcs = ["hlo_get_dimension_size_rewriter_test.cc"], + deps = [ + ":hlo", + ":hlo_get_dimension_size_rewriter", + ":hlo_matchers", + ":hlo_parser", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + cc_library( name = "device_memory_allocator", srcs = [ @@ -2913,6 +3020,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@llvm//:core", "@llvm//:transform_utils", @@ -3026,6 +3134,7 @@ cc_library( ":hlo_casting_utils", ":hlo_execution_profile", ":hlo_tfgraph_builder", + ":pattern_matcher", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", @@ -3318,9 +3427,9 @@ cc_library( ":tuple_util", ":while_loop_analysis", ":while_util", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", - "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -3463,6 +3572,8 @@ tf_cc_test( ":hlo_casting_utils", ":hlo_matchers", ":hlo_parser", + ":pattern_matcher", + ":pattern_matcher_gmock", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:window_util", "//tensorflow/core:lib", @@ -3513,6 +3624,41 @@ cc_library( ], ) +cc_library( + name = "ar_crs_combiner", + srcs = ["ar_crs_combiner.cc"], + hdrs = ["ar_crs_combiner.h"], + deps = [ + ":call_graph", + ":pattern_matcher", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "ar_crs_combiner_test", + srcs = ["ar_crs_combiner_test.cc"], + deps = [ + ":ar_crs_combiner", + ":hlo", + ":hlo_matchers", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + tf_cc_test( name = "map_inliner_test", srcs = ["map_inliner_test.cc"], diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 89e62bd2f0d..985c5af1c4d 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include +#include #include #include #include @@ -68,6 +69,45 @@ bool IsAll(const HloInstruction* op, int8 value) { } } +// Checks whether `op` is a floating-point constant or broadcast of a constant +// of the form +/- 2^k for some integer k positive, negative, or zero. Such +// values are interesting because multiplying by a power of 2 just moves the +// exponent. +bool IsAllFpConstantPowerOf2(const HloInstruction* op) { + // Unwrap the broadcast if necessary. + const HloInstruction* c; + if (!Match(op, m::ConstantEffectiveScalar(&c)) && + !Match(op, m::Broadcast(m::Constant(&c).WithShape( + m::Shape().IsEffectiveScalar())))) { + return false; + } + auto val = [&]() -> absl::optional { + switch (c->shape().element_type()) { + case BF16: + return static_cast(c->literal().GetFirstElement()); + case F16: + return static_cast(c->literal().GetFirstElement()); + case F32: + return c->literal().GetFirstElement(); + case F64: + return c->literal().GetFirstElement(); + default: + // Cowardly refuse to consider complex types. + return absl::nullopt; + } + }(); + if (!val) { + return false; + } + + int exp; + double mantissa = std::frexp(*val, &exp); + // frexp returns a value in the range (-1; -0.5] U [0.5, 1). A return value + // of +/-0.5 therefore indicates that the floating point value is a power of + // 2. + return mantissa == 0.5 || mantissa == -0.5; +} + // Returns whether the given transpose produces a result which is bit-wise // identical to its operand and thus may be replaced with a bitcast. bool TransposeIsBitcast(const HloInstruction* transpose) { @@ -84,7 +124,8 @@ bool TransposeIsBitcast(const HloInstruction* transpose) { // reshape may still be a bitcast. For example, a reshape from [28x28] to [784]. bool ReshapeOrCopyIsBitcast( const HloInstruction* instr, - const AlgebraicSimplifier::ValidBitcastCallback& valid_bitcast_callback) { + const AlgebraicSimplifierOptions::ValidBitcastCallback& + valid_bitcast_callback) { CHECK(HloOpcode::kReshape == instr->opcode() || HloOpcode::kCopy == instr->opcode()); @@ -95,6 +136,11 @@ bool ReshapeOrCopyIsBitcast( valid_bitcast_callback(operand->shape(), instr->shape()); } +bool IsUnstridedSlice(const HloInstruction* hlo) { + return absl::c_all_of(hlo->slice_strides(), + [](int64 stride) { return stride == 1; }); +} + // AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain // algebraic expressions to simplified forms. Note: This only supports // simplifications that simply look at the operands of an instruction. For the @@ -180,21 +226,13 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { const bool changed() const { return changed_; } // Runs the visitor on a computation. - static bool Run( - HloComputation* computation, bool is_layout_sensitive, - AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback, - bool enable_dot_strength_reduction, bool enable_conv_simplification); + static bool Run(HloComputation* computation, + const AlgebraicSimplifierOptions& options); private: - explicit AlgebraicSimplifierVisitor( - HloComputation* computation, bool is_layout_sensitive, - AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback, - bool enable_dot_strength_reduction, bool enable_conv_simplification) - : computation_(computation), - is_layout_sensitive_(is_layout_sensitive), - valid_bitcast_callback_(std::move(valid_bitcast_callback)), - enable_dot_strength_reduction_(enable_dot_strength_reduction), - enable_conv_simplification_(enable_conv_simplification) {} + explicit AlgebraicSimplifierVisitor(HloComputation* computation, + const AlgebraicSimplifierOptions& options) + : computation_(computation), options_(options) {} // Transforms Dots where at least one input is a vector or has a degenerate // dimension and converts it into a multiply and reduce. This should enable @@ -233,10 +271,10 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { HloInstruction* new_instruction); // Returns whether the shape of the output of the given instructions are the - // same for the purposes of simplification. If is_layout_sensitive_ is true, - // then this tests shape equality including layout (ShapeUtil::Equal). If - // is_layout_sensitive_ is false, then the tests shape compatibility - // (ShapeUtil::Compatible). + // same for the purposes of simplification. If options_.is_layout_sensitive() + // is true, then this tests shape equality including layout + // (ShapeUtil::Equal). If options_.is_layout_sensitive() is false, then the + // tests shape compatibility (ShapeUtil::Compatible). bool SameShape(const HloInstruction* lhs, const HloInstruction* rhs) const; // Returns whether it was possible to transform `root` to a clamp instruction. @@ -325,22 +363,12 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // traversing. HloComputation* computation_; + // The backend-specific options selected for the algebraic simplifier. + const AlgebraicSimplifierOptions& options_; + // Whether algebraic simplification has occurred. bool changed_ = false; - // Whether layout is considered during transformation. - bool is_layout_sensitive_; - - // Callback used to determine if a bitcast is possible. - AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback_; - - // Disable dot strength reduction on platforms where it causes a slowdown. - bool enable_dot_strength_reduction_; - - // Disable convolution -> dot simplification on platforms where it causes a - // slowdown. - bool enable_conv_simplification_; - // Cached computation for adding two scalar F32. HloComputation* scalar_add_computation_ = nullptr; }; @@ -348,19 +376,15 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { } // namespace bool AlgebraicSimplifierVisitor::Run( - HloComputation* computation, bool is_layout_sensitive, - AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback, - bool enable_dot_strength_reduction, bool enable_conv_simplification) { - AlgebraicSimplifierVisitor visitor( - computation, is_layout_sensitive, std::move(valid_bitcast_callback), - enable_dot_strength_reduction, enable_conv_simplification); + HloComputation* computation, const AlgebraicSimplifierOptions& options) { + AlgebraicSimplifierVisitor visitor(computation, options); TF_CHECK_OK(computation->Accept(&visitor)); return visitor.changed_; } bool AlgebraicSimplifierVisitor::SameShape(const HloInstruction* lhs, const HloInstruction* rhs) const { - if (is_layout_sensitive_) { + if (options_.is_layout_sensitive()) { return ShapeUtil::Equal(lhs->shape(), rhs->shape()); } else { return ShapeUtil::Compatible(lhs->shape(), rhs->shape()); @@ -431,6 +455,40 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { sum_of_constants)); } + // A*C + B*C => (A+B)*C + // + // - If A, B, and C are integers, do this unconditionally. Proof of + // correctness: https://rise4fun.com/Alive/u9X. + // + // - If A, B, and C are floating point, do this if C is a scalar constant or + // broadcast of scalar constant and is equal to +/- 2^k for some (possibly + // negative) integer k. + // + // Multiplying by a power of 2 just moves the exponent, so our answer is + // exact modulo rounding of intermediate results so long as + // + // - none of the three products has an exponent which underflows (so the + // result is 0 or denormal), and + // - none of the three products overflows to inf. + // + // Proof: See algebraic_simplifier_proof_distributive_property.py. + // + // We deem these differences in rounding, underflow, and overflow + // acceptable in the ML context. + HloInstruction *b, *c; + if (((Match(lhs, m::Multiply(m::Op(&a), m::Op(&c))) && + Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b)))) || + (Match(lhs, m::Multiply(m::Op(&c), m::Op(&a))) && + Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b))))) && + (ShapeUtil::ElementIsIntegral(add->shape()) || + IsAllFpConstantPowerOf2(c))) { + return ReplaceWithNewInstruction( + add, HloInstruction::CreateBinary( + add->shape(), HloOpcode::kMultiply, + computation_->AddInstruction(HloInstruction::CreateBinary( + add->shape(), HloOpcode::kAdd, a, b)), + c)); + } return Status::OK(); } @@ -504,8 +562,8 @@ Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) { return Status::OK(); } - if (is_layout_sensitive_ && - ReshapeOrCopyIsBitcast(copy, valid_bitcast_callback_)) { + if (options_.is_layout_sensitive() && + ReshapeOrCopyIsBitcast(copy, options_.valid_bitcast_callback())) { ReplaceWithBitcast(copy); } @@ -541,7 +599,74 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate( VLOG(10) << "trying to replace " << concatenate->ToString() << " with " << replacement->ToString(); ReplaceInstructionIfSameShape(concatenate, replacement); - } else if (operands.size() == 2) { + return Status::OK(); + } + + // Check if we can merge "adjacent" slice operands which take slices from the + // same other op. For simplicity we only merge unstrided slices. + int64 concatenate_dimension = concatenate->concatenate_dimension(); + for (int64 i = 0; i < operands.size(); ++i) { + if (operands[i]->opcode() != HloOpcode::kSlice || + !IsUnstridedSlice(operands[i])) { + continue; + } + int64 slice_end = operands[i]->slice_limits(concatenate_dimension); + HloInstruction* slice_operand = operands[i]->mutable_operand(0); + int64 j = i + 1; + while (j < operands.size() && operands[j]->opcode() == HloOpcode::kSlice && + IsUnstridedSlice(operands[j]) && + operands[j]->operand(0) == slice_operand && + operands[j]->slice_starts(concatenate_dimension) == slice_end) { + // Check that all the slice_start values are the same in all other + // dimensions. This implies that the slice_limit values are also the same, + // because operands of concatenate need to have the same shape, and we + // already checked that the slices are unstrided. + bool same_other_starts = true; + for (int64 k = 0; k < operands[j]->slice_starts().size(); ++k) { + if (k == concatenate_dimension) { + continue; + } + if (operands[i]->slice_starts(k) != operands[j]->slice_starts(k)) { + same_other_starts = false; + break; + } + } + if (!same_other_starts) { + break; + } + slice_end = operands[j]->slice_limits(concatenate_dimension); + ++j; + } + if (j - i > 1) { + Shape new_slice_shape = operands[i]->shape(); + new_slice_shape.set_dimensions( + concatenate_dimension, + slice_end - operands[i]->slice_starts(concatenate_dimension)); + auto new_limit_indices = operands[i]->slice_limits(); + new_limit_indices[concatenate_dimension] = slice_end; + auto new_slice_op = + computation_->AddInstruction(HloInstruction::CreateSlice( + new_slice_shape, slice_operand, + /*start_indices=*/operands[i]->slice_starts(), + /*limit_indices=*/new_limit_indices, + /*strides=*/operands[i]->slice_strides())); + std::vector new_operands; + for (int64 k = 0; k < i; ++k) { + new_operands.push_back(operands[k]); + } + new_operands.push_back(new_slice_op); + for (int64 k = j; k < operands.size(); ++k) { + new_operands.push_back(operands[k]); + } + auto replacement = + computation_->AddInstruction(concatenate->CloneWithNewOperands( + concatenate->shape(), new_operands)); + ReplaceInstructionIfSameShape(concatenate, replacement); + return Status::OK(); + } + } + + if (operands.size() == 2) { // A binary concat with a broadcasted scalar as an operand can be converted // into a pad which is simpler to fold into other operations. bool is_effective_low_pad = Match( @@ -557,7 +682,7 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate( padding_config_dim->set_edge_padding_high(0); padding_config_dim->set_edge_padding_low(0); padding_config_dim->set_interior_padding(0); - if (dim == concatenate->concatenate_dimension()) { + if (dim == concatenate_dimension) { if (is_effective_low_pad) { padding_config_dim->set_edge_padding_low( operands[0]->shape().dimensions(dim)); @@ -1215,7 +1340,8 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { return ReplaceInstruction(dot, dot_of_gather_optimized); } - if (enable_dot_strength_reduction_ && !is_layout_sensitive_) { + if (options_.enable_dot_strength_reduction() && + !options_.is_layout_sensitive()) { TF_ASSIGN_OR_RETURN(bool did_strength_reduction, HandleDotStrengthReduction(dot)); if (did_strength_reduction) { @@ -1619,6 +1745,27 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { pad, HloInstruction::CreateBroadcast(pad->shape(), pad->mutable_operand(1), {})); } + + // Interior padding on one sized dimensions have no effect. As a result it + // makes other simplifications possible if there is no interior padding. + if (HasInteriorPadding(pad->padding_config())) { + PaddingConfig padding_config = pad->padding_config(); + bool cleared_interior_padding = false; + for (int64 i = 0; i < ShapeUtil::Rank(pad->shape()); ++i) { + if (padding_config.dimensions(i).interior_padding() > 0 && + pad->operand(0)->shape().dimensions(i) == 1) { + cleared_interior_padding = true; + padding_config.mutable_dimensions(i)->set_interior_padding(0); + } + } + if (cleared_interior_padding) { + return ReplaceWithNewInstruction( + pad, + HloInstruction::CreatePad(pad->shape(), pad->mutable_operand(0), + pad->mutable_operand(1), padding_config)); + } + } + // Eliminate nop pads (padding all zero), and replace a pad with negative // padding with a pad with non-negative padding followed by a slice. bool all_zero = true; @@ -1910,8 +2057,8 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { } // Make this a bitcast if possible. - if (is_layout_sensitive_ && - ReshapeOrCopyIsBitcast(reshape, valid_bitcast_callback_)) { + if (options_.is_layout_sensitive() && + ReshapeOrCopyIsBitcast(reshape, options_.valid_bitcast_callback())) { ReplaceWithBitcast(reshape); return Status::OK(); } @@ -2030,11 +2177,6 @@ StatusOr AlgebraicSimplifierVisitor::TrySimplifyScalarSlice( return false; } -bool IsUnstridedSlice(const HloInstruction* hlo) { - return absl::c_all_of(hlo->slice_strides(), - [](int64 stride) { return stride == 1; }); -} - StatusOr AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape( HloInstruction* slice) { CHECK_EQ(slice->opcode(), HloOpcode::kSlice); @@ -2501,6 +2643,108 @@ Status AlgebraicSimplifierVisitor::HandleSort(HloInstruction* sort) { return ReplaceWithNewInstruction( sort, HloInstruction::CreateTuple(sort->operands())); } + if (!options_.enable_permutation_sort_replacement()) { + return Status::OK(); + } + // Check if we are sorting a permutation. In that case, we know that the keys + // will be sorted to the identity permutation, and we can represent the + // changes to the 'values' parameter as a scatter. + if (sort->operand_count() == 2 && + operand->opcode() == HloOpcode::kGetTupleElement) { + const HloInstruction* other_sort = operand->operand(0); + // Check whether the 'values' parameter is the result of another sort with + // the same sort dimension. + if (other_sort->opcode() == HloOpcode::kSort && + other_sort->operand_count() >= 2 && + other_sort->dimensions(0) == dimension_to_sort && + other_sort->operand(operand->tuple_index())->opcode() == + HloOpcode::kIota) { + auto* iota = + Cast(other_sort->operand(operand->tuple_index())); + // The sort operand needs to be an integral iota, and the iota dimension + // needs to be the dimension that was sorted. + if (iota->iota_dimension() == dimension_to_sort && + ShapeUtil::ElementIsIntegral(iota->shape())) { + // We use the following construction method for a Scatter that applies + // the permutation from 'keys' to the 'values' parameter. + // - Take the "keys" parameter of the second sort and reshape it to have + // another "1" dimension at the end. + // - Concatenate it with iotas of the same extended shape with all + // different iota_dimensions except the dimension_to_sort in the order + // of iota_dimensions/dimension_to_sort, so e.g. with rank 3 and + // dimension_to_sort = 1, we would have concatenate of (iota with + // iota_dimension=0, keys, iota with iota_dimension = 2) + // - Use this as the indices parameter of scatter, and set updates + // of the scatter to be a reshaped 'values' parameter of sort (adding + // 'rank' many 1 dimensions at the end). + int64 rank = ShapeUtil::Rank(operand->shape()); + Shape extended_shape = operand->shape(); + extended_shape.add_dimensions(1); + extended_shape.mutable_layout()->add_minor_to_major(rank); + auto reshaped_permutation = computation_->AddInstruction( + HloInstruction::CreateReshape(extended_shape, operand)); + std::vector concat_operands; + for (int64 i = 0; i < rank; ++i) { + if (i == dimension_to_sort) { + concat_operands.push_back(reshaped_permutation); + } else { + concat_operands.push_back(computation_->AddInstruction( + HloInstruction::CreateIota(extended_shape, i))); + } + } + Shape concat_shape = operand->shape(); + concat_shape.add_dimensions(rank); + concat_shape.mutable_layout()->add_minor_to_major(rank); + auto scatter_indices = + rank > 1 ? computation_->AddInstruction( + HloInstruction::CreateConcatenate( + concat_shape, concat_operands, rank)) + : reshaped_permutation; + + // We don't care about the operand, it will be completely overridden by + // the updates. + auto scatter_operand = computation_->AddInstruction( + HloInstruction::CreateIota(sort->operand(1)->shape(), 0)); + + // Construct the updates operand of scatter. + Shape update_shape = sort->operand(1)->shape(); + for (int64 i = 0; i < rank; ++i) { + update_shape.add_dimensions(1); + update_shape.mutable_layout()->add_minor_to_major(rank + i); + } + auto scatter_updates = + computation_->AddInstruction(HloInstruction::CreateReshape( + update_shape, sort->mutable_operand(1))); + + // Construct the updates computation, which simply replaces the operand + // values with the update values. + HloComputation::Builder b("update_replace_computation"); + Shape scalar_shape = ShapeUtil::MakeShape(S32, {}); + b.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "scalar_lhs")); + auto scalar_rhs = b.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "scalar_rhs")); + auto update_replace_computation = + computation_->parent()->AddEmbeddedComputation(b.Build(scalar_rhs)); + + ScatterDimensionNumbers dim_numbers; + dim_numbers.set_index_vector_dim(rank); + for (int64 i = 0; i < rank; ++i) { + dim_numbers.add_update_window_dims(rank + i); + dim_numbers.add_scatter_dims_to_operand_dims(i); + } + auto scatter = + computation_->AddInstruction(HloInstruction::CreateScatter( + sort->operand(1)->shape(), scatter_operand, scatter_indices, + scatter_updates, update_replace_computation, dim_numbers)); + return ReplaceWithNewInstruction( + sort, HloInstruction::CreateTuple( + {computation_->AddInstruction(HloInstruction::CreateIota( + operand->shape(), dimension_to_sort)), + scatter})); + } + } + } return Status::OK(); } @@ -2525,7 +2769,7 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { return ReplaceInstruction(transpose, operand); } - if (is_layout_sensitive_ && TransposeIsBitcast(transpose)) { + if (options_.is_layout_sensitive() && TransposeIsBitcast(transpose)) { ReplaceWithBitcast(transpose); return Status::OK(); } @@ -2674,13 +2918,13 @@ StatusOr AlgebraicSimplifierVisitor::SimplifyConvToDot( const ConvolutionDimensionNumbers& dnums = convolution->convolution_dimension_numbers(); - if (!enable_conv_simplification_) { + if (!options_.enable_conv_simplification()) { return false; } // TODO(b/31337498): For now, we cowardly refuse to do this optimization in // layout-insensitive mode, for fear of adding nontrivial reshapes. - if (!is_layout_sensitive_) { + if (!options_.is_layout_sensitive()) { return false; } @@ -2770,9 +3014,9 @@ StatusOr AlgebraicSimplifierVisitor::SimplifyConvToDot( // We cannot insert bitcasts if the layouts will not be compatible. // TODO(b/33178038): Consider inserting a transpose if a bitcast would be // invalid. - if (!valid_bitcast_callback_(input_shape, new_input_shape) || - !valid_bitcast_callback_(filter_shape, new_filter_shape) || - !valid_bitcast_callback_(dot_output_shape, convolution_shape)) { + if (!options_.valid_bitcast_callback()(input_shape, new_input_shape) || + !options_.valid_bitcast_callback()(filter_shape, new_filter_shape) || + !options_.valid_bitcast_callback()(dot_output_shape, convolution_shape)) { return false; } @@ -2878,9 +3122,7 @@ StatusOr AlgebraicSimplifier::Run(HloModule* module) { "AlgebraicSimplifier::Run(), before:\n" + module->ToString()); bool changed = false; for (auto* comp : module->MakeNonfusionComputations()) { - if (AlgebraicSimplifierVisitor::Run( - comp, is_layout_sensitive_, valid_bitcast_callback_, - enable_dot_strength_reduction_, enable_conv_simplification_)) { + if (AlgebraicSimplifierVisitor::Run(comp, options_)) { changed = true; } } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index 9f8d0ee88bd..d2775b9fafa 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -23,8 +23,7 @@ limitations under the License. namespace xla { -// A pass which performs algebraic simplifications. -class AlgebraicSimplifier : public HloModulePass { +class AlgebraicSimplifierOptions { public: // Given shapes 'from_shape' and 'to_shape', determines if it is valid to // bitcast from 'from_shape' to 'to_shape' after considering platform @@ -34,18 +33,63 @@ class AlgebraicSimplifier : public HloModulePass { using ValidBitcastCallback = std::function; + explicit AlgebraicSimplifierOptions( + ValidBitcastCallback valid_bitcast_callback) + : valid_bitcast_callback_(std::move(valid_bitcast_callback)) {} + // If valid_bitcast_callback returns true, then the pass will replace reshapes + // and transposes with bitcasts. + const ValidBitcastCallback& valid_bitcast_callback() const { + return valid_bitcast_callback_; + } + // If is_layout_sensitive is true, then the simplifier preserves layout during - // transformation. Otherwise, layout is ignored. If valid_bitcast_callback - // returns true, then the pass will replace reshapes and transposes with - // bitcasts. - AlgebraicSimplifier(bool is_layout_sensitive, - ValidBitcastCallback valid_bitcast_callback, - bool enable_dot_strength_reduction = true, - bool enable_conv_simplification = true) - : is_layout_sensitive_(is_layout_sensitive), - valid_bitcast_callback_(std::move(valid_bitcast_callback)), - enable_dot_strength_reduction_(enable_dot_strength_reduction), - enable_conv_simplification_(enable_conv_simplification) {} + // transformation. Otherwise, layout is ignored. + void set_is_layout_sensitive(bool is_layout_sensitive) { + is_layout_sensitive_ = is_layout_sensitive; + } + bool is_layout_sensitive() const { return is_layout_sensitive_; } + + // Enable dot simplification on platforms where it is profitable. + void set_enable_dot_strength_reduction(bool enable_dot_strength_reduction) { + enable_dot_strength_reduction_ = enable_dot_strength_reduction; + } + bool enable_dot_strength_reduction() const { + return enable_dot_strength_reduction_; + } + + // Enable convolution simplification on platforms where it is profitable. + void set_enable_conv_simplification(bool enable_conv_simplification) { + enable_conv_simplification_ = enable_conv_simplification; + } + bool enable_conv_simplification() const { + return enable_conv_simplification_; + } + + // If enable_permutation_sort_replacement is true, a sort op that is known to + // sort a permutation will be replaced with a scatter op. + void set_enable_permutation_sort_replacement( + bool enable_permutation_sort_replacement) { + enable_permutation_sort_replacement_ = enable_permutation_sort_replacement; + } + bool enable_permutation_sort_replacement() const { + return enable_permutation_sort_replacement_; + } + + private: + ValidBitcastCallback valid_bitcast_callback_; + bool is_layout_sensitive_{false}; + bool enable_dot_strength_reduction_{true}; + bool enable_conv_simplification_{true}; + bool enable_permutation_sort_replacement_{false}; +}; + +// A pass which performs algebraic simplifications. +class AlgebraicSimplifier : public HloModulePass { + public: + // If is_layout_sensitive is true, then the simplifier preserves layout during + // transformation. Otherwise, layout is ignored. + explicit AlgebraicSimplifier(const AlgebraicSimplifierOptions& options) + : options_(options) {} ~AlgebraicSimplifier() override = default; absl::string_view name() const override { return "algsimp"; } @@ -54,14 +98,7 @@ class AlgebraicSimplifier : public HloModulePass { StatusOr Run(HloModule* module) override; private: - bool is_layout_sensitive_; - ValidBitcastCallback valid_bitcast_callback_; - - // Enable dot simplification on platforms where it is profitable. - bool enable_dot_strength_reduction_; - - // Enable convolution simplification on platforms where it is profitable. - bool enable_conv_simplification_; + AlgebraicSimplifierOptions options_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_proof_distributive_property.py b/tensorflow/compiler/xla/service/algebraic_simplifier_proof_distributive_property.py new file mode 100644 index 00000000000..5da13da041b --- /dev/null +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_proof_distributive_property.py @@ -0,0 +1,82 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Proof that transforming (A*C)+(B*C) <=> (A+B)*C is "safe" if C=2^k. + +Specifically, for all floating-point values A, B, and C, if + + - C is equal to +/- 2^k for some (possibly negative) integer k, and + - A, B, C, A*C, B*C, and A+B are not subnormal, zero, or inf, + +then there exists a rounding mode rm in [RTZ, RNE] such that + + (A*C) + (B*C) == (A+B) * C (computed with rounding mode rm). + +Informally, this means that the equivalence holds for powers of 2 C, modulo +flushing to zero or inf, and modulo rounding of intermediate results. + +Requires z3 python bindings; try `pip install z3-solver`. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import z3 + +# We do float16 because it lets the solver run much faster. These results +# should generalize to fp32 and fp64, and you can verify this by changing the +# value of FLOAT_TY (and then waiting a while). +FLOAT_TY = z3.Float16 + +a = z3.FP("a", FLOAT_TY()) +b = z3.FP("b", FLOAT_TY()) +c = z3.FP("c", FLOAT_TY()) + +s = z3.Solver() + +# C must be a power of 2, i.e. significand bits must all be 0. +s.add(z3.Extract(FLOAT_TY().sbits() - 1, 0, z3.fpToIEEEBV(c)) == 0) + +for rm in [z3.RTZ(), z3.RNE()]: + z3.set_default_rounding_mode(rm) + before = a * c + b * c + after = (a + b) * c + + # Check that before == after, allowing that 0 == -0. + s.add( + z3.Not( + z3.Or( + before == after, # + z3.And(z3.fpIsZero(before), z3.fpIsZero(after))))) + + for x in [ + (a * c), + (b * c), + (a + b), + ]: + s.add(z3.Not(z3.fpIsSubnormal(x))) + s.add(z3.Not(z3.fpIsZero(x))) + s.add(z3.Not(z3.fpIsInf(x))) + +if s.check() == z3.sat: + m = s.model() + print("Counterexample found!") + print(m) + print("a*c: ", z3.simplify(m[a] * m[c])) + print("b*c: ", z3.simplify(m[b] * m[c])) + print("a+b: ", z3.simplify(m[a] + m[b])) + print("a*c + b*c: ", z3.simplify(m[a] * m[c] + m[b] * m[c])) + print("(a+b) * c: ", z3.simplify((m[a] + m[b]) * m[c])) +else: + print("Proved!") diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index e4c4da1b0e7..14ce519b6a0 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -27,9 +27,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" -#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -42,18 +44,20 @@ namespace xla { namespace { using ::testing::ElementsAre; +namespace m = match; -namespace op = xla::testing::opcode_matchers; - -AlgebraicSimplifier::ValidBitcastCallback bitcasting_callback() { +AlgebraicSimplifierOptions::ValidBitcastCallback bitcasting_callback() { return [](const Shape&, const Shape&) { return true; }; } -AlgebraicSimplifier::ValidBitcastCallback non_bitcasting_callback() { +AlgebraicSimplifierOptions::ValidBitcastCallback non_bitcasting_callback() { return [](const Shape&, const Shape&) { return false; }; } -class AlgebraicSimplifierTest : public HloTestBase {}; +class AlgebraicSimplifierTest : public HloTestBase { + protected: + AlgebraicSimplifierOptions default_options_{non_bitcasting_callback()}; +}; // Test that A + 0 is simplified to A TEST_F(AlgebraicSimplifierTest, AddZero) { @@ -70,13 +74,134 @@ TEST_F(AlgebraicSimplifierTest, AddZero) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } +TEST_F(AlgebraicSimplifierTest, FactorIntegerAddition) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = s32[8] parameter(0) + p1 = s32[8] parameter(1) + p2 = s32[8] parameter(2) + x = s32[8] multiply(p0, p2) + y = s32[8] multiply(p1, p2) + ROOT sum = s32[8] add(x, y) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::MultiplyAnyOrder( + m::AddAnyOrder(m::Parameter(0), m::Parameter(1)), m::Parameter(2)))); +} + +// A*C + B*C => (A+B)*C if C is a floating-point power of 2. +TEST_F(AlgebraicSimplifierTest, FactorFpAddition) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + c = f32[] constant(0.125) + x = f32[] multiply(p0, c) + y = f32[] multiply(p1, c) + ROOT sum = f32[] add(x, y) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::MultiplyAnyOrder( + m::AddAnyOrder(m::Parameter(0), m::Parameter(1)), + m::ConstantScalar(0.125)))); +} + +// A*C + B*C => (A+B)*C if C is a broadcast of a floating-point power of 2. +TEST_F(AlgebraicSimplifierTest, FactorFpAdditionWithBroadcast) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[4] parameter(0) + p1 = f32[4] parameter(1) + c = f32[] constant(0.125) + b = f32[4] broadcast(c), dimensions={} + x = f32[4] multiply(p0, b) + y = f32[4] multiply(p1, b) + ROOT sum = f32[4] add(x, y) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::MultiplyAnyOrder( + m::AddAnyOrder(m::Parameter(0), m::Parameter(1)), + m::Broadcast(m::ConstantScalar(0.125))))); +} + +// A*C + B*C => (A+B)*C simplification should not happen if C is not a +// floating-point power of 2. +TEST_F(AlgebraicSimplifierTest, FactorFpAdditionNotPowerOf2) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + c = f32[] constant(0.3) + x = f32[] multiply(p0, c) + y = f32[] multiply(p1, c) + ROOT sum = f32[] add(x, y) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + EXPECT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); +} + +// A*C + B*C => (A+B)*C simplification should not happen if A, B, and C are +// complex numbers. +TEST_F(AlgebraicSimplifierTest, FactorFpAdditionComplex) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = c64[8] parameter(0) + p1 = c64[8] parameter(1) + p2 = c64[8] parameter(2) + x = c64[8] multiply(p0, p2) + y = c64[8] multiply(p1, p2) + ROOT sum = c64[8] add(x, y) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + EXPECT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); +} + +// A*C + B*C => (A+B)*C simplification is OK if A, B, and C are complex. +TEST_F(AlgebraicSimplifierTest, FactorFpAdditionBfloat16) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = bf16[4] parameter(0) + p1 = bf16[4] parameter(1) + c = bf16[] constant(0.125) + b = bf16[4] broadcast(c), dimensions={} + x = bf16[4] multiply(p0, b) + y = bf16[4] multiply(p1, b) + ROOT sum = bf16[4] add(x, y) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::MultiplyAnyOrder( + m::AddAnyOrder(m::Parameter(0), m::Parameter(1)), + m::Broadcast(m::ConstantScalar(0.125))))); +} + // Test that A * 0 is simplified to 0 TEST_F(AlgebraicSimplifierTest, MulZero) { auto m = CreateNewVerifiedModule(); @@ -92,8 +217,7 @@ TEST_F(AlgebraicSimplifierTest, MulZero) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kMultiply); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_EQ(computation->root_instruction(), zero); } @@ -115,8 +239,7 @@ TEST_F(AlgebraicSimplifierTest, SelectTrue) { auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kSelect); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_EQ(computation->root_instruction(), param0); } @@ -138,8 +261,7 @@ TEST_F(AlgebraicSimplifierTest, SelectFalse) { auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kSelect); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_EQ(computation->root_instruction(), param1); } @@ -159,8 +281,7 @@ TEST_F(AlgebraicSimplifierTest, SelectIdentical) { auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kSelect); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_EQ(computation->root_instruction(), param1); } @@ -196,11 +317,10 @@ TEST_F(AlgebraicSimplifierTest, TwoReducesToOne) { builder.AddInstruction(HloInstruction::CreateReduce(r1f32, reduce0, zero, dims1, add_computation)); m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); HloInstruction* root = m->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Reduce(param, zero)); + EXPECT_THAT(root, GmockMatch(m::Reduce(m::Parameter(0), m::Op().Is(zero)))); EXPECT_EQ(root->dimensions(), std::vector({0, 2, 3})); } @@ -219,11 +339,10 @@ TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Add(param0, op::Constant())); + EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(0), m::Constant()))); } // Test that [(A + C1) + C2] => [A + (C1 + C2)] for constants C1 and C2. @@ -246,11 +365,12 @@ TEST_F(AlgebraicSimplifierTest, AddReassociateMergeConstants) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Add(param0, op::Add(constant1, constant2))); + EXPECT_THAT(root, GmockMatch(m::Add( + m::Op().Is(param0), + m::Add(m::Op().Is(constant1), m::Op().Is(constant2))))); } TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) { @@ -269,8 +389,7 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); @@ -306,11 +425,11 @@ TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kMap); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Add(param0, op::Broadcast(zero))); + EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(0), + m::Broadcast(m::Op().Is(zero))))); } TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { @@ -329,8 +448,7 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); @@ -344,12 +462,11 @@ TEST_F(AlgebraicSimplifierTest, ConstantToBroadcast) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Constant()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + EXPECT_THAT(root, GmockMatch(m::Constant())); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Broadcast(op::Constant())); + EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant()))); EXPECT_EQ(3.14f, root->operand(0)->literal().GetFirstElement()); } @@ -361,12 +478,11 @@ TEST_F(AlgebraicSimplifierTest, ConstantNotToBroadcast) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Constant()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + EXPECT_THAT(root, GmockMatch(m::Constant())); + AlgebraicSimplifier simplifier(default_options_); ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Constant()); + EXPECT_THAT(root, GmockMatch(m::Constant())); } TEST_F(AlgebraicSimplifierTest, IotaToBroadcast) { @@ -377,12 +493,11 @@ TEST_F(AlgebraicSimplifierTest, IotaToBroadcast) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Constant()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + EXPECT_THAT(root, GmockMatch(m::Constant())); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Iota()); + EXPECT_THAT(root, GmockMatch(m::Iota())); } // Test that A - 0 is simplified to A @@ -400,8 +515,7 @@ TEST_F(AlgebraicSimplifierTest, SubZero) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kSubtract); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); @@ -422,11 +536,11 @@ TEST_F(AlgebraicSimplifierTest, SubConstCanonicalization) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kSubtract); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Add(param0, op::Negate(constant))); + EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(0), + m::Negate(m::Op().Is(constant))))); } // Test that (A/B)/C is simplified to A/(B*C). @@ -448,14 +562,16 @@ TEST_F(AlgebraicSimplifierTest, LhsDivOfDiv) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Divide(op::Divide(param0, param1), param2)); + GmockMatch(m::Divide(m::Divide(m::Parameter(0), m::Parameter(1)), + m::Parameter(2)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), - op::Divide(param0, op::Multiply(param1, param2))); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Divide(m::Parameter(0), + m::Multiply(m::Parameter(1), m::Parameter(2))))); } // Test that A/(B/C) is simplified to (A*C)/B. @@ -476,15 +592,18 @@ TEST_F(AlgebraicSimplifierTest, RhsDivOfDiv) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), - op::Divide(param0, op::Divide(param1, param2))); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Divide(m::Parameter(0), + m::Divide(m::Parameter(1), m::Parameter(2))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), - op::Divide(op::Multiply(param0, param2), param1)); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Divide(m::Multiply(m::Parameter(0), m::Parameter(2)), + m::Parameter(1)))); } // Test that (A/B)/(C/D) is simplified to (A*D)/(B*C). @@ -511,15 +630,16 @@ TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) { EXPECT_THAT( computation->root_instruction(), - op::Divide(op::Divide(param0, param1), op::Divide(param2, param3))); + GmockMatch(m::Divide(m::Divide(m::Parameter(0), m::Parameter(1)), + m::Divide(m::Parameter(2), m::Parameter(3))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT( computation->root_instruction(), - op::Divide(op::Multiply(param0, param3), op::Multiply(param1, param2))); + GmockMatch(m::Divide(m::Multiply(m::Parameter(0), m::Parameter(3)), + m::Multiply(m::Parameter(1), m::Parameter(2))))); } // Test that A/exp(B) is simplified to A*exp(-B). @@ -539,14 +659,14 @@ TEST_F(AlgebraicSimplifierTest, DivOfExp) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Divide(param0, op::Exp(param1))); + GmockMatch(m::Divide(m::Parameter(0), m::Exp(m::Parameter(1))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Multiply(param0, op::Exp(op::Negate(param1)))); + GmockMatch(m::Multiply(m::Parameter(0), + m::Exp(m::Negate(m::Parameter(1)))))); } // Test that A/pow(B,C) is simplified to A*pow(B,-C). @@ -567,15 +687,18 @@ TEST_F(AlgebraicSimplifierTest, DivOfPower) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), - op::Divide(param0, op::Power(param1, param2))); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Divide(m::Parameter(0), + m::Power(m::Parameter(1), m::Parameter(2))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Multiply(param0, op::Power(param1, op::Negate(param2)))); + GmockMatch(m::Multiply( + m::Parameter(0), + m::Power(m::Parameter(1), m::Negate(m::Parameter(2)))))); } // Test that broadcasting is done on the right step when simplifying A/pow(B,C) @@ -597,15 +720,18 @@ TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), - op::Divide(param0, op::Power(param1, param2))); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Divide(m::Parameter(0), + m::Power(m::Parameter(1), m::Parameter(2))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); ASSERT_THAT(computation->root_instruction(), - op::Multiply(param0, op::Power(param1, op::Negate(param2)))); + GmockMatch(m::Multiply( + m::Parameter(0), + m::Power(m::Parameter(1), m::Negate(m::Parameter(2)))))); } // A / Const => A * InvertedConst @@ -623,12 +749,11 @@ TEST_F(AlgebraicSimplifierTest, DivideByConstant) { auto computation = m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Multiply(param0, op::Constant())); + GmockMatch(m::Multiply(m::Parameter(0), m::Constant()))); } // pow(pow(A, X), Y) => pow(A, X*Y) @@ -648,11 +773,12 @@ TEST_F(AlgebraicSimplifierTest, PowerOfPower) { inner_power, exp2)); auto computation = m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), - op::Power(base, op::Multiply(exp1, exp2))); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Power(m::Op().Is(base), + m::Multiply(m::Op().Is(exp1), m::Op().Is(exp2))))); } // Don't simplify pow(pow(A, X), Y) => pow(A, X*Y) if X and Y are complex @@ -673,8 +799,7 @@ TEST_F(AlgebraicSimplifierTest, PowerOfPowerComplex) { inner_power, exp2)); m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie()); } @@ -693,8 +818,7 @@ TEST_F(AlgebraicSimplifierTest, DivOneScalar) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, div); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); @@ -715,8 +839,7 @@ TEST_F(AlgebraicSimplifierTest, DivOneArray) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, div); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); @@ -740,8 +863,7 @@ TEST_F(AlgebraicSimplifierTest, ComplexOfRealImagC) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, cplx); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); @@ -765,8 +887,7 @@ TEST_F(AlgebraicSimplifierTest, RealOfComplex) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, real); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); @@ -790,8 +911,7 @@ TEST_F(AlgebraicSimplifierTest, ImagOfComplex) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, imag); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param1); @@ -818,11 +938,10 @@ TEST_F(AlgebraicSimplifierTest, SelectMakeTuple) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, add); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Add(param1, param2)); + EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(1), m::Parameter(2)))); } // Test that exp(A)/exp(B) is simplified to exp(A-B) @@ -843,15 +962,16 @@ TEST_F(AlgebraicSimplifierTest, ExpDiv) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), - op::Divide(op::Exp(param0), op::Exp(param1))); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Divide(m::Exp(m::Parameter(0)), m::Exp(m::Parameter(1))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), - op::Exp(op::Subtract(param0, param1))); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Exp(m::Subtract(m::Parameter(0), m::Parameter(1))))); } // Test that exp(A)*exp(B) is simplified to exp(A+B) @@ -873,14 +993,14 @@ TEST_F(AlgebraicSimplifierTest, ExpMul) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Multiply(op::Exp(param0), op::Exp(param1))); + GmockMatch(m::Multiply(m::Exp(m::Parameter(0)), + m::Exp(m::Parameter(1))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Exp(op::Add(param0, param1))); + GmockMatch(m::Exp(m::Add(m::Parameter(0), m::Parameter(1))))); } // Test that pow(exp(A), B) is simplified to exp(A*B) @@ -900,14 +1020,14 @@ TEST_F(AlgebraicSimplifierTest, PowExp) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Power(op::Exp(param0), param1)); + GmockMatch(m::Power(m::Exp(m::Parameter(0)), m::Parameter(1)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), - op::Exp(op::Multiply(param0, param1))); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Exp(m::Multiply(m::Parameter(0), m::Parameter(1))))); } // Test that ln(pow(A, B)) is simplified to ln(A)*B @@ -927,14 +1047,14 @@ TEST_F(AlgebraicSimplifierTest, LnPow) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Log(op::Power(param0, param1))); + GmockMatch(m::Log(m::Power(m::Parameter(0), m::Parameter(1))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), - op::Multiply(op::Log(param0), param1)); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Multiply(m::Log(m::Parameter(0)), m::Parameter(1)))); } // Test that ln(exp(A)) is simplified to A @@ -951,10 +1071,10 @@ TEST_F(AlgebraicSimplifierTest, LnExp) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Log(op::Exp(param0))); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Log(m::Exp(m::Parameter(0))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_EQ(computation->root_instruction(), param0); @@ -981,13 +1101,14 @@ TEST_F(AlgebraicSimplifierTest, LnExpDiv) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Log(op::Divide(op::Exp(param0), op::Exp(param1)))); + GmockMatch(m::Log(m::Divide(m::Exp(m::Parameter(0)), + m::Exp(m::Parameter(1)))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Subtract(param0, param1)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Subtract(m::Parameter(0), m::Parameter(1)))); } // Test that pow(A, 0) where A is a scalar is simplified to the scalar @@ -1005,14 +1126,14 @@ TEST_F(AlgebraicSimplifierTest, Pow0Scalar) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Power(m::Parameter(0), m::Op().Is(zero)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Constant()); + EXPECT_THAT(root, GmockMatch(m::Constant())); EXPECT_EQ(root->literal().GetFirstElement(), 1); } @@ -1030,14 +1151,14 @@ TEST_F(AlgebraicSimplifierTest, Pow0Vector) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Power(m::Parameter(0), m::Op().Is(zero)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Broadcast()); + EXPECT_THAT(root, GmockMatch(m::Broadcast())); EXPECT_TRUE(ShapeUtil::Equal(root->shape(), r1f32)) << ShapeUtil::HumanString(root->shape()); EXPECT_EQ(root->dimensions().size(), 0); @@ -1059,10 +1180,10 @@ TEST_F(AlgebraicSimplifierTest, Pow1) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Power(param0, one)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Power(m::Parameter(0), m::Op().Is(one)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_EQ(computation->root_instruction(), param0); @@ -1082,13 +1203,14 @@ TEST_F(AlgebraicSimplifierTest, Pow2) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Power(param0, two)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Power(m::Parameter(0), m::Op().Is(two)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Multiply(param0, param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0)))); } // Test that pow(A, -1) is simplified to 1/A. @@ -1105,14 +1227,14 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Power(param0, negative_one)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Power(m::Parameter(0), m::Op().Is(negative_one)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Divide(op::Broadcast(), param0)); + EXPECT_THAT(root, GmockMatch(m::Divide(m::Broadcast(), m::Parameter(0)))); EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kBroadcast); EXPECT_EQ(root->operand(0)->operand(0)->literal().GetFirstElement(), 1); @@ -1153,13 +1275,12 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) { ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); m->AddEntryComputation(builder.Build()); - HloPassFix simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + HloPassFix simplifier(default_options_); EXPECT_THAT(m->entry_computation()->root_instruction(), - op::Convolution(lhs, rhs)); + GmockMatch(m::Convolution(m::Op().Is(lhs), m::Op().Is(rhs)))); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(m->entry_computation()->root_instruction(), - op::Broadcast(op::Constant())); + GmockMatch(m::Broadcast(m::Constant()))); } TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) { @@ -1196,13 +1317,12 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))), window, add_computation)); m->AddEntryComputation(builder.Build()); - HloPassFix simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + HloPassFix simplifier(default_options_); EXPECT_THAT(m->entry_computation()->root_instruction(), - op::ReduceWindow(param, op::Constant())); + GmockMatch(m::ReduceWindow(m::Parameter(0), m::Constant()))); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(m->entry_computation()->root_instruction(), - op::Broadcast(op::Constant())); + GmockMatch(m::Broadcast(m::Constant()))); } TEST_F(AlgebraicSimplifierTest, ZeroSizedPad) { @@ -1225,12 +1345,11 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedPad) { padding)); m->AddEntryComputation(builder.Build()); EXPECT_THAT(m->entry_computation()->root_instruction(), - op::Pad(param, op::Constant())); - HloPassFix simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + GmockMatch(m::Pad(m::Parameter(0), m::Constant()))); + HloPassFix simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(m->entry_computation()->root_instruction(), - op::Broadcast(op::Constant())); + GmockMatch(m::Broadcast(m::Constant()))); } TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) { @@ -1251,10 +1370,9 @@ TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) { m->AddEntryComputation(std::move(computation)); EXPECT_THAT(m->entry_computation()->root_instruction(), - op::Reshape(op::Broadcast(op::Reshape(op)))); + GmockMatch(m::Reshape(m::Broadcast(m::Reshape(m::Op().Is(op)))))); - HloPassFix simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + HloPassFix simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(m->entry_computation()->root_instruction(), op); @@ -1271,10 +1389,10 @@ TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Convert(input)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Convert(m::Op().Is(input)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), input); @@ -1292,10 +1410,10 @@ TEST_F(AlgebraicSimplifierTest, RemoveCopy) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Parameter(0)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param0); @@ -1314,19 +1432,24 @@ TEST_F(AlgebraicSimplifierTest, CopyEqualsBitcast) { *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 2, 0, 3}); auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Copy(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Parameter(0)))); - AlgebraicSimplifier simplifier1(/*is_layout_sensitive=*/true, - non_bitcasting_callback()); + AlgebraicSimplifierOptions options(non_bitcasting_callback()); + options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier1(options); ASSERT_FALSE(simplifier1.Run(m.get()).ValueOrDie()); // Verify that the copy is not replaced. - EXPECT_THAT(computation->root_instruction(), op::Copy(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Parameter(0)))); - AlgebraicSimplifier simplifier2(/*is_layout_sensitive=*/true, - bitcasting_callback()); + AlgebraicSimplifierOptions options2(bitcasting_callback()); + options2.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier2(options2); ASSERT_TRUE(simplifier2.Run(m.get()).ValueOrDie()); // Verify that the copy is replaced. - EXPECT_THAT(computation->root_instruction(), op::Bitcast(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Bitcast(m::Parameter(0)))); } // Test that unary concatenates are removed. @@ -1341,10 +1464,10 @@ TEST_F(AlgebraicSimplifierTest, RemoveUnaryConcatenate) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Concatenate(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Concatenate(m::Parameter(0)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param0); @@ -1371,16 +1494,17 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT( - computation->root_instruction(), - op::Concatenate(empty_literal, param0, param0, empty_slice, param1)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Concatenate( + m::Op().Is(empty_literal), m::Parameter(0), m::Parameter(0), + m::Op().Is(empty_slice), m::Parameter(1)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Concatenate(param0, param0, param1)); + GmockMatch(m::Concatenate(m::Parameter(0), m::Parameter(0), + m::Parameter(1)))); } // Test that reduce of concat is simplified. @@ -1423,14 +1547,14 @@ TEST_F(AlgebraicSimplifierTest, SimplifyReduceOfConcat) { auto computation = m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT( computation->root_instruction(), - op::Map(op::Map(op::Reduce(param0, zero), op::Reduce(param1, zero)), - op::Reduce(param2, zero))); + GmockMatch(m::Map(m::Map(m::Reduce(m::Parameter(0), m::Op().Is(zero)), + m::Reduce(m::Parameter(1), m::Op().Is(zero))), + m::Reduce(m::Parameter(2), m::Op().Is(zero))))); } // Test a concatenate with only empty operands is removed. @@ -1453,10 +1577,10 @@ TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Concatenate(empty_literal, empty_slice)); + GmockMatch(m::Concatenate(m::Op().Is(empty_literal), + m::Op().Is(empty_slice)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_EQ(computation->root_instruction(), empty_literal); @@ -1479,10 +1603,80 @@ TEST_F(AlgebraicSimplifierTest, ConcatenateOfBroadcastBecomesPad) { auto computation = m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Pad(param0, param1)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Pad(m::Parameter(0), m::Parameter(1)))); +} + +TEST_F(AlgebraicSimplifierTest, SimplifyConcatenateOfSlices) { + auto m = CreateNewVerifiedModule(); + Shape r2f32 = ShapeUtil::MakeShape(F32, {100, 99}); + Shape concat_shape = ShapeUtil::MakeShape(F32, {50, 80}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r2f32, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r2f32, "param1")); + + HloInstruction* slice0 = builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{0, 0}, + /*limit_indices=*/{50, 10}, /*strides=*/{1, 1})); + + // Cannot merge 'slice0' and 'slice1' because of different start indices in + // dimension 0. + HloInstruction* slice1 = builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 10}, + /*limit_indices=*/{100, 20}, /*strides=*/{1, 1})); + + // Cannot merge 'slice1' and 'slice2' because of stride in dimension 2. + HloInstruction* slice2 = builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 20}, + /*limit_indices=*/{100, 40}, /*strides=*/{1, 2})); + + // Cannot merge 'slice2' and 'slice3' because of stride in dimension 2. + HloInstruction* slice3 = builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 40}, + /*limit_indices=*/{100, 50}, /*strides=*/{1, 1})); + + // Can merge 'slice3' and 'slice4'. + HloInstruction* slice4 = builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 50}, + /*limit_indices=*/{100, 60}, /*strides=*/{1, 1})); + + // Can merge 'slice4' and 'slice5'. + HloInstruction* slice5 = builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 60}, + /*limit_indices=*/{100, 70}, /*strides=*/{1, 1})); + + // Cannot merge 'slice5' and 'slice6' because of overlap. + HloInstruction* slice6 = builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 69}, + /*limit_indices=*/{100, 79}, /*strides=*/{1, 1})); + + // Cannot merge 'slice6' and 'slice7' because of slicing from a different + // parameter. + HloInstruction* slice7 = builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {50, 10}), param1, /*start_indices=*/{50, 79}, + /*limit_indices=*/{100, 89}, /*strides=*/{1, 1})); + + builder.AddInstruction(HloInstruction::CreateConcatenate( + concat_shape, + {slice0, slice1, slice2, slice3, slice4, slice5, slice6, slice7}, 1)); + auto computation = m->AddEntryComputation(builder.Build()); + + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + auto s = m::Slice(m::Parameter(0)); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Concatenate(s, s, s, s, s, m::Slice(m::Parameter(1))))); + // The operand 3 should be a merge of 'slice3', 'slice4' and 'slice5', so its + // shape should have dimensions {50, 30}. + EXPECT_TRUE( + ShapeUtil::Equal(computation->root_instruction()->operand(3)->shape(), + ShapeUtil::MakeShape(F32, {50, 30}))); + EXPECT_EQ(computation->root_instruction()->operand(3)->slice_starts(1), 40); } // Test that a simplification which changes layouts is not performed if layout @@ -1502,14 +1696,17 @@ TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) { *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0}); - EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Parameter(0)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, - non_bitcasting_callback()); + AlgebraicSimplifierOptions options(non_bitcasting_callback()); + options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier(options); EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); // Copy has not been removed. - EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Parameter(0)))); } // Test that a simplification which preserves layouts is performed if layout @@ -1529,10 +1726,12 @@ TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) { *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); - EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Parameter(0)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, - non_bitcasting_callback()); + AlgebraicSimplifierOptions options(non_bitcasting_callback()); + options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier(options); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); // Copy has been removed. @@ -1557,14 +1756,17 @@ TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Parameter(0)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, - non_bitcasting_callback()); + AlgebraicSimplifierOptions options(non_bitcasting_callback()); + options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier(options); EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); // Reshape is not replaced with a bitcast. - EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Parameter(0)))); } // Test transforming reshapes and transposes of rng. @@ -1588,13 +1790,13 @@ TEST_F(AlgebraicSimplifierTest, ReshapeOfTransposeOfRngToRng) { auto computation = m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - bitcasting_callback()); + AlgebraicSimplifier simplifier( + (AlgebraicSimplifierOptions(bitcasting_callback()))); EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - // Verify that that reshape(transpose(rng)) is replace by a single rng of the + // Verify that reshape(transpose(rng)) is replace by a single rng of the // same shape as the reshape. - EXPECT_THAT(computation->root_instruction(), op::Rng()); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Rng())); EXPECT_TRUE(ShapeUtil::Equal(computation->root_instruction()->shape(), reshape_shape)); } @@ -1636,17 +1838,20 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Tuple(transformable_reshape, dimensions_wrong_reshape, - layout_wrong_reshape)); + GmockMatch(m::Tuple(m::Op().Is(transformable_reshape), + m::Op().Is(dimensions_wrong_reshape), + m::Op().Is(layout_wrong_reshape)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, - bitcasting_callback()); + AlgebraicSimplifierOptions options(bitcasting_callback()); + options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier(options); simplifier.Run(m.get()).ValueOrDie(); // Verify that only the first reshape is replaced. EXPECT_THAT( computation->root_instruction(), - op::Tuple(op::Bitcast(), dimensions_wrong_reshape, layout_wrong_reshape)); + GmockMatch(m::Tuple(m::Bitcast(), m::Op().Is(dimensions_wrong_reshape), + m::Op().Is(layout_wrong_reshape)))); } // Regression test for a bug where if we failed to sink a reshape, we'd set the @@ -1667,8 +1872,8 @@ TEST_F(AlgebraicSimplifierTest, FailureToSinkReshapeDoesntAffectChangedBit) { builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {4}), add)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - bitcasting_callback()); + AlgebraicSimplifier simplifier( + (AlgebraicSimplifierOptions(bitcasting_callback()))); m->AddEntryComputation(builder.Build()); EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie()); } @@ -1692,8 +1897,8 @@ TEST_F(AlgebraicSimplifierTest, FailureToSinkBroadcastDoesntAffectChangedBit) { HloInstruction::CreateBroadcast(ShapeUtil::MakeShape(F32, {2, 2, 2}), add, /*broadcast_dimensions=*/{0, 1})); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - bitcasting_callback()); + AlgebraicSimplifier simplifier( + (AlgebraicSimplifierOptions(bitcasting_callback()))); m->AddEntryComputation(builder.Build()); EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie()); } @@ -1715,14 +1920,17 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Transpose(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Transpose(m::Parameter(0)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, - bitcasting_callback()); + AlgebraicSimplifierOptions options(bitcasting_callback()); + options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier(options); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); // Verify that the reshape is replaced. - EXPECT_THAT(computation->root_instruction(), op::Bitcast(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Bitcast(m::Parameter(0)))); } TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) { @@ -1742,14 +1950,17 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Transpose(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Transpose(m::Parameter(0)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, - bitcasting_callback()); + AlgebraicSimplifierOptions options(bitcasting_callback()); + options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier(options); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); // Verify that the reshape is replaced. - EXPECT_THAT(computation->root_instruction(), op::Bitcast(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Bitcast(m::Parameter(0)))); } TEST_F(AlgebraicSimplifierTest, ReshapesMerged) { @@ -1769,13 +1980,13 @@ TEST_F(AlgebraicSimplifierTest, ReshapesMerged) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Reshape(param0))); + GmockMatch(m::Reshape(m::Reshape(m::Parameter(0))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Parameter(0)))); } TEST_F(AlgebraicSimplifierTest, CopiesMerged) { @@ -1796,13 +2007,16 @@ TEST_F(AlgebraicSimplifierTest, CopiesMerged) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Copy(op::Copy(param0))); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Copy(m::Parameter(0))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, - non_bitcasting_callback()); + AlgebraicSimplifierOptions options(non_bitcasting_callback()); + options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier(options); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Parameter(0)))); } TEST_F(AlgebraicSimplifierTest, TransposesMerged) { @@ -1821,13 +2035,14 @@ TEST_F(AlgebraicSimplifierTest, TransposesMerged) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Transpose(transpose1)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Transpose(m::Op().Is(transpose1)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Transpose(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Transpose(m::Parameter(0)))); EXPECT_EQ(std::vector({2, 1, 0}), computation->root_instruction()->dimensions()); } @@ -1846,13 +2061,13 @@ TEST_F(AlgebraicSimplifierTest, ReshapeAndBroadcastMerged) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Broadcast(op::Reshape(param0))); + GmockMatch(m::Broadcast(m::Reshape(m::Parameter(0))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Broadcast(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Broadcast(m::Parameter(0)))); } // Test merging broadcast and reshape. @@ -1869,13 +2084,13 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshapeMerged) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Broadcast(param0))); + GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Broadcast(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Broadcast(m::Parameter(0)))); } TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x1_3) { @@ -1891,14 +2106,13 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x1_3) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Broadcast(param))); + GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Broadcast(param))); + GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0))))); } TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4_6x1x1x4) { @@ -1914,13 +2128,13 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4_6x1x1x4) { HloComputation* computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Broadcast(param))); + GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Broadcast(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Broadcast(m::Parameter(0)))); EXPECT_THAT(computation->root_instruction()->dimensions(), ::testing::ElementsAre(3)); } @@ -1938,13 +2152,13 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x2x1_6x1x1x1) { HloComputation* computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Broadcast(param))); + GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Broadcast(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Broadcast(m::Parameter(0)))); const std::vector broadcast_dims = computation->root_instruction()->dimensions(); EXPECT_EQ(1, broadcast_dims.size()); @@ -1964,14 +2178,13 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) { HloComputation* computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Broadcast(param))); + GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Broadcast(param))); + GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0))))); } TEST_F(AlgebraicSimplifierTest, IotaAndReshapeMerged) { @@ -1984,13 +2197,13 @@ TEST_F(AlgebraicSimplifierTest, IotaAndReshapeMerged) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Iota()))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Iota()); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota())); EXPECT_TRUE( ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape)); } @@ -2004,14 +2217,13 @@ TEST_F(AlgebraicSimplifierTest, IotaEffectiveScalar) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Iota()); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota())); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); auto root = computation->root_instruction(); - EXPECT_THAT(root, op::Broadcast(op::Constant())); + EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant()))); EXPECT_EQ(0.0f, root->operand(0)->literal().GetFirstElement()); EXPECT_TRUE( ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape)); @@ -2027,13 +2239,14 @@ TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2_6) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Iota()))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Iota()))); } TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4_6x1x1x4) { @@ -2046,13 +2259,13 @@ TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4_6x1x1x4) { HloComputation* computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Iota()))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Iota()); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota())); EXPECT_EQ(Cast(computation->root_instruction()) ->iota_dimension(), 3); @@ -2068,13 +2281,13 @@ TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2x2_6x1x1x2) { HloComputation* computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Iota()))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Iota()); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota())); const int64 iota_dim = Cast(computation->root_instruction()) ->iota_dimension(); @@ -2091,13 +2304,14 @@ TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4x2_6x8) { HloComputation* computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Iota()))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Iota()))); } TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { @@ -2120,10 +2334,10 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); @@ -2153,8 +2367,7 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) { auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); auto has_negative_padding = [](const HloInstruction* pad) { for (auto& padding_dimension : pad->padding_config().dimensions()) { @@ -2166,16 +2379,54 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) { return false; }; - EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero)))); EXPECT_TRUE(has_negative_padding(pad)); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Slice(op::Pad(param, zero))); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Slice(m::Pad(m::Parameter(0), m::Op().Is(zero))))); EXPECT_FALSE( has_negative_padding(computation->root_instruction()->operand(0))); } +TEST_F(AlgebraicSimplifierTest, TrivialInteriorPadding) { + // Verify that a pad instruction with interior padding on one-sized + // dimensions, removes the interior padding. + HloComputation::Builder builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {2, 1}), "param")); + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + PaddingConfig padding; + for (int i = 0; i < 2; ++i) { + auto dimension = padding.add_dimensions(); + dimension->set_edge_padding_low(3); + dimension->set_edge_padding_high(3); + dimension->set_interior_padding(i * 3); + } + HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad( + ShapeUtil::MakeShape(F32, {8, 7}), param, zero, padding)); + + auto module = CreateNewVerifiedModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + AlgebraicSimplifier simplifier(default_options_); + + ASSERT_THAT(computation->root_instruction(), + GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero)))); + ASSERT_TRUE(HasInteriorPadding(pad->padding_config())); + + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero)))); + EXPECT_FALSE( + HasInteriorPadding(computation->root_instruction()->padding_config())); +} + TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) { HloComputation::Builder builder(TestName()); HloInstruction* param = @@ -2187,10 +2438,10 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) { auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Parameter(0)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); @@ -2210,10 +2461,10 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) { auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Slice(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Slice(m::Parameter(0)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); @@ -2239,13 +2490,14 @@ TEST_F(AlgebraicSimplifierTest, SliceOfSliceToSlice) { auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Slice(op::Slice(param))); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Slice(m::Slice(m::Parameter(0))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Slice(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Slice(m::Parameter(0)))); EXPECT_EQ(computation->root_instruction()->slice_starts(0), 3); EXPECT_EQ(computation->root_instruction()->slice_starts(1), 5); EXPECT_EQ(computation->root_instruction()->slice_limits(0), dim0 - 2); @@ -2271,13 +2523,14 @@ TEST_F(AlgebraicSimplifierTest, SliceOfReshapeToReshapeOfSlice) { auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Slice(op::Reshape(param))); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Slice(m::Reshape(m::Parameter(0))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Slice(param))); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Slice(m::Parameter(0))))); } TEST_F(AlgebraicSimplifierTest, SliceOfReshapeUnchanged) { @@ -2296,10 +2549,10 @@ TEST_F(AlgebraicSimplifierTest, SliceOfReshapeUnchanged) { auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Slice(op::Reshape(param))); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Slice(m::Reshape(m::Parameter(0))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); } @@ -2312,12 +2565,84 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSort) { builder.AddInstruction(HloInstruction::CreateSort(keys_shape, 0, keys)); auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), keys); } +TEST_F(AlgebraicSimplifierTest, ReplacePermutationSortWithScatter) { + const char* hlo_string = R"( + HloModule permutation_sort + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} iota(), iota_dimension=1 + sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), dimensions={1} + gte = s32[64,8732]{1,0} get-tuple-element(sort), index=1 + ROOT sort2 = (s32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(gte, values), dimensions={1} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options(non_bitcasting_callback()); + options.set_enable_permutation_sort_replacement(true); + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, + GmockMatch(m::Tuple( + m::Iota(), + m::Scatter(m::Iota(), m::Concatenate(m::Iota(), m::Reshape()), + m::Reshape())))); +} + +TEST_F(AlgebraicSimplifierTest, DontReplacePermutationSortIfNonIntegral) { + // Same as ReplacePermutationSortWithScatter except that the iota has F32 + // type. + const char* hlo_string = R"( + HloModule permutation_sort + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = f32[64,8732]{1,0} iota(), iota_dimension=1 + sort = (f32[64,8732]{1,0}, f32[64,8732]{1,0}) sort(keys, values), dimensions={1} + gte = f32[64,8732]{1,0} get-tuple-element(sort), index=1 + ROOT sort2 = (f32[64,8732]{1,0}, f32[64,8732]{1,0}) sort(gte, values), dimensions={1} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options(non_bitcasting_callback()); + options.set_enable_permutation_sort_replacement(true); + AlgebraicSimplifier simplifier(options); + EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); +} + +TEST_F(AlgebraicSimplifierTest, DontReplacePermutationSortWrongDimensions) { + // Same as ReplacePermutationSortWithScatter except that the sort dimensions + // don't match. + const char* hlo_string = R"( + HloModule permutation_sort + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} iota(), iota_dimension=1 + sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), dimensions={1} + gte = s32[64,8732]{1,0} get-tuple-element(sort), index=1 + ROOT sort2 = (s32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(gte, values), dimensions={0} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options(non_bitcasting_callback()); + options.set_enable_permutation_sort_replacement(true); + AlgebraicSimplifier simplifier(options); + EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); +} + TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) { auto builder = HloComputation::Builder(TestName()); @@ -2334,11 +2659,11 @@ TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) { keys, {values0, values1})); auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Tuple(keys, values0, values1)); + GmockMatch(m::Tuple(m::Op().Is(keys), m::Op().Is(values0), + m::Op().Is(values1)))); } // Test that A && True is simplified to A @@ -2356,8 +2681,7 @@ TEST_F(AlgebraicSimplifierTest, AndTrue) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAnd); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); @@ -2378,8 +2702,7 @@ TEST_F(AlgebraicSimplifierTest, AndTrue2) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAnd); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); @@ -2400,8 +2723,7 @@ TEST_F(AlgebraicSimplifierTest, AndFalse) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAnd); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, const_false); @@ -2422,8 +2744,7 @@ TEST_F(AlgebraicSimplifierTest, AndFalse2) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAnd); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, const_false); @@ -2444,8 +2765,7 @@ TEST_F(AlgebraicSimplifierTest, OrTrue) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kOr); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, const_true); @@ -2466,8 +2786,7 @@ TEST_F(AlgebraicSimplifierTest, OrTrue2) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kOr); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, const_true); @@ -2488,8 +2807,7 @@ TEST_F(AlgebraicSimplifierTest, OrFalse) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kOr); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); @@ -2510,8 +2828,7 @@ TEST_F(AlgebraicSimplifierTest, OrFalse2) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kOr); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); @@ -2641,15 +2958,15 @@ TEST_P(ConvInputPaddingTest, DoTest) { auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); if (testcase.expected_conv_window.empty()) { ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); } else { ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto* conv = module->entry_computation()->root_instruction(); SCOPED_TRACE(module->ToString()); - ASSERT_THAT(conv, op::Convolution(op::Parameter(), op::Parameter())); + ASSERT_THAT(conv, + GmockMatch(m::Convolution(m::Parameter(), m::Parameter()))); EXPECT_EQ(window_util::ToString(conv->window()), absl::StrCat("size=3x3 ", testcase.expected_conv_window)); } @@ -2759,15 +3076,15 @@ TEST_P(ConvFilterPaddingTest, DoIt) { auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); if (testcase.expected_conv_window.empty()) { ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); } else { ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto* conv = module->entry_computation()->root_instruction(); SCOPED_TRACE(module->ToString()); - ASSERT_THAT(conv, op::Convolution(op::Parameter(), op::Parameter())); + ASSERT_THAT(conv, + GmockMatch(m::Convolution(m::Parameter(), m::Parameter()))); EXPECT_EQ(window_util::ToString(conv->window()), absl::StrFormat("size=%dx%d %s", conv->operand(1)->shape().dimensions(2), @@ -2908,8 +3225,9 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { auto module = CreateNewUnverifiedModule(); auto* computation = module->AddEntryComputation(b.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, - bitcasting_callback()); + AlgebraicSimplifierOptions simplifier_options(bitcasting_callback()); + simplifier_options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier(simplifier_options); if (!simplifier.Run(module.get()).ValueOrDie()) { return "NO_CHANGE"; } @@ -3032,17 +3350,15 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { EXPECT_EQ(root, slice); EXPECT_TRUE(ShapeUtil::Equal(root->shape(), slice_shape)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); // Running simplification again should not result in any further changes. ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); - - root = computation->root_instruction(); - EXPECT_THAT(root, op::Broadcast(scalar_param)); - EXPECT_TRUE(ShapeUtil::Equal(root->shape(), slice_shape)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Broadcast(m::Op().Is(scalar_param)) + .WithShapeEqualTo(&slice_shape))); } // Test that reshape(transpose(broadcast(/*scalar value*/))) simplifies to a @@ -3071,13 +3387,11 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { EXPECT_EQ(root, reshape); EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reshape_shape)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - - root = computation->root_instruction(); - EXPECT_THAT(root, op::Broadcast(forty_two)); - EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reshape_shape)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Broadcast(m::Op().Is(forty_two)) + .WithShapeEqualTo(&reshape_shape))); } // Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x). @@ -3138,8 +3452,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, reduce_window); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); // Running simplification again should not result in any further changes. @@ -3147,7 +3460,8 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { // Verify the result root = computation->root_instruction(); - EXPECT_THAT(root, op::ReduceWindow(operand, op::Constant())); + EXPECT_THAT(root, + GmockMatch(m::ReduceWindow(m::Op().Is(operand), m::Constant()))); EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reduce_window_shape)) << ShapeUtil::HumanString(root->shape()) << " vs " << ShapeUtil::HumanString(reduce_window_shape); @@ -3224,8 +3538,7 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, reduce_window); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); // Running simplification again should not result in any further changes. @@ -3233,7 +3546,8 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { // Verify the result root = computation->root_instruction(); - EXPECT_THAT(root, op::ReduceWindow(op::Convert(parameter), op::Constant())); + EXPECT_THAT(root, GmockMatch(m::ReduceWindow(m::Convert(m::Parameter(0)), + m::Constant()))); EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reduce_window_shape)) << ShapeUtil::HumanString(root->shape()) << " vs " << ShapeUtil::HumanString(reduce_window_shape); @@ -3258,8 +3572,7 @@ TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) { auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); @@ -3295,8 +3608,7 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { m->AddEmbeddedComputation(std::move(dot_computation)); m->AddEntryComputation(call_builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); } @@ -3313,11 +3625,10 @@ TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) { auto computation = m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Tuple(op::Constant(), op::Constant())); + GmockMatch(m::Tuple(m::Constant(), m::Constant()))); } // A dynamic-slice is trivial if its start indices are all zeroes and the size @@ -3337,10 +3648,9 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) { /*slice_sizes=*/{10, 100, 1000})); auto computation = m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Parameter()); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Parameter())); } // A dynamic-update-slice is trivial if its start indices are all zeroes and the @@ -3371,11 +3681,10 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) { 3, ShapeUtil::MakeShape(U32, {3}), "update_indices")))); auto computation = m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::DynamicSlice(op::Parameter(), op::Parameter())); + GmockMatch(m::DynamicSlice(m::Parameter(), m::Parameter()))); } // Test that two consecutive broadcasts can be merged to one. @@ -3394,11 +3703,10 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcasts) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Broadcast(op::Constant())); + EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant()))); EXPECT_THAT(root->dimensions(), ElementsAre(2)); } @@ -3421,11 +3729,10 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcasts2) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Broadcast(op::Parameter(0))); + EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Parameter(0)))); EXPECT_THAT(root->dimensions(), ElementsAre(1, 3)); } @@ -3442,11 +3749,10 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Iota()); + EXPECT_THAT(root, GmockMatch(m::Iota())); EXPECT_EQ(Cast(root)->iota_dimension(), 2); } @@ -3464,11 +3770,10 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota2) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Iota()); + EXPECT_THAT(root, GmockMatch(m::Iota())); EXPECT_EQ(Cast(root)->iota_dimension(), 2); } @@ -3486,11 +3791,11 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadLow) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - bitcasting_callback()); + AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Reshape(op::Constant())); + EXPECT_THAT(root, GmockMatch(m::Reshape(m::Constant()))); } TEST_F(AlgebraicSimplifierTest, SliceOfPadHigh) { @@ -3507,11 +3812,11 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadHigh) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - bitcasting_callback()); + AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Reshape(op::Constant())); + EXPECT_THAT(root, GmockMatch(m::Reshape(m::Constant()))); } TEST_F(AlgebraicSimplifierTest, SliceOfPadMidNonScalar) { @@ -3528,8 +3833,8 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadMidNonScalar) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - bitcasting_callback()); + AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifier simplifier(options); EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); } @@ -3547,11 +3852,11 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalar) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - bitcasting_callback()); + AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Parameter()); + EXPECT_THAT(root, GmockMatch(m::Parameter())); } TEST_F(AlgebraicSimplifierTest, SliceOfConcatScalarInput) { @@ -3569,11 +3874,11 @@ TEST_F(AlgebraicSimplifierTest, SliceOfConcatScalarInput) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - bitcasting_callback()); + AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Parameter(1)); + EXPECT_THAT(root, GmockMatch(m::Parameter(1))); } TEST_F(AlgebraicSimplifierTest, SliceOfConcatNonScalarInput) { @@ -3591,11 +3896,11 @@ TEST_F(AlgebraicSimplifierTest, SliceOfConcatNonScalarInput) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - bitcasting_callback()); + AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Slice(op::Parameter(2))); + EXPECT_THAT(root, GmockMatch(m::Slice(m::Parameter(2)))); EXPECT_EQ(root->slice_starts(0), 1); EXPECT_EQ(root->slice_limits(0), 2); } @@ -3613,11 +3918,11 @@ TEST_F(AlgebraicSimplifierTest, NegateNegate) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - bitcasting_callback()); + AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Parameter(0)); + EXPECT_THAT(root, GmockMatch(m::Parameter(0))); } TEST_F(AlgebraicSimplifierTest, NotNot) { @@ -3633,11 +3938,11 @@ TEST_F(AlgebraicSimplifierTest, NotNot) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - bitcasting_callback()); + AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Parameter(0)); + EXPECT_THAT(root, GmockMatch(m::Parameter(0))); } struct PadReduceWindowEffectiveBroadcastCase { @@ -3733,8 +4038,7 @@ TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) { output_shape, pad, zero, window, add_computation)); auto computation = m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get())); ASSERT_TRUE(run_successful); @@ -3742,10 +4046,10 @@ TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) { ShapeUtil::Equal(computation->root_instruction()->shape(), output_shape)); if (param.should_become_broadcast) { - EXPECT_THAT(computation->root_instruction(), op::Broadcast(::testing::_)); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Broadcast())); } else { EXPECT_THAT(computation->root_instruction(), - op::ReduceWindow(::testing::_, zero)); + GmockMatch(m::ReduceWindow(m::Op(), m::Op().Is(zero)))); } } @@ -3815,8 +4119,7 @@ TEST_P(DotStrengthReductionTest, DotStrengthReduction) { builder.AddInstruction(HloInstruction::CreateDot( dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); auto computation = module->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get())); const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1; const bool computation_should_be_modified = @@ -3845,7 +4148,7 @@ struct DotOfConcatTestSpec { }; class DotOfConcatSimplificationTest - : public HloTestBase, + : public AlgebraicSimplifierTest, public ::testing::WithParamInterface {}; // Test that we transform @@ -3893,19 +4196,19 @@ TEST_P(DotOfConcatSimplificationTest, ConstantLHS) { dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); auto computation = m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get())); ASSERT_TRUE(run_successful); EXPECT_TRUE( ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape)); - auto match_dot_0 = op::Dot(op::Slice(op::Constant()), op::Parameter(0)); - auto match_dot_1 = op::Dot(op::Slice(op::Constant()), op::Parameter(1)); - auto match_dot_2 = op::Dot(op::Slice(op::Constant()), op::Parameter(2)); - EXPECT_THAT(computation->root_instruction(), - op::Add(op::Add(match_dot_0, match_dot_1), match_dot_2)); + auto match_dot_0 = m::Dot(m::Slice(m::Constant()), m::Parameter(0)); + auto match_dot_1 = m::Dot(m::Slice(m::Constant()), m::Parameter(1)); + auto match_dot_2 = m::Dot(m::Slice(m::Constant()), m::Parameter(2)); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Add(m::Add(match_dot_0, match_dot_1), match_dot_2))); } // Test that we transform @@ -3958,20 +4261,20 @@ TEST_P(DotOfConcatSimplificationTest, ConstantRHS) { dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); auto computation = m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get())); ASSERT_TRUE(run_successful); EXPECT_TRUE( ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape)); - auto match_dot_0 = op::Dot(op::Parameter(0), op::Slice(op::Constant())); - auto match_dot_1 = op::Dot(op::Parameter(1), op::Slice(op::Constant())); - auto match_dot_2 = op::Dot(op::Parameter(2), op::Slice(op::Constant())); - auto match_dot_3 = op::Dot(op::Parameter(3), op::Slice(op::Constant())); - EXPECT_THAT(computation->root_instruction(), - op::Add(op::Add(op::Add(match_dot_0, match_dot_1), match_dot_2), - match_dot_3)); + auto match_dot_0 = m::Dot(m::Parameter(0), m::Slice(m::Constant())); + auto match_dot_1 = m::Dot(m::Parameter(1), m::Slice(m::Constant())); + auto match_dot_2 = m::Dot(m::Parameter(2), m::Slice(m::Constant())); + auto match_dot_3 = m::Dot(m::Parameter(3), m::Slice(m::Constant())); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Add(m::Add(m::Add(match_dot_0, match_dot_1), match_dot_2), + match_dot_3))); } DotOfConcatTestSpec kDotOfConcatTestSpecs[] = { @@ -4000,8 +4303,7 @@ TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceZeroUpdate) { const HloComputation* const computation = m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), operand); } @@ -4021,7 +4323,7 @@ struct DotOfGatherTestSpec { }; class DotOfGatherSimplificationTest - : public HloTestBase, + : public AlgebraicSimplifierTest, public ::testing::WithParamInterface {}; // input: dot(DS(ctA), ctB)) @@ -4078,8 +4380,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { dot_shape, ds, rhs, dot_dnums, DefaultPrecisionConfig(2))); auto computation = m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get())); ASSERT_TRUE(run_successful); EXPECT_TRUE( @@ -4090,8 +4391,8 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { HloOpcode::kDynamicSlice); } else { EXPECT_THAT(computation->root_instruction(), - op::DynamicSlice(op::Dot(op::Constant(), op::Constant()), - op::Concatenate())); + GmockMatch(m::DynamicSlice(m::Dot(m::Constant(), m::Constant()), + m::Concatenate()))); } } @@ -4149,8 +4450,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) { dot_shape, lhs, ds, dot_dnums, DefaultPrecisionConfig(2))); auto computation = m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get())); ASSERT_TRUE(run_successful); EXPECT_TRUE( @@ -4161,8 +4461,8 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) { HloOpcode::kDynamicSlice); } else { EXPECT_THAT(computation->root_instruction(), - op::DynamicSlice(op::Dot(op::Constant(), op::Constant()), - op::Concatenate())); + GmockMatch(m::DynamicSlice(m::Dot(m::Constant(), m::Constant()), + m::Concatenate()))); } } diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.cc b/tensorflow/compiler/xla/service/ar_crs_combiner.cc new file mode 100644 index 00000000000..c11452a6fbd --- /dev/null +++ b/tensorflow/compiler/xla/service/ar_crs_combiner.cc @@ -0,0 +1,286 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/ar_crs_combiner.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +namespace { + +namespace m = match; + +// If the argument instruction is a CRS in the sequence +// AR -> Convert -> Add -> CRS +// then return the AR in the sequence. +// TODO(b/117554291): Rewrite this to recognize more general patterns, +// not just the specific one of AR -> Add -> Convert -> CRS. +absl::optional MatchesArCrsPattern( + HloInstruction* instruction) { + HloInstruction *ar, *convert, *add, *crs; + if (Match(instruction, + m::CrossReplicaSum( + &crs, m::Add(&add, m::Op(), + m::Convert(&convert, + m::CrossReplicaSum(&ar, m::Op()))))) && + ar->users().size() == 1 && ar->shape().element_type() == BF16 && + convert->shape().element_type() == F32 && !crs->all_reduce_id()) { + return ar; + } + return absl::optional(); +} + +} // namespace + +absl::optional ArCrsCombiner::WhileFromBodyParameter( + HloInstruction* instruction) { + CHECK(HloOpcode::kParameter == instruction->opcode()); + HloComputation* computation = instruction->parent(); + auto caller_instructions = call_graph_->GetComputationCallers(computation); + if (caller_instructions.size() == 1) { + auto caller_instruction = caller_instructions[0]; + if (caller_instruction->opcode() == HloOpcode::kWhile) { + return caller_instruction; + } + } + return absl::optional(); +} + +std::vector ArCrsCombiner::GetAllTuples( + HloInstruction* instruction) { + if (instruction->opcode() == HloOpcode::kTuple) { + return {instruction}; + } + if (instruction->opcode() == HloOpcode::kDomain) { + return GetAllTuples(instruction->operands()[0]); + } + if (instruction->opcode() == HloOpcode::kParameter) { + auto maybe_while = WhileFromBodyParameter(instruction); + if (!maybe_while) { + return {}; + } + auto while_instr = *maybe_while; + auto init_tuples = GetAllTuples(while_instr->while_init()); + auto body_tuples = + GetAllTuples(while_instr->while_body()->root_instruction()); + if (init_tuples.empty() || body_tuples.empty()) { + return {}; + } + init_tuples.insert(init_tuples.end(), body_tuples.begin(), + body_tuples.end()); + return init_tuples; + } + if (instruction->opcode() == HloOpcode::kGetTupleElement) { + std::vector result_tuples; + for (auto tuple : GetAllTuples(instruction->operands()[0])) { + auto tmp_tuples = + GetAllTuples(tuple->mutable_operand(instruction->tuple_index())); + if (tmp_tuples.empty()) { + return {}; + } + result_tuples.insert(result_tuples.end(), tmp_tuples.begin(), + tmp_tuples.end()); + } + return result_tuples; + } + return {}; +} + +bool ArCrsCombiner::TupleElementsComputeSameValue( + HloInstruction* tuple_shaped_instruction, int64 i1, int64 i2, + absl::flat_hash_map* visited_pairs) { + auto tuples = GetAllTuples(tuple_shaped_instruction); + if (tuples.empty()) { + return false; + } + for (auto tuple : tuples) { + CHECK(tuple->opcode() == HloOpcode::kTuple); + if (!InstructionsComputeSameValue(tuple->mutable_operand(i1), + tuple->mutable_operand(i2), + visited_pairs)) { + return false; + } + } + return true; +} + +/* static */ +bool ArCrsCombiner::TestInstructionsComputeSameValue(HloInstruction* i1, + HloInstruction* i2) { + ArCrsCombiner combiner(/*num_spatial_partitions=*/2); + auto module = i1->parent()->parent(); + CHECK_EQ(module, i2->parent()->parent()); + combiner.call_graph_ = CallGraph::Build(module); + absl::flat_hash_map visited_pairs; + return combiner.InstructionsComputeSameValue(i1, i2, &visited_pairs); +} + +bool ArCrsCombiner::InstructionsComputeSameValue( + HloInstruction* i1, HloInstruction* i2, + absl::flat_hash_map* visited_pairs) { + if (i1 == i2) { + return true; + } + auto uid1 = i1->unique_id(); + auto uid2 = i2->unique_id(); + auto min_uid = std::min(uid1, uid2); + auto max_uid = std::max(uid1, uid2); + auto it = visited_pairs->find(min_uid); + if (it != visited_pairs->end() && max_uid == it->second) { + return true; + } + auto opcode1 = i1->opcode(); + auto operands1 = i1->operands(); + if (opcode1 != i2->opcode() || operands1.size() != i2->operands().size()) { + return false; + } + if (opcode1 == HloOpcode::kConstant || i1->IsCrossModuleAllReduce()) { + return i1->Identical( + *i2, + /*eq_operands=*/std::equal_to(), + /*eq_computations=*/std::equal_to(), + /*layout_sensitive=*/false); + } + visited_pairs->emplace(min_uid, max_uid); + for (int i = 0; i < operands1.size(); ++i) { + auto operand1 = operands1[i]; + auto operand2 = i2->operands()[i]; + if (!InstructionsComputeSameValue(operand1, operand2, visited_pairs)) { + return false; + } + } + if (opcode1 == HloOpcode::kGetTupleElement) { + if (i1->tuple_index() == i2->tuple_index()) { + return true; + } + return TupleElementsComputeSameValue(operands1[0], i1->tuple_index(), + i2->tuple_index(), visited_pairs); + } + return true; +} + +void ArCrsCombiner::GroupAllReducesById(HloModule* module) { + for (HloComputation* computation : module->MakeNonfusionComputations()) { + for (HloInstruction* instruction : computation->instructions()) { + auto ar = MatchesArCrsPattern(instruction); + if (ar) { + all_reduce_map_[*((*ar)->all_reduce_id())].push_back(*ar); + } + } + } +} + +void ArCrsCombiner::KeepProvablyEqualInstructionGroups() { + for (auto it : all_reduce_map_) { + auto instruction_vec = it.second; + CHECK_EQ(instruction_vec.size(), num_spatial_partitions_); + + auto instr_0 = instruction_vec[0]; + auto add_0 = instr_0->users()[0]->users()[0]; + CHECK(HloOpcode::kAdd == add_0->opcode()); + + for (int i = 1; i < instruction_vec.size(); ++i) { + auto instr_i = instruction_vec[i]; + auto add_i = instr_i->users()[0]->users()[0]; + CHECK(HloOpcode::kAdd == add_i->opcode()); + absl::flat_hash_map visited_pairs; + if (!InstructionsComputeSameValue(add_0, add_i, &visited_pairs)) { + all_reduce_map_.erase(it.first); + } + } + } +} + +StatusOr ArCrsCombiner::RewriteGraph() { + if (all_reduce_map_.empty()) { + return false; + } + + auto computation_is_addition = [](HloComputation* c) { + return c->instruction_count() == 3 && + Match(c->root_instruction(), m::Add(m::Parameter(), m::Parameter())); + }; + + for (auto it : all_reduce_map_) { + auto instruction_vec = it.second; + for (auto all_reduce : instruction_vec) { + auto parent_computation = all_reduce->parent(); + auto convert = all_reduce->users()[0]; + auto add = convert->users()[0]; + auto crs = add->users()[0]; + + if (!computation_is_addition(all_reduce->called_computations()[0]) || + !computation_is_addition(crs->called_computations()[0])) { + continue; + } + HloInstruction* other_summand = (add->operands()[0] == convert) + ? add->operands()[1] + : add->operands()[0]; + // Remove the AllReduce and replace the CRS with: + // AllReduce - (other_summand * (num_spatial_partitions_ - 1)) + TF_CHECK_OK( + all_reduce->ReplaceAllUsesWith(all_reduce->mutable_operand(0))); + crs->set_all_reduce_id(all_reduce->all_reduce_id()); + auto new_shape = crs->shape(); + HloInstruction* to_subtract; + if (num_spatial_partitions_ == 2) { + to_subtract = other_summand; + } else { + Literal partitions_minus_1_lit = Literal(new_shape); + partitions_minus_1_lit.PopulateWithValue( + num_spatial_partitions_ - 1); + auto partitions_minus_1_const = parent_computation->AddInstruction( + HloInstruction::CreateConstant(partitions_minus_1_lit.Clone())); + to_subtract = + parent_computation->AddInstruction(HloInstruction::CreateBinary( + new_shape, HloOpcode::kMultiply, other_summand, + partitions_minus_1_const)); + } + auto sub = + parent_computation->AddInstruction(HloInstruction::CreateBinary( + new_shape, HloOpcode::kSubtract, crs, to_subtract)); + TF_CHECK_OK(crs->ReplaceAllUsesWith(sub)); + TF_CHECK_OK(parent_computation->RemoveInstruction(all_reduce)); + } + } + + return true; +} + +StatusOr ArCrsCombiner::Run(HloModule* module) { + call_graph_ = CallGraph::Build(module); + + GroupAllReducesById(module); + + KeepProvablyEqualInstructionGroups(); + + return RewriteGraph(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.h b/tensorflow/compiler/xla/service/ar_crs_combiner.h new file mode 100644 index 00000000000..f6a7ef76ec3 --- /dev/null +++ b/tensorflow/compiler/xla/service/ar_crs_combiner.h @@ -0,0 +1,88 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_AR_CRS_COMBINER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_AR_CRS_COMBINER_H_ + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// Combine an AllReduce and a CrossReplicaSum when they are close to each other +// in the graph, to use an efficient CrossReplicaSum implementation that +// fully utilizes the interconnect bandwidth. +class ArCrsCombiner : public HloModulePass { + public: + ArCrsCombiner(int num_spatial_partitions) + : num_spatial_partitions_(num_spatial_partitions) {} + absl::string_view name() const override { return "ar-crs-combiner"; } + StatusOr Run(HloModule* module) override; + + // Helper method to allow testing of InstructionsComputeSameValue. + static bool TestInstructionsComputeSameValue(HloInstruction* i1, + HloInstruction* i2); + + private: + // If the passed instruction is a while parameter, and the while body is only + // called by a single while instruction, return the while instruction. + absl::optional WhileFromBodyParameter( + HloInstruction* instruction); + + // Returns a vector of tuple instructions. + // If all instructions that flow to "instruction" are tuples, return them. + // Otherwise, return an empty vector. + std::vector GetAllTuples(HloInstruction* instruction); + + // Checks whether two different elements in the same tuple compute the same + // value. + bool TupleElementsComputeSameValue( + HloInstruction* tuple_shaped_instruction, int64 i1, int64 i2, + absl::flat_hash_map* visited_pairs); + + // Returns whether the instructions i1 and i2 can be shown to evaluate to the + // same value. Handling WHILE requires recursion, which may cause us to visit + // the same instruction again. To avoid infinite loops, we pass a cache of + // visited instruction pairs. + bool InstructionsComputeSameValue( + HloInstruction* i1, HloInstruction* i2, + absl::flat_hash_map* visited_pairs); + + // Populates all_reduce_map_. + void GroupAllReducesById(HloModule* module); + + // Looks at each AllReduce group in all_reduce_map_, and keeps only the + // groups for which it's safe to move the AllReduce later in the HLO graph. + void KeepProvablyEqualInstructionGroups(); + + // Performs the graph rewrite that eliminates the early AllReduce and turns + // the later CRS into an AllReduce. + StatusOr RewriteGraph(); + + int num_spatial_partitions_; + + // Map from all-reduce ids to the all reduce instructions. + absl::flat_hash_map> all_reduce_map_; + + std::unique_ptr call_graph_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_AR_CRS_COMBINER_H_ diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc new file mode 100644 index 00000000000..9d5eaf63ccf --- /dev/null +++ b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc @@ -0,0 +1,415 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/ar_crs_combiner.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; + +class ArCrsCombinerTest : public HloTestBase {}; + +TEST_F(ArCrsCombinerTest, SameValueTestBasecase) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { + %p = f32[2,2] parameter(0) + %constant.f32.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32.2 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32.1, %constant.f32.2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue( + i1, module->entry_computation()->parameter_instruction(0))); + EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestNumOperands) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (p: f32[2,2]) -> ((f32[2,2]), (f32[2,2], f32[2,2])) { + %p = f32[2,2] parameter(0) + %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %tuple1 = (f32[2,2]) tuple(%constant.f32) + %tuple2 = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) + ROOT %tuple = ((f32[2,2]), (f32[2,2], f32[2,2])) tuple(%tuple1, %tuple2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestTupleElementSameIndex) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { + %p = f32[2,2] parameter(0) + %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %tuple.1 = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) + %get-tuple-element.1 = f32[2,2] get-tuple-element(%tuple.1), index=0 + %get-tuple-element.2 = f32[2,2] get-tuple-element(%tuple.1), index=0 + ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%get-tuple-element.1, %get-tuple-element.2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestTupleElementDifferentIndex1) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { + %p = f32[2,2] parameter(0) + %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %tuple.1 = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) + %get-tuple-element.1 = f32[2,2] get-tuple-element(%tuple.1), index=0 + %get-tuple-element.2 = f32[2,2] get-tuple-element(%tuple.1), index=1 + ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%get-tuple-element.1, %get-tuple-element.2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestTupleElementDifferentIndex2) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { + %p = f32[2,2] parameter(0) + %constant.f32.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32.2 = f32[2,2] constant(f32[2,2] {{2, 3}, {4, 5}}) + %tuple.1 = (f32[2,2], f32[2,2]) tuple(%constant.f32.1, %constant.f32.2) + %get-tuple-element.1 = f32[2,2] get-tuple-element(%tuple.1), index=0 + %get-tuple-element.2 = f32[2,2] get-tuple-element(%tuple.1), index=1 + ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%get-tuple-element.1, %get-tuple-element.2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestWhile1) { + const char* module_str = R"( +HloModule foobar + +%condition (x: (f32[2,2], f32[2,2])) -> pred[] { + %x = (f32[2,2], f32[2,2]) parameter(0) + %constant.0 = s32[] constant(0) + %constant.1 = s32[] constant(1) + ROOT %greater-than = pred[] greater-than(s32[] %constant.1, s32[] %constant.0) +} + +%body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) { + %x = (f32[2,2], f32[2,2]) parameter(0) + %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %get-tuple-element.1 = f32[2,2] get-tuple-element(%x), index=0 + %get-tuple-element.2 = f32[2,2] get-tuple-element(%x), index=1 + %add.1 = f32[2,2] add(%get-tuple-element.1, %constant.f32) + %add.2 = f32[2,2] add(%get-tuple-element.2, %constant.f32) + ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%add.1, %add.2) +} + +ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) { + %constant.f32 = f32[2,2] constant(f32[2,2] {{3, 4}, {5, 6}}) + %init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) + ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition, body=%body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_while = module->entry_computation()->root_instruction(); + auto body_tuple = root_while->while_body()->root_instruction(); + auto i1 = body_tuple->operands()[0]; + auto i2 = body_tuple->operands()[1]; + EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestWhile2) { + const char* module_str = R"( +HloModule foobar + +%condition (x: (f32[2,2], f32[2,2])) -> pred[] { + %x = (f32[2,2], f32[2,2]) parameter(0) + %constant.0 = s32[] constant(0) + %constant.1 = s32[] constant(1) + ROOT %greater-than = pred[] greater-than(s32[] %constant.1, s32[] %constant.0) +} + +%body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) { + %x = (f32[2,2], f32[2,2]) parameter(0) + %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %get-tuple-element.1 = f32[2,2] get-tuple-element(%x), index=0 + %get-tuple-element.2 = f32[2,2] get-tuple-element(%x), index=1 + %add.1 = f32[2,2] add(%get-tuple-element.1, %constant.f32) + %add.2 = f32[2,2] add(%get-tuple-element.2, %constant.f32) + ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%add.1, %add.2) +} + +ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) { + %constant.f32.1 = f32[2,2] constant(f32[2,2] {{3, 4}, {5, 6}}) + %constant.f32.2 = f32[2,2] constant(f32[2,2] {{3, 4}, {7, 8}}) + %init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32.1, %constant.f32.2) + ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition, body=%body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_while = module->entry_computation()->root_instruction(); + auto body_tuple = root_while->while_body()->root_instruction(); + auto i1 = body_tuple->operands()[0]; + auto i2 = body_tuple->operands()[1]; + EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestWhile3) { + const char* module_str = R"( +HloModule foobar + +%condition (x: (f32[2,2], f32[2,2])) -> pred[] { + %x = (f32[2,2], f32[2,2]) parameter(0) + %constant.0 = s32[] constant(0) + %constant.1 = s32[] constant(1) + ROOT %greater-than = pred[] greater-than(s32[] %constant.1, s32[] %constant.0) +} + +%body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) { + %x = (f32[2,2], f32[2,2]) parameter(0) + %constant.f32.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32.2 = f32[2,2] constant(f32[2,2] {{3, 4}, {1, 2}}) + %get-tuple-element.1 = f32[2,2] get-tuple-element(%x), index=0 + %get-tuple-element.2 = f32[2,2] get-tuple-element(%x), index=1 + %add.1 = f32[2,2] add(%get-tuple-element.1, %constant.f32.1) + %add.2 = f32[2,2] add(%get-tuple-element.2, %constant.f32.2) + ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%add.1, %add.2) +} + +ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) { + %constant.f32 = f32[2,2] constant(f32[2,2] {{3, 4}, {5, 6}}) + %init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) + ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition, body=%body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_while = module->entry_computation()->root_instruction(); + auto body_tuple = root_while->while_body()->root_instruction(); + auto i1 = body_tuple->operands()[0]->operands()[0]; // %get-tuple-element.1 + auto i2 = body_tuple->operands()[1]->operands()[0]; // %get-tuple-element.2 + EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, RewritePatternArConvertAddCrs) { + const char* module_str = R"( +HloModule foobar + +%binary_add (a: bf16[], b: bf16[]) -> bf16[] { + %a = bf16[] parameter(0) + %b = bf16[] parameter(1) + ROOT %add = bf16[] add(%a, %b) +} + +%sum.f32 (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { + %p = f32[2,2] parameter(0) + %constant.bf16 = bf16[2,2] constant(bf16[2,2] {{1, 2}, {3, 4}}) + %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + + %cross-replica-sum.ar.1 = bf16[2,2] + cross-replica-sum(%constant.bf16), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%binary_add, + sharding={maximal device=0} + %convert.1 = f32[2,2] + convert(%cross-replica-sum.ar.1), + sharding={maximal device=0} + %add.1 = f32[2,2] + add(%constant.f32, %convert.1), + sharding={maximal device=0} + %cross-replica-sum.1 = f32[2,2] + cross-replica-sum(%add.1), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=0} + + %cross-replica-sum.ar.2 = bf16[2,2] + cross-replica-sum(%constant.bf16), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%binary_add, + sharding={maximal device=1} + %convert.2 = f32[2,2] + convert(%cross-replica-sum.ar.2), + sharding={maximal device=1} + %add.2 = f32[2,2] + add(%constant.f32, %convert.2), + sharding={maximal device=1} + %cross-replica-sum.2 = f32[2,2] + cross-replica-sum(%add.2), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=1} + + ROOT %tuple = (f32[2,2], f32[2,2]) + tuple(%cross-replica-sum.1, %cross-replica-sum.2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::Subtract(op::CrossReplicaSum(), op::Constant()), + op::Subtract(op::CrossReplicaSum(), op::Constant()))); + auto sub = module->entry_computation()->root_instruction()->operands()[0]; + auto crs_after = sub->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + ASSERT_EQ(replica_groups_before.size(), replica_groups_after.size()); + for (int i = 0; i < replica_groups_before.size(); ++i) { + // Somewhat verbose way to compare the replica_ids, because EqualsProto + // is not available in the open-source build. + auto group_before = replica_groups_before[i]; + std::vector ids_before(group_before.replica_ids().begin(), + group_before.replica_ids().end()); + auto group_after = replica_groups_after[i]; + std::vector ids_after(group_after.replica_ids().begin(), + group_after.replica_ids().end()); + EXPECT_EQ(ids_before, ids_after); + } +} + +TEST_F(ArCrsCombinerTest, OtherSummandNotTheSameDontRewrite) { + const char* module_str = R"( +HloModule foobar + +%binary_add (a: bf16[], b: bf16[]) -> bf16[] { + %a = bf16[] parameter(0) + %b = bf16[] parameter(1) + ROOT %add = bf16[] add(%a, %b) +} + +%sum.f32 (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { + %p = f32[2,2] parameter(0) + %constant.bf16 = bf16[2,2] constant(bf16[2,2] {{1, 2}, {3, 4}}) + %constant.f32.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32.2 = f32[2,2] constant(f32[2,2] {{3, 4}, {5, 6}}) + + %cross-replica-sum.ar.1 = bf16[2,2] + cross-replica-sum(%constant.bf16), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%binary_add, + sharding={maximal device=0} + %convert.1 = f32[2,2] + convert(%cross-replica-sum.ar.1), + sharding={maximal device=0} + %add.1 = f32[2,2] + add(%constant.f32.1, %convert.1), + sharding={maximal device=0} + %cross-replica-sum.1 = f32[2,2] + cross-replica-sum(%add.1), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=0} + + %cross-replica-sum.ar.2 = bf16[2,2] + cross-replica-sum(%constant.bf16), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%binary_add, + sharding={maximal device=1} + %convert.2 = f32[2,2] + convert(%cross-replica-sum.ar.2), + sharding={maximal device=1} + %add.2 = f32[2,2] + add(%constant.f32.2, %convert.2), + sharding={maximal device=1} + %cross-replica-sum.2 = f32[2,2] + cross-replica-sum(%add.2), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=1} + + ROOT %tuple = (f32[2,2], f32[2,2]) + tuple(%cross-replica-sum.1, %cross-replica-sum.2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_FALSE(changed); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc index f70f6ddfec6..0e6ca1871b3 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -107,19 +107,37 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { } std::unique_ptr Mean( - int64 element_count, HloInstruction* operand, + HloInstruction* element_count, HloInstruction* operand, const std::function)>& add_instruction) { - HloInstruction* elem_count_recip = - add_instruction(HloInstruction::CreateBroadcast( - operand->shape(), - add_instruction(HloInstruction::CreateConvert( - ShapeUtil::MakeShape(operand->shape().element_type(), {}), - add_instruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR0(1.0 / element_count))))), - {})); - return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kMultiply, - operand, elem_count_recip); + auto broadcast = add_instruction( + HloInstruction::CreateBroadcast(operand->shape(), element_count, {})); + return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kDivide, + operand, broadcast); + } + + std::unique_ptr DynamicElementCountPerFeature( + HloInstruction* operand, int64 feature_index, + const std::function)>& + add_instruction) { + auto elements_per_feature_u32 = add_instruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + + for (int64 i = 0; i < ShapeUtil::Rank(operand->shape()); ++i) { + if (i == feature_index) { + continue; + } + auto dynamic_dimension_size = + add_instruction(HloInstruction::CreateGetDimensionSize( + ShapeUtil::MakeShape(U32, {}), operand, i)); + elements_per_feature_u32 = add_instruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(U32, {}), HloOpcode::kMultiply, + dynamic_dimension_size, elements_per_feature_u32)); + } + + return HloInstruction::CreateConvert( + ShapeUtil::MakeShape(operand->shape().element_type(), {}), + elements_per_feature_u32); } // Replaces the existing HLO instruction old_instruction, with @@ -195,9 +213,6 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( const Shape operand_shape = operand->shape(); PrimitiveType ptype = operand_shape.element_type(); int64 feature_index = batch_norm->feature_index(); - const int64 feature_count = operand_shape.dimensions(feature_index); - const int64 size_in_elements = ShapeUtil::ElementsIn(operand_shape); - int64 elements_per_feature_int64 = size_in_elements / feature_count; HloInstruction* scale = batch_norm->mutable_operand(1); HloInstruction* offset = batch_norm->mutable_operand(2); @@ -220,6 +235,9 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( } } + auto elements_per_feature = + add(DynamicElementCountPerFeature(operand, feature_index, add)); + auto scale_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index})); @@ -243,13 +261,13 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( add_reduce_computation)); // E[X]. - auto mean = add(Mean(elements_per_feature_int64, sum, add)); + auto mean = add(Mean(elements_per_feature, sum, add)); auto mean_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index})); // E[X^2]. - auto square_mean = add(Mean(elements_per_feature_int64, squared_sum, add)); + auto square_mean = add(Mean(elements_per_feature, squared_sum, add)); // E^2[X]. auto mean_square = @@ -458,9 +476,8 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( int64 feature_index = batch_norm->feature_index(); - const int64 size_in_elements = ShapeUtil::ElementsIn(activation_shape); - const int64 feature_count = activation_shape.dimensions(feature_index); - const int64 elements_per_feature_int64 = size_in_elements / feature_count; + auto elements_per_feature = + add(DynamicElementCountPerFeature(activation, feature_index, add)); auto zero_literal = LiteralUtil::CreateR0(0.0f); TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype)); @@ -553,15 +570,9 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( add_binary(activation_shape, HloOpcode::kMultiply, scale_broadcasted, rsqrt_var_add_epsilon_broadcasted); - scale_times_rsqrt_var_add_epsilon = add( - Mean(elements_per_feature_int64, scale_times_rsqrt_var_add_epsilon, add)); + scale_times_rsqrt_var_add_epsilon = + add(Mean(elements_per_feature, scale_times_rsqrt_var_add_epsilon, add)); - auto elements_per_feature_literal = - LiteralUtil::CreateR0(elements_per_feature_int64); - TF_ASSIGN_OR_RETURN(elements_per_feature_literal, - elements_per_feature_literal.Convert(ptype)); - auto elements_per_feature = add( - HloInstruction::CreateConstant(std::move(elements_per_feature_literal))); auto i1 = add_binary(activation_shape, HloOpcode::kMultiply, grad_output, add(HloInstruction::CreateBroadcast( activation_shape, elements_per_feature, {}))); diff --git a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc index 08cf8026177..8e8fbbd935b 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc @@ -36,7 +36,21 @@ limitations under the License. namespace xla { namespace { -using BatchNormExpanderTest = HloTestBase; +class BatchNormExpanderTest : public HloTestBase { + protected: + // BatchNorm should have a dynamic sized dividor for mean operations. + int64 CountGetDimensionSize(const HloModule& module) { + int64 count = 0; + for (HloComputation* comp : module.computations()) { + for (HloInstruction* inst : comp->instructions()) { + if (inst->opcode() == HloOpcode::kGetDimensionSize) { + count++; + } + } + } + return count; + } +}; // Test that we expand BatchNormTraining. TEST_F(BatchNormExpanderTest, BatchNormTraining) { @@ -68,6 +82,7 @@ TEST_F(BatchNormExpanderTest, BatchNormTraining) { /*rewrite_grad_op=*/true); ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); + EXPECT_EQ(CountGetDimensionSize(*module), 3); // Make sure this operation is expanded. EXPECT_EQ(root->opcode(), HloOpcode::kTuple); } @@ -110,6 +125,7 @@ TEST_F(BatchNormExpanderTest, BatchNormGrad) { /*rewrite_grad_op=*/true); ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); + EXPECT_EQ(CountGetDimensionSize(*module), 3); // Make sure this operation is expanded. EXPECT_EQ(root->opcode(), HloOpcode::kTuple); } diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 40c012a5e42..8d7c6244785 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -746,8 +746,7 @@ StatusOr> BufferAssigner::Run( LogicalBuffer::AlignmentFunction color_alignment, bool allow_input_output_aliasing, bool allocate_buffers_for_constants, BufferLiveness::Colorer colorer, ReuseAllocationFunction reuse_checker) { - BufferAssigner assigner(allow_input_output_aliasing, - allocate_buffers_for_constants, std::move(colorer), + BufferAssigner assigner(allocate_buffers_for_constants, std::move(colorer), std::move(reuse_checker)); return assigner.CreateAssignment(module, std::move(hlo_ordering), std::move(buffer_size), @@ -1434,33 +1433,40 @@ BufferAssigner::MergeColocatedBufferSets( computation == module->entry_computation(); }; + std::vector set_can_be_merged(colocated_buffer_sets.size(), true); + + // Do not merge if one of the sets includes live outs, entry parameters or + // constants. + // + // Buffer liveness does not report the correct live range for entry + // parameter and live out buffers so we have to special case them here. On + // backends that support constant buffer allocations, constant buffers are + // assigned globals in readonly storage so we can't merge colocated buffer + // sets containing constants with colocated buffer sets containing writing + // instructions or other constants. + // + // Moreover (on the CPU/GPU backends) the entry parameter buffers belong to + // the caller of the executable so we can't write to entry parameters + // either, and the argument for not merging constants also applies to entry + // parameters. + for (int64 i = 0; i < colocated_buffer_sets.size(); ++i) { + for (auto& buffer : colocated_buffer_sets[i]) { + if (buffer_liveness.MaybeLiveOut(*buffer) || + is_entry_parameter(*buffer) || + buffer->instruction()->opcode() == HloOpcode::kConstant) { + set_can_be_merged[i] = false; + break; + } + } + } + // Returns true if the two colocated buffer sets (specified by their indices // into the colocated_buffer_sets) can be merged into a single set. auto cannot_merge_buffer_sets = [&colocated_buffer_sets, &buffer_liveness, &buffer_size, - &is_entry_parameter](int64 i, int64 j) { - // Do not merge if one of the sets includes live outs, entry parameters or - // constants. - // - // Buffer liveness does not report the correct live range for entry - // parameter and live out buffers so we have to special case them here. On - // backends that support constant buffer allocations, constant buffers are - // assigned globals in readonly storage so we can't merge colocated buffer - // sets containing constants with colocated buffer sets containing writing - // instructions or other constants. - // - // Moreover (on the CPU/GPU backends) the entry parameter buffers belong to - // the caller of the executable so we can't write to entry parameters - // either, and the argument for not merging constants also applies to entry - // parameters. - for (int64 key : {i, j}) { - for (auto& buffer : colocated_buffer_sets[key]) { - if (buffer_liveness.MaybeLiveOut(*buffer) || - is_entry_parameter(*buffer) || - buffer->instruction()->opcode() == HloOpcode::kConstant) { - return true; - } - } + &set_can_be_merged](int64 i, int64 j) { + if (!set_can_be_merged[i] || !set_can_be_merged[j]) { + return true; } // Colocated sets satisfy the invariant that all buffers within a set have diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index d8e1612b899..0a9fdede803 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -545,12 +545,10 @@ class BufferAssigner { ReuseAllocationFunction reuse_checker = nullptr); private: - BufferAssigner(bool allow_input_output_aliasing, - bool allocate_buffers_for_constants, + BufferAssigner(bool allocate_buffers_for_constants, BufferLiveness::Colorer colorer, ReuseAllocationFunction reuse_checker) - : allow_input_output_aliasing_(allow_input_output_aliasing), - allocate_buffers_for_constants_(allocate_buffers_for_constants), + : allocate_buffers_for_constants_(allocate_buffers_for_constants), colorer_(colorer), reuse_checker_(reuse_checker) {} virtual ~BufferAssigner() = default; @@ -640,10 +638,6 @@ class BufferAssigner { LogicalBuffer::Color::Hasher> SplitBuffersByColor(const absl::flat_hash_set& buffers); - // If true, buffer assignments assumes that input parameter buffers and output - // buffers can be shared if their sizes match. - bool allow_input_output_aliasing_; - // If true, allocate buffers for constant instructions. bool allocate_buffers_for_constants_; diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index b1fc50cb188..8f482e6ba8c 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -137,8 +137,7 @@ class BufferAssignmentTest : public HloTestBase { } std::unique_ptr RunBufferAssignmentWithInstructionSequence( - HloModule* module, - absl::Span instruction_sequence, + HloModule* module, absl::Span instruction_sequence, int64 alignment = 1) { HloSchedule schedule(module); schedule.set_sequence(module->entry_computation(), instruction_sequence); @@ -1853,7 +1852,7 @@ class WhileBufferAssignmentTest : public HloTestBase { std::unique_ptr RunBufferAssignment(HloModule* module, int64 alignment = 1) { HloSchedule schedule = - ScheduleModule(*module, ByteSizeOf).ConsumeValueOrDie(); + ScheduleModule(module, ByteSizeOf).ConsumeValueOrDie(); return BufferAssigner::Run( module, absl::make_unique(schedule), ByteSizeOf, @@ -2162,7 +2161,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { // nodes are traversed during BufferAssignment. TF_ASSERT_OK_AND_ASSIGN( HloSchedule schedule, - ScheduleModule(*module, [](const BufferValue& buffer) { + ScheduleModule(module.get(), [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/sizeof(void*)); })); @@ -2391,15 +2390,16 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { RunCopyInsertion(module.get()); HloSchedule schedule = - ScheduleModule(*module, ByteSizeOf).ConsumeValueOrDie(); + ScheduleModule(module.get(), ByteSizeOf).ConsumeValueOrDie(); // To trigger b/38494731, we want a specific Hlo schedule for the // root computation, so we overwrite that entry with a manually // crafted sequence. - schedule.set_sequence(module->entry_computation(), - {input1, weights1, one, output1, while1->operand(0), - while1, input0, weights0, zero, output0, - while0->operand(0), while0, gte0, gte1, root_add}); + schedule.set_sequence( + module->entry_computation(), + {input1, weights1, one, output1, while1->mutable_operand(0), while1, + input0, weights0, zero, output0, while0->mutable_operand(0), while0, + gte0, gte1, root_add}); // If this ASSERT fails, we constructed a bogus sequence above and this test // itself is buggy. diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index aeee543e843..40825a78716 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -117,7 +117,7 @@ TEST_F(BufferLivenessTest, ElementwiseChain) { auto log = builder.AddInstruction( HloInstruction::CreateUnary(vec_, HloOpcode::kLog, exp)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); auto liveness = @@ -164,7 +164,7 @@ TEST_F(BufferLivenessTest, MultipleEntryParameters_Sequential) { auto add = builder.AddInstruction( HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry = module->AddEntryComputation(builder.Build()); HloSchedule schedule(module.get()); @@ -213,7 +213,7 @@ TEST_F(BufferLivenessTest, NonElementwiseOperand) { auto reverse = builder.AddInstruction(HloInstruction::CreateReverse(vec_, negate, {0})); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); auto liveness = @@ -247,7 +247,7 @@ TEST_F(BufferLivenessTest, OverlappedBuffers) { auto add = builder.AddInstruction( HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); auto liveness = @@ -289,7 +289,7 @@ TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) { auto add = builder.AddInstruction( HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); HloSchedule schedule(module.get()); @@ -336,7 +336,7 @@ TEST_F(BufferLivenessTest, RootInstructionIsNotLastInSequentialOrder) { HloInstruction::CreateSend(recv_done, token, /*channel_id=*/1)); auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build(add)); HloSchedule schedule(module.get()); @@ -373,7 +373,7 @@ TEST_F(BufferLivenessTest, TupleLiveOut) { auto outer_tuple = builder.AddInstruction(HloInstruction::CreateTuple({inner_tuple, exp})); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); auto liveness = @@ -393,7 +393,7 @@ TEST_F(BufferLivenessTest, TupleLiveOut) { TEST_F(BufferLivenessTest, EmbeddedComputation) { // Test MaybeLiveOut and MayInterfere for embedded computation. - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto embedded_builder = HloComputation::Builder(TestName() + "_embedded"); auto embedded_param = embedded_builder.AddInstruction( @@ -450,7 +450,7 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) { builder.AddInstruction(HloInstruction::CreateGetTupleElement( inner_tuple0.shape(), tuple_constant, 0)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); auto liveness = @@ -576,7 +576,7 @@ TEST_F(BufferLivenessTest, DependentTupleElements) { auto tuple_root = builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(BuildDummyComputation()); module->AddEmbeddedComputation(builder.Build()); @@ -611,8 +611,8 @@ TEST_F(BufferLivenessTest, DependentTupleElements) { class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { protected: // Builds and runs a computation (see test case computation graphs below). - std::unique_ptr BuildModule(const bool update_uses_tuple_element1, - const bool fuse_gte0) { + std::unique_ptr BuildModule( + const bool update_uses_tuple_element1, const bool fuse_gte0) { auto builder = HloComputation::Builder(TestName()); // Create param0 Tuple. Shape data_shape = ShapeUtil::MakeShape(F32, {8}); @@ -646,7 +646,7 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { builder.AddInstruction( HloInstruction::CreateTuple({gte0, dynamic_update_slice})); // Build module and get reference to entry computation. - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); auto* computation = module->entry_computation(); // Create fusion instruction based on number of tuple element 1 users. @@ -802,7 +802,7 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest { auto tuple_root = builder.AddInstruction( HloInstruction::CreateTuple({gte0, dynamic_update_slice})); // Build module and get reference to entry computation. - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(BuildDummyComputation()); module->AddEmbeddedComputation(builder.Build()); // Run BufferLiveness on 'module'. diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index bdd5069632e..7987343bfaf 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -325,6 +325,15 @@ bool CallGraph::IsFlattened() const { return true; } +std::vector CallGraph::GetComputationCallers( + HloComputation* c) { + std::vector callers; + for (auto callsite : GetNode(c).caller_callsites()) { + callers.push_back(callsite.instruction()); + } + return callers; +} + std::pair CallGraph::NearestAncestorsInSameComputation(HloInstruction* a, HloInstruction* b) const { diff --git a/tensorflow/compiler/xla/service/call_graph.h b/tensorflow/compiler/xla/service/call_graph.h index cb56f4789d0..05c7c998738 100644 --- a/tensorflow/compiler/xla/service/call_graph.h +++ b/tensorflow/compiler/xla/service/call_graph.h @@ -236,6 +236,10 @@ class CallGraph { // FlattenCallGraph. bool IsFlattened() const; + // Returns a vector of instructions calling the passed computation. + // (Often a vector of size 1.) + std::vector GetComputationCallers(HloComputation* c); + string ToString() const; private: diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index 67132274c0d..1965925fa7f 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -86,15 +86,15 @@ CompileOnlyService::CompileAheadOfTime( Executable::DumpToDirectory(per_host_path, filename, hlo_snapshot)); } - const auto& program_shape = instance.computation.host_program_shape(); ExecutionOptions execution_options; *execution_options.mutable_debug_options() = debug_options; *execution_options.mutable_shape_with_output_layout() = - *instance.result_layout; + instance.result_layout->ToProto(); TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, - CreateModuleConfig(program_shape, instance.argument_layouts, - &execution_options)); + CreateModuleConfig( + ProgramShape(instance.computation.host_program_shape()), + instance.argument_layouts, &execution_options)); TF_ASSIGN_OR_RETURN( std::unique_ptr hlo_module, diff --git a/tensorflow/compiler/xla/service/computation_placer.h b/tensorflow/compiler/xla/service/computation_placer.h index c899ffb9dc5..844b42a38d7 100644 --- a/tensorflow/compiler/xla/service/computation_placer.h +++ b/tensorflow/compiler/xla/service/computation_placer.h @@ -105,8 +105,6 @@ class ComputationPlacer { // Map from platform kind to computation placer singleton. static std::map* GetPlatformComputationPlacers(); - se::Platform::Id platform_id_; - TF_DISALLOW_COPY_AND_ASSIGN(ComputationPlacer); }; diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc index 7f7f1503a09..95c7724c3c9 100644 --- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc +++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc @@ -142,16 +142,16 @@ std::vector GetMaskIds(int64 group_size, int64 group_count) { // Finally we use the Eq op of these two broadcasted constants and get the // desired mask. HloInstruction* GetExpandedFilterMask( - const Shape& filter_shape, int64 input_feature_dim, - int64 output_feature_dim, int64 group_count, + const Shape& filter_shape, int64 kernel_input_feature_dim, + int64 kernel_output_feature_dim, int64 group_count, const std::function)>& add_instruction) { Shape expanded_filter_shape = - ExpandedFilterShape(filter_shape, group_count, input_feature_dim); + ExpandedFilterShape(filter_shape, group_count, kernel_input_feature_dim); Shape mask_shape = ShapeUtil::MakeShape( S32, AsInt64Slice(expanded_filter_shape.dimensions())); - int64 output_feature = filter_shape.dimensions(output_feature_dim); - int64 group_size = filter_shape.dimensions(input_feature_dim); + int64 output_feature = filter_shape.dimensions(kernel_output_feature_dim); + int64 group_size = filter_shape.dimensions(kernel_input_feature_dim); // Create a 'input_feature' sized linspace and 'output_feature' sized linspace // that will be broadcasted into perpendicular dimensions and compared. @@ -159,15 +159,14 @@ HloInstruction* GetExpandedFilterMask( GetMaskIds(group_size, group_count); const std::vector output_feature_filter_mask = GetMaskIds(output_feature / group_count, group_count); - auto mask1 = add_instruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1(input_feature_filter_mask))); - auto broadcasted_mask1 = add_instruction( - HloInstruction::CreateBroadcast(mask_shape, mask1, {input_feature_dim})); + auto broadcasted_mask1 = add_instruction(HloInstruction::CreateBroadcast( + mask_shape, mask1, {kernel_input_feature_dim})); auto mask2 = add_instruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1(output_feature_filter_mask))); - auto broadcasted_mask2 = add_instruction( - HloInstruction::CreateBroadcast(mask_shape, mask2, {output_feature_dim})); + auto broadcasted_mask2 = add_instruction(HloInstruction::CreateBroadcast( + mask_shape, mask2, {kernel_output_feature_dim})); // Compare the broadcasted output feature linspace to the input feature // linspace to create a diagonal predicate. @@ -189,91 +188,203 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { }; auto dim_numbers = convolution->convolution_dimension_numbers(); - int64 input_feature_dim = dim_numbers.kernel_input_feature_dimension(); - int64 group_size = filter->shape().dimensions(input_feature_dim); - int64 output_feature_dim = dim_numbers.kernel_output_feature_dimension(); - auto expanded_filter_shape = - ExpandedFilterShape(filter->shape(), group_count, input_feature_dim); - HloInstruction* filter_mask = GetExpandedFilterMask( - filter->shape(), input_feature_dim, output_feature_dim, group_count, add); + int64 kernel_input_feature_dim = dim_numbers.kernel_input_feature_dimension(); + int64 group_size = filter->shape().dimensions(kernel_input_feature_dim); + int64 kernel_output_feature_dim = + dim_numbers.kernel_output_feature_dimension(); + auto expanded_filter_shape = ExpandedFilterShape(filter->shape(), group_count, + kernel_input_feature_dim); + HloInstruction* filter_mask = + GetExpandedFilterMask(filter->shape(), kernel_input_feature_dim, + kernel_output_feature_dim, group_count, add); HloInstruction* expanded_filter; if (group_size == 1) { bool depthwise_separable = - (group_count == filter->shape().dimensions(output_feature_dim)); + (group_count == filter->shape().dimensions(kernel_output_feature_dim)); // If the code generator handles depthwise separable convolutions // inherently, then no filter expansion is needed. if (!filter_expansion_ && depthwise_separable) { - const int64 old_kernel_input_feature_dimension = - dim_numbers.kernel_input_feature_dimension(); - const int64 old_kernel_output_feature_dimension = - dim_numbers.kernel_output_feature_dimension(); - - // For depthwise convolutions, we want the kernel input feature dimension - // to be smaller than the output feature dimension. If that's not the - // case, we swap the dimensions. - if (old_kernel_input_feature_dimension > - old_kernel_output_feature_dimension) { - Shape reshaped_filter_shape = filter->shape(); - auto& dimensions = *reshaped_filter_shape.mutable_dimensions(); - std::swap(dimensions[old_kernel_input_feature_dimension], - dimensions[old_kernel_output_feature_dimension]); - - auto reshaped_filter = - add(HloInstruction::CreateReshape(reshaped_filter_shape, filter)); - - dim_numbers.set_kernel_input_feature_dimension( - old_kernel_output_feature_dimension); - - dim_numbers.set_kernel_output_feature_dimension( - old_kernel_input_feature_dimension); - - auto new_convolution = HloInstruction::CreateConvolve( - convolution->shape(), convolution->mutable_operand(0), - reshaped_filter, group_count, convolution->window(), dim_numbers, - convolution->precision_config()); - - TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( - convolution, std::move(new_convolution))); - } return Status::OK(); } // We want to repeat 'filter' in the 'input_feature_dim' dimension // 'group_count' times. Shape reshaped_filter_shape = - ShapeUtil::DeleteDimension(input_feature_dim, filter->shape()); + ShapeUtil::DeleteDimension(kernel_input_feature_dim, filter->shape()); auto reshaped_filter = add(HloInstruction::CreateReshape(reshaped_filter_shape, filter)); std::vector broadcast_dims; for (int64 i = 0; i < filter->shape().dimensions_size(); ++i) { - if (i == input_feature_dim) { + if (i == kernel_input_feature_dim) { continue; } broadcast_dims.push_back(i); } expanded_filter = add(HloInstruction::CreateBroadcast( expanded_filter_shape, reshaped_filter, broadcast_dims)); + + auto zero = add(HloInstruction::CreateConstant( + LiteralUtil::Zero(expanded_filter_shape.element_type()))); + auto zero_filter = + add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {})); + auto new_filter = add(HloInstruction::CreateTernary( + expanded_filter_shape, HloOpcode::kSelect, filter_mask, expanded_filter, + zero_filter)); + + auto new_convolution = HloInstruction::CreateConvolve( + convolution->shape(), convolution->mutable_operand(0), new_filter, + /*feature_group_count=*/1, convolution->window(), dim_numbers, + convolution->precision_config()); + TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( + convolution, std::move(new_convolution))); } else { - // We could possibly also use reshape, broadcast, reshape instead of concat - // here, but it would require more complex code, and for depthwise - // convolution we would never end up in this branch. - std::vector concat_operands(group_count, filter); - expanded_filter = add(HloInstruction::CreateConcatenate( - expanded_filter_shape, concat_operands, input_feature_dim)); + int64 activation_input_feature_dim = dim_numbers.input_feature_dimension(); + + int64 output_feature = + filter->shape().dimensions(kernel_output_feature_dim); + + // If group_count == output_feature, then we map those grouped convolutions + // onto depthwise convolution. This is done by adding an additional spatial + // dimension to the activations, kernel, and the output. + // E.g., we would turn + // [2, 12]{B, IF} conv [3, 4]{IF, OF} into + // [3, 2, 4]{S, B, IF} depth conv [3, 1, 4]{S, IF, OF}, where S is the + // additional spatial dimension. The generated convolution output will be + // [1, 2, 4]{S, B, OF} and then reshape the output back to [2, 4] {B, OF}. + + if (group_count == output_feature && !filter_expansion_) { + auto filter = convolution->mutable_operand(1); + auto activation = convolution->mutable_operand(0); + + // Add spatial dimension to the activation, and reshape. + Shape reshaped_activation_shape = activation->shape(); + ShapeUtil::AppendMajorDimension(group_size, &reshaped_activation_shape); + + int64 new_spatial_dim = reshaped_activation_shape.dimensions().size() - 1; + + reshaped_activation_shape.set_dimensions(activation_input_feature_dim, + group_count); + activation = add( + HloInstruction::CreateReshape(reshaped_activation_shape, activation)); + + // Add spatial dimension to the filter, and reshape. + Shape reshaped_filter_shape = filter->shape(); + ShapeUtil::AppendMajorDimension(1, &reshaped_filter_shape); + + filter = + add(HloInstruction::CreateReshape(reshaped_filter_shape, filter)); + + Shape new_output_shape = convolution->shape(); + ShapeUtil::AppendMajorDimension(1, &new_output_shape); + + // Edit convolution dimension numbers. Note that kernel_input_feature_dim + // now becomes a spatial dimension, and the newly added dimension of size + // 1 is the new kernel_input_feature_dim. + dim_numbers.add_input_spatial_dimensions(new_spatial_dim); + dim_numbers.add_kernel_spatial_dimensions(kernel_input_feature_dim); + dim_numbers.set_kernel_input_feature_dimension(new_spatial_dim); + dim_numbers.add_output_spatial_dimensions(new_spatial_dim); + + // Add window for the new spatial dimension. + Window new_window = convolution->window(); + auto* dim = new_window.add_dimensions(); + dim->set_window_dilation(1); + dim->set_base_dilation(1); + dim->set_stride(1); + dim->set_size(group_size); + + auto new_convolution = add(HloInstruction::CreateConvolve( + new_output_shape, activation, filter, group_count, new_window, + dim_numbers, convolution->precision_config())); + + // Delete the extra spatial dimension, and reshape. + Shape reshaped_convolution_shape = + ShapeUtil::DeleteDimension(new_spatial_dim, new_convolution->shape()); + auto reshaped_convolution = HloInstruction::CreateReshape( + reshaped_convolution_shape, new_convolution); + + TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( + convolution, std::move(reshaped_convolution))); + + } else { + // The filter expansion mechanism adds zeroes in the kernel. + // For an OF = 12, IF = 6, and kernel IF = 2, the expanded filter mask + // would look like (IF on the Y-axis, OF on the X-axis) + // 1 1 1 1 0 0 0 0 0 0 0 0 + // 1 1 1 1 0 0 0 0 0 0 0 0 + // 0 0 0 0 1 1 1 1 0 0 0 0 + // 0 0 0 0 1 1 1 1 0 0 0 0 + // 0 0 0 0 0 0 0 0 1 1 1 1 + // 0 0 0 0 0 0 0 0 1 1 1 1 + // + // Instead of convolving the above with the input, we instead slice the + // kernel into three kernels, each containing islands of 1s from the + // filter above. We also slice the activations in the IF dimension with + // each slice of size = group_size. For each slice, we perform + // convolutions, and concatenate the generated outputs in the output OF + // dimension. + + std::vector sliced_convolutions; + auto activation = convolution->mutable_operand(0); + std::vector slice_strides(filter->shape().dimensions_size(), 1); + std::vector filter_slice_starts(filter->shape().dimensions_size(), + 0); + std::vector filter_slice_limits( + filter->shape().dimensions().begin(), + filter->shape().dimensions().end()); + std::vector activation_slice_starts( + activation->shape().dimensions_size(), 0); + std::vector activation_slice_limits( + activation->shape().dimensions().begin(), + activation->shape().dimensions().end()); + + int64 output_feature = + filter->shape().dimensions(kernel_output_feature_dim); + auto output_feature_dim = dim_numbers.output_feature_dimension(); + int64 filter_slice_width = output_feature / group_count; + + int64 activation_input_feature_dim = + dim_numbers.input_feature_dimension(); + + for (int64 i = 0; i < group_count; i++) { + filter_slice_starts[kernel_output_feature_dim] = i * filter_slice_width; + filter_slice_limits[kernel_output_feature_dim] = + (i + 1) * filter_slice_width; + auto filter_sliced_shape = filter->shape(); + filter_sliced_shape.set_dimensions(kernel_output_feature_dim, + filter_slice_width); + auto filter_slice = add(HloInstruction::CreateSlice( + filter_sliced_shape, filter, filter_slice_starts, + filter_slice_limits, slice_strides)); + + activation_slice_starts[activation_input_feature_dim] = i * group_size; + activation_slice_limits[activation_input_feature_dim] = + (i + 1) * group_size; + auto activation_sliced_shape = activation->shape(); + activation_sliced_shape.set_dimensions(activation_input_feature_dim, + group_size); + auto activation_slice = add(HloInstruction::CreateSlice( + activation_sliced_shape, activation, activation_slice_starts, + activation_slice_limits, slice_strides)); + + auto conv_slice_shape = convolution->shape(); + conv_slice_shape.set_dimensions(output_feature_dim, filter_slice_width); + + auto new_convolution = add(HloInstruction::CreateConvolve( + conv_slice_shape, activation_slice, filter_slice, + /*feature_group_count=*/1, convolution->window(), dim_numbers, + convolution->precision_config())); + + sliced_convolutions.push_back(new_convolution); + } + + auto new_conv = HloInstruction::CreateConcatenate( + convolution->shape(), sliced_convolutions, output_feature_dim); + TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( + convolution, std::move(new_conv))); + } } - auto zero = add(HloInstruction::CreateConstant( - LiteralUtil::Zero(expanded_filter_shape.element_type()))); - auto zero_filter = - add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {})); - auto new_filter = add( - HloInstruction::CreateTernary(expanded_filter_shape, HloOpcode::kSelect, - filter_mask, expanded_filter, zero_filter)); - auto new_convolution = HloInstruction::CreateConvolve( - convolution->shape(), convolution->mutable_operand(0), new_filter, - /*feature_group_count=*/1, convolution->window(), dim_numbers, - convolution->precision_config()); - TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( - convolution, std::move(new_convolution))); + return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter_test.cc b/tensorflow/compiler/xla/service/convolution_feature_group_converter_test.cc index 28373ebf636..e6bf2143a21 100644 --- a/tensorflow/compiler/xla/service/convolution_feature_group_converter_test.cc +++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter_test.cc @@ -82,18 +82,14 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,4], filter: f32[1,2,2]) -> f32[1,2 ConvolutionFeatureGroupConverter converter; ASSERT_TRUE(converter.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); - // Make sure the convolution is converted to one with feature_group_count = 1. - EXPECT_EQ(root->opcode(), HloOpcode::kConvolution); - EXPECT_EQ(root->feature_group_count(), 1); - // Verify that the filter operand has been replaced. - EXPECT_THAT(root->operand(1), - op::Select(op::Eq(op::Broadcast(op::Constant()), - op::Broadcast(op::Constant())), - // We expect to see Concatenate here instead of - // Broadcast, because feature_group_count < input - // feature dimension. - op::Concatenate(op::Parameter(), op::Parameter()), - op::Broadcast(op::Constant()))); + // Make sure the convolution is replaced with a concatenate. + EXPECT_EQ(root->opcode(), HloOpcode::kConcatenate); + // And the operands of the concatenate are convolutions, each with a feature + // group count = 1. + EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kConvolution); + EXPECT_EQ(root->operand(1)->opcode(), HloOpcode::kConvolution); + EXPECT_EQ(root->operand(0)->feature_group_count(), 1); + EXPECT_EQ(root->operand(1)->feature_group_count(), 1); } } // namespace diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index 4e547d925f6..df605966387 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -442,7 +442,6 @@ class CopyRemover { const HloOrdering& ordering, HloModule* module) : module_(module), alias_analysis_(alias_analysis), - ordering_(ordering), buffer_value_tracker_(*module, alias_analysis, ordering) {} // Try to elide the given copy. The copy is elided if the instruction is not @@ -1003,7 +1002,6 @@ class CopyRemover { HloModule* module_; const HloAliasAnalysis& alias_analysis_; - const HloOrdering& ordering_; // Object tracking the HLO values contained in each HLO buffer. BufferValueTracker buffer_value_tracker_; diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 7446bc7cc11..e4e9d7ba05c 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -94,7 +94,7 @@ TEST_F(CopyInsertionTest, SingleParameter) { EXPECT_THAT(x->users(), UnorderedElementsAre(tuple)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); InsertCopies(module.get()); @@ -114,7 +114,7 @@ TEST_F(CopyInsertionTest, SingleConstant) { EXPECT_THAT(constant->users(), UnorderedElementsAre(tuple)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); InsertCopies(module.get()); @@ -127,7 +127,7 @@ TEST_F(CopyInsertionTest, SingleConstant) { TEST_F(CopyInsertionTest, ExistingCopiesNotRemoved) { // Verify that kCopy instructions which change layout and exist before // copy-insertion remain in the graph after copy-insertion. - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); HloInstruction* constant = @@ -181,7 +181,7 @@ TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) { builder.AddInstruction(HloInstruction::CreateTuple({constant2, x, add})); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); InsertCopies(module.get()); @@ -217,7 +217,7 @@ TEST_F(CopyInsertionTest, AmbiguousPointsToSet) { EXPECT_THAT(constant2->users(), UnorderedElementsAre(tuple1, tuple2)); EXPECT_THAT(constant3->users(), UnorderedElementsAre(tuple2)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); HloInstruction* old_root = module->entry_computation()->root_instruction(); @@ -238,7 +238,7 @@ TEST_F(CopyInsertionTest, BitcastParameter) { HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, x)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_THAT(x->users(), UnorderedElementsAre(bitcast)); @@ -261,7 +261,7 @@ TEST_F(CopyInsertionTest, BitcastConstant) { HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, constant)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_THAT(constant->users(), UnorderedElementsAre(bitcast)); @@ -283,7 +283,7 @@ TEST_F(CopyInsertionTest, BitcastTupleElementParameter) { ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, x)); builder.AddInstruction(HloInstruction::CreateTuple({bitcast})); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_THAT(x->users(), UnorderedElementsAre(bitcast)); @@ -310,7 +310,7 @@ TEST_F(CopyInsertionTest, NestedTupleParameter) { ShapeUtil::MakeShape(F32, {42})}), "param0")); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(HloOpcode::kParameter, @@ -351,7 +351,7 @@ TEST_F(CopyInsertionTest, ElementOfNestedTupleParameter) { auto gte = builder.AddInstruction(HloInstruction::CreateGetTupleElement( ShapeUtil::GetSubshape(param->shape(), {0}), param, 0)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(gte, module->entry_computation()->root_instruction()); @@ -388,7 +388,7 @@ TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) { builder.AddInstruction(HloInstruction::CreateGetTupleElement( ShapeUtil::GetSubshape(select->shape(), {0}), select, 0)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(gte, module->entry_computation()->root_instruction()); @@ -1295,7 +1295,7 @@ TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) { TEST_F(CopyInsertionTest, SwizzlingWhile) { // Test a while instruction with a body which permutes its tuple parameter // elements. - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); const Shape loop_state_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1362,7 +1362,7 @@ TEST_F(CopyInsertionTest, CrossingParameters) { // | / \ | // | / \| // (p1 , p0) - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1395,7 +1395,7 @@ TEST_F(CopyInsertionTest, ParametersAliasing) { // | | // | | // (p0 , p1) - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1428,7 +1428,7 @@ TEST_F(CopyInsertionTest, ParameterWithNoAliasing) { // | | // | | // (p0 , p1) - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1461,7 +1461,7 @@ TEST_F(CopyInsertionTest, ParameterWithPartialAliasing) { // | | // | | // (p0 , p1) - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1496,7 +1496,7 @@ TEST_F(CopyInsertionTest, ParameterAndParallelOpsWithPartialAliasing) { // | | | // | | | // +-- (p0 , p1) - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1534,7 +1534,7 @@ TEST_F(CopyInsertionTest, ParameterAndOpsWithPartialAliasing) { // | Add----+ // | | | // +-- (p0 , p1) - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1569,7 +1569,7 @@ TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) { // the operation (instruction) on the element makes the live range of the // respective input and output elements different than if the instruction were // not there (as in the SwizzlingWhile test above). - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); const Shape loop_state_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1632,7 +1632,7 @@ TEST_F(CopyInsertionTest, SwizzlingWhileSharedInput) { // the while body is a single constant (both loop state elements are the same // constant). This means no copies are necessary because both loop state // elements are the same so interchanging them is a no-op. - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); const Shape loop_state_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1693,7 +1693,7 @@ TEST_F(CopyInsertionTest, SequentialWhiles) { const Shape loop_state_shape = ShapeUtil::MakeTupleShape( {element_shape, element_shape, element_shape, element_shape}); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param_0 = builder.AddInstruction( HloInstruction::CreateParameter(0, element_shape, "param_0")); @@ -1783,7 +1783,7 @@ TEST_F(CopyInsertionTest, SequentialWhiles) { TEST_F(CopyInsertionTest, WhileBodyWithConstantRoot) { // Test a while body and condition which are each simply a constant (root of // computation is a constant). The body constant should be copied. - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param_0 = builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape_, "param_0")); diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 2763d18121a..ce4c2a9cc69 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -96,6 +96,7 @@ cc_library( "@com_google_absl//absl/types:span", "//tensorflow/compiler/tf2xla:cpu_function_runtime", "//tensorflow/compiler/xla/service:map_inliner", + "//tensorflow/compiler/xla/service:hlo_get_dimension_size_rewriter", "//tensorflow/compiler/xla/service:scatter_expander", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:protobuf_util", diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc index 73b03440cbb..796a7cf94d0 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc @@ -61,19 +61,6 @@ Disabling these as a starting point. // TODO(b/64227304) Creating a custom pass pipeline will replace this. namespace { -class FilteredFunctionPassManager : public llvm::legacy::FunctionPassManager { - public: - FilteredFunctionPassManager(llvm::Module* m, bool disable_expensive_passes) - : llvm::legacy::FunctionPassManager(m), - disable_expensive_passes_(disable_expensive_passes) {} - void add(llvm::Pass* p) override { - llvm::legacy::FunctionPassManager::add(p); - } - - private: - bool disable_expensive_passes_; -}; - class FilteredPassManager : public llvm::legacy::PassManager { public: explicit FilteredPassManager(bool disable_expensive_passes) @@ -96,8 +83,7 @@ class FilteredPassManager : public llvm::legacy::PassManager { std::unique_ptr CompilerFunctor::operator()( llvm::Module& module) const { FilteredPassManager module_passes(disable_expensive_passes_); - FilteredFunctionPassManager function_passes(&module, - disable_expensive_passes_); + llvm::legacy::FunctionPassManager function_passes(&module); VLOG(2) << "IR before optimizations"; XLA_VLOG_LINES(2, llvm_ir::DumpModuleToString(module)); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 4ce5a8a2925..6374822c81b 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -76,6 +76,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_cse.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" +#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -268,10 +269,11 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( /*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); - pass.AddPass( - /*is_layout_sensitive=*/false, - [](const Shape&, const Shape&) { return false; }, - /*enable_dot_strength_reduction=*/false); + pipeline.AddPass(); + AlgebraicSimplifierOptions options( + [](const Shape&, const Shape&) { return false; }); + options.set_enable_dot_strength_reduction(false); + pass.AddPass(options); pass.AddPass(); // BatchNormExpander can create zero-sized ops, so zero-sized HLO @@ -334,10 +336,11 @@ Status CpuCompiler::RunHloPassesAfterLayoutAssn( pass.AddInvariantChecker( /*layout_sensitive=*/true, /*allow_mixed_precision=*/false); - pass.AddPass>( - /*is_layout_sensitive=*/true, - [](const Shape&, const Shape&) { return true; }, - /*enable_dot_strength_reduction=*/false); + AlgebraicSimplifierOptions options( + [](const Shape&, const Shape&) { return true; }); + options.set_is_layout_sensitive(true); + options.set_enable_dot_strength_reduction(false); + pass.AddPass>(options); pass.AddPass(); pass.AddPass(/*is_layout_sensitive=*/true); } @@ -587,9 +590,9 @@ StatusOr> CpuCompiler::RunBackend( // Select an order for emitting the HLO instructions for each // computation. Using this sequence enables tighter buffer liveness analysis // and reduced memory usage (as compared to using DependencyHloOrdering). - TF_ASSIGN_OR_RETURN( - HloSchedule schedule, - ScheduleModule(*module, BufferSizeBytesFunction(), DFSMemoryScheduler)); + TF_ASSIGN_OR_RETURN(HloSchedule schedule, + ScheduleModule(module.get(), BufferSizeBytesFunction(), + DFSMemoryScheduler)); // Run buffer allocation on the HLO graph. TF_ASSIGN_OR_RETURN( @@ -779,7 +782,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, XLA_VLOG_LINES(2, module->ToString()); TF_ASSIGN_OR_RETURN(HloSchedule schedule, - ScheduleModule(*module, BufferSizeBytesFunction())); + ScheduleModule(module, BufferSizeBytesFunction())); // Run buffer analysis on the HLO graph. This analysis figures out which // temporary buffers are required to run the computation. diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index 29abf38e439..818b2b0d0db 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -51,8 +51,7 @@ namespace cpu { CpuExecutable::CpuExecutable( std::unique_ptr jit, std::unique_ptr assignment, - std::unique_ptr hlo_module, - const string& entry_function_name, + std::unique_ptr hlo_module, const string& entry_function_name, std::unique_ptr hlo_profile_printer_data, std::unique_ptr hlo_profile_index_map) : Executable(std::move(hlo_module), std::move(hlo_profile_printer_data), diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h index 3c3c047bfe8..3b91b15ba9b 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h @@ -49,7 +49,7 @@ class CpuExecutable : public Executable { public: CpuExecutable(std::unique_ptr jit, std::unique_ptr assignment, - std::unique_ptr hlo_module, + std::unique_ptr hlo_module, const string& entry_function_name, std::unique_ptr hlo_profile_printer_data, std::unique_ptr hlo_profile_index_map); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc index f9cd61bea3d..6f79ad7c146 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc @@ -48,10 +48,15 @@ bool IsMatrixVectorDot(const HloInstruction* hlo) { (hlo_shape.dimensions(0) == 1 || hlo_shape.dimensions(1) == 1); } +bool HasExactlyOneUse(const HloInstruction& hlo_instr) { + return hlo_instr.user_count() == 1 && + absl::c_count(hlo_instr.users().front()->operands(), &hlo_instr) == 1; +} + bool CanBeOutputFused(const HloInstruction* producer, const HloInstruction* consumer) { return consumer->opcode() == HloOpcode::kAdd && IsMatrixVectorDot(producer) && - producer->user_count() == 1; + HasExactlyOneUse(*producer) == 1; } bool CanBeOutputFusedIntoSomeOperand(const HloInstruction* consumer) { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index c95a514ca04..527df0bd1c2 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -321,7 +321,7 @@ TEST_F(OpcodeFusionTest, Exponential_Reshape_Negate) { builder.AddInstruction( HloInstruction::CreateUnary(result_shape, HloOpcode::kNegate, reshape2)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -370,7 +370,7 @@ TEST_F(OpcodeFusionTest, Broadcast_Negate) { builder.AddInstruction(HloInstruction::CreateUnary( result_shape, HloOpcode::kNegate, broadcast1)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -410,7 +410,7 @@ TEST_F(OpcodeFusionTest, Exponential_Negate) { builder.AddInstruction( HloInstruction::CreateUnary(param_shape, HloOpcode::kNegate, exp1)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -429,7 +429,7 @@ TEST_F(OpcodeFusionTest, Reshape_Negate) { builder.AddInstruction( HloInstruction::CreateUnary(result_shape, HloOpcode::kNegate, reshape1)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -447,7 +447,7 @@ TEST_F(OpcodeFusionTest, Reverse_Negate) { builder.AddInstruction( HloInstruction::CreateUnary(param_shape, HloOpcode::kNegate, reverse1)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -489,7 +489,7 @@ TEST_F(OpcodeFusionTest, Exponential_Transpose_Negate) { builder.AddInstruction(HloInstruction::CreateUnary( result_shape, HloOpcode::kNegate, transpose2)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -498,7 +498,7 @@ TEST_F(OpcodeFusionTest, Exponential_Transpose_Negate) { } TEST_F(OpcodeFusionTest, UnaryMapOfExp) { - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {3, 4}); @@ -517,7 +517,7 @@ TEST_F(OpcodeFusionTest, UnaryMapOfExp) { } TEST_F(OpcodeFusionTest, BinaryMapOfExps) { - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {3, 4}); @@ -542,7 +542,7 @@ TEST_F(OpcodeFusionTest, BinaryMapOfExps) { } TEST_F(OpcodeFusionTest, DynamicSliceWithDynamicUpdateSlice) { - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); Shape full_shape = ShapeUtil::MakeShape(F32, {10, 100, 1000}); @@ -573,7 +573,7 @@ TEST_F(OpcodeFusionTest, DynamicSliceWithDynamicUpdateSlice) { } TEST_F(OpcodeFusionTest, MessOfFusibleNodes) { - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); Shape full_shape = ShapeUtil::MakeShape(F32, {4, 100, 10, 100, 50}); @@ -712,7 +712,7 @@ void CreateComputationForDotAddOutputFusionTest(const string& test_name, } TEST_F(OpcodeFusionTest, DotAddOutputFusion_1x50x19) { - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/1, /*k=*/50, /*n=*/19, /*add_extra_use_for_dot=*/false); @@ -725,7 +725,7 @@ TEST_F(OpcodeFusionTest, DotAddOutputFusion_1x50x19) { } TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1) { - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19, /*k=*/50, /*n=*/1, /*add_extra_use_for_dot=*/false); @@ -738,7 +738,7 @@ TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1) { } TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x19) { - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19, /*k=*/50, /*n=*/19, /*add_extra_use_for_dot=*/false); @@ -751,7 +751,7 @@ TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x19) { } TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1_multi_use) { - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19, /*k=*/50, /*n=*/1, /*add_extra_use_for_dot=*/true); @@ -763,6 +763,28 @@ TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1_multi_use) { Not(op::Fusion())); } +TEST_F(InstructionFusionTest, + DotOperationFusion_DontOutputFuseDuplicateOperands) { + absl::string_view module_string = R"( +HloModule module + +ENTRY main { + a = f32[50,60]{1,0} parameter(0) + b = f32[60,1]{1,0} parameter(1) + c = f32[50,1]{1,0} dot(a, b), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT d = f32[50,1]{1,0} add(c, c) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_string)); + TF_ASSERT_OK_AND_ASSIGN(bool fused_something, + CpuInstructionFusion().Run(module.get())); + EXPECT_FALSE(fused_something); + EXPECT_THAT(module->entry_computation()->root_instruction(), + Not(op::Fusion())); +} + struct GatherLoopFusionTestSpec { string test_name; string hlo_computation_text; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc index 2cd52e4a18a..6c61b64758e 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc @@ -73,7 +73,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensor) { auto result = builder.AddInstruction( CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); ComputationLayout computation_layout(computation->ComputeProgramShape()); @@ -114,7 +114,7 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor0) { builder.AddInstruction(HloInstruction::CreateBinary( result_shape, HloOpcode::kAdd, dot_a_result, dot_b_result)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); ComputationLayout computation_layout(computation->ComputeProgramShape()); @@ -158,7 +158,7 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor1) { auto tuple_result = builder.AddInstruction( HloInstruction::CreateTuple({dot_a_result, dot_b_result})); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); ComputationLayout computation_layout(computation->ComputeProgramShape()); @@ -192,7 +192,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantLhsTensor) { auto dot_result = builder.AddInstruction( CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); ComputationLayout computation_layout(computation->ComputeProgramShape()); @@ -232,7 +232,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensorThroughGTE) { auto dot_result = builder.AddInstruction( CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); ComputationLayout computation_layout(computation->ComputeProgramShape()); @@ -353,7 +353,7 @@ static void AssertCorrectLayoutForDotOutputFusion( } TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_1x50x19_dot_idx_0) { - std::unique_ptr module = CreateNewUnverifiedModule(); + std::unique_ptr module = CreateNewVerifiedModule(); TF_ASSERT_OK_AND_ASSIGN( DotOutputFusionLayoutAssignmentResult layout_assignment_result, RunDotOutputFusion(module.get(), TestName(), /*m=*/1, /*k=*/50, /*n=*/19, @@ -365,7 +365,7 @@ TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_1x50x19_dot_idx_0) { } TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_1x50x19_dot_idx_1) { - std::unique_ptr module = CreateNewUnverifiedModule(); + std::unique_ptr module = CreateNewVerifiedModule(); TF_ASSERT_OK_AND_ASSIGN( DotOutputFusionLayoutAssignmentResult layout_assignment_result, RunDotOutputFusion(module.get(), TestName(), /*m=*/1, /*k=*/50, /*n=*/19, @@ -377,7 +377,7 @@ TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_1x50x19_dot_idx_1) { } TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x1_dot_idx_0) { - std::unique_ptr module = CreateNewUnverifiedModule(); + std::unique_ptr module = CreateNewVerifiedModule(); TF_ASSERT_OK_AND_ASSIGN( DotOutputFusionLayoutAssignmentResult layout_assignment_result, RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/1, @@ -389,7 +389,7 @@ TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x1_dot_idx_0) { } TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x1_dot_idx_1) { - std::unique_ptr module = CreateNewUnverifiedModule(); + std::unique_ptr module = CreateNewVerifiedModule(); TF_ASSERT_OK_AND_ASSIGN( DotOutputFusionLayoutAssignmentResult layout_assignment_result, RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/1, @@ -401,7 +401,7 @@ TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x1_dot_idx_1) { } TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x19_dot_idx_0) { - std::unique_ptr module = CreateNewUnverifiedModule(); + std::unique_ptr module = CreateNewVerifiedModule(); TF_ASSERT_OK_AND_ASSIGN( DotOutputFusionLayoutAssignmentResult layout_assignment_result, RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/19, @@ -413,7 +413,7 @@ TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x19_dot_idx_0) { } TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x19_dot_idx_1) { - std::unique_ptr module = CreateNewUnverifiedModule(); + std::unique_ptr module = CreateNewVerifiedModule(); TF_ASSERT_OK_AND_ASSIGN( DotOutputFusionLayoutAssignmentResult layout_assignment_result, RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/19, diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.cc b/tensorflow/compiler/xla/service/cpu/cpu_options.cc index b8ace570268..92debb83e33 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.cc @@ -22,7 +22,6 @@ limitations under the License. namespace { const char* const kXlaOptimizeForSizeCpuOption = "xla_cpu_optimize_for_size"; -const char* const kXlaDisableVectorizedReduce = "xla_disable_vectorized_reduce"; const char* const kLlvmIrDotTilingFactor = "xla_llvm_dot_tiling_factor"; const char* const kXlaEnableExperimentalLlvmIrGemm = "xla_enable_experimental_llvm_ir_gemm"; diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 620c45fa391..4032c2da2f3 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -111,7 +111,7 @@ IrEmitter::IrEmitter( StatusOr IrEmitter::EmitComputation( HloComputation* computation, const string& function_name_prefix, bool is_top_level_computation, - const std::vector* instruction_order) { + const std::vector* instruction_order) { string function_name = name_uniquer_.GetUniqueName(function_name_prefix); VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix << "]; ordered? " << (instruction_order != nullptr); @@ -140,7 +140,7 @@ StatusOr IrEmitter::EmitComputation( // readcyclecounter if it is unavailable. bool use_rdtscp = arch_type_ == llvm::Triple::ArchType::x86 || arch_type_ == llvm::Triple::ArchType::x86_64; - profiling_state_ = ProfilingState(use_rdtscp, GetProfileCountersArgument()); + profiling_state_ = ProfilingState(use_rdtscp); if (instruction_order == nullptr) { TF_RETURN_IF_ERROR(computation->Accept(this)); } else { @@ -1379,33 +1379,6 @@ Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) { return Status::OK(); } -// Fills up the free variables in 'index_with_free_var' with values from -// 'filler_index'. The size of free variables must be the same as the -// size of 'filler_index'. -// -// This is often used after dimension reduction, where -// 'index_with_free_var' has one or more dimensions reduced, which serves as -// free variables (represented as nullptr). For example, if we have a 4 -// dimensional input and index for the dimension being reduced is -// 2 (third dimension), we will have an index like [i, j, NULL, k] -// after reduced dimension. -// -// Here we fill up that free variable by 'filler_index', which contains -// the value in the reduced dimension. -static llvm_ir::IrArray::Index FillReducedDimensionIndex( - llvm_ir::IrArray::Index index_with_free_var, - llvm_ir::IrArray::Index filler_index) { - llvm_ir::IrArray::Index::const_iterator it = filler_index.begin(); - - for (size_t i = 0; i < index_with_free_var.size(); ++i) { - if (index_with_free_var[i] == nullptr) { - index_with_free_var[i] = *it++; - } - } - CHECK(filler_index.end() == it); - return index_with_free_var; -} - Status IrEmitter::HandleParameter(HloInstruction* parameter) { VLOG(2) << "HandleParameter: " << parameter->ToString(); return EmitTargetAddressForOp(parameter); @@ -2194,14 +2167,6 @@ Status IrEmitter::HandlePad(HloInstruction* pad) { return Status::OK(); } -// If `hlo` is a Transpose, returns its operand; otherwise returns `hlo` itself. -static const HloInstruction* StripTranspose(const HloInstruction& hlo) { - if (hlo.IsRank2Transpose()) { - return hlo.operand(0); - } - return &hlo; -} - Status IrEmitter::HandleFusion(HloInstruction* fusion) { auto* root = fusion->fused_expression_root(); if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion, assignment_)) { @@ -2600,10 +2565,17 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) { return Status::OK(); } -Status IrEmitter::HandleAfterAll(HloInstruction* gen_token) { - TF_RET_CHECK(ByteSizeOf(gen_token->shape()) == 0); +Status IrEmitter::HandleAfterAll(HloInstruction* after_all) { + TF_RET_CHECK(ByteSizeOf(after_all->shape()) == 0); // No code to generate, but we need to emit an address for book-keeping. - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(gen_token)); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(after_all)); + return Status::OK(); +} + +Status IrEmitter::HandleAddDependency(HloInstruction* add_dependency) { + // AddDedendency just forwards its zero-th operand. + emitted_value_[add_dependency] = + GetEmittedValueFor(add_dependency->operand(0)); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 136b88ff75e..559a8162a2d 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -101,7 +101,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, StatusOr EmitComputation( HloComputation* computation, const string& function_name_prefix, bool is_top_level_computation, - const std::vector* instruction_order); + const std::vector* instruction_order); llvm::IRBuilder<>* b() { return &b_; } @@ -159,7 +159,8 @@ class IrEmitter : public DfsHloVisitorWithDefault, Status HandleConcatenate(HloInstruction* concatenate) override; Status HandleConditional(HloInstruction* conditional) override; Status HandleScatter(HloInstruction* scatter) override; - Status HandleAfterAll(HloInstruction* gen_token) override; + Status HandleAfterAll(HloInstruction* after_all) override; + Status HandleAddDependency(HloInstruction* add_dependency) override; Status HandleRng(HloInstruction* rng) override; Status FinishVisit(HloInstruction* root) override; @@ -467,9 +468,8 @@ class IrEmitter : public DfsHloVisitorWithDefault, // profiling a computation. class ProfilingState { public: - ProfilingState() : use_rdtscp_(false), prof_counters_(nullptr) {} - ProfilingState(bool use_rdtscp, llvm::Value* prof_counters) - : use_rdtscp_(use_rdtscp), prof_counters_(prof_counters) {} + ProfilingState() : use_rdtscp_(false) {} + explicit ProfilingState(bool use_rdtscp) : use_rdtscp_(use_rdtscp) {} // Record the cycle counter before an HLO executes. void RecordCycleStart(llvm::IRBuilder<>* b, HloInstruction* hlo); @@ -494,9 +494,6 @@ class IrEmitter : public DfsHloVisitorWithDefault, // intrinsic? bool use_rdtscp_; - // The argument which corresponds to the profile counter buffer. - llvm::Value* prof_counters_; - // The first read cycle counter in the program. llvm::Value* first_read_cycle_start_ = nullptr; diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc index 669eeb95f32..722aa3120ef 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -41,61 +42,60 @@ void KeyValueSort(std::pair* row_to_sort, int64 num_elements) { std::sort(row_to_sort, row_to_sort + num_elements); } -// For floating point numbers, we want a total order comparator. -NaN and NaN -// should appear at the beginning and end of the ordering, and -0.0 should -// appear before 0.0. Also we want to have a stable sort, so if the keys are the -// same, we compare the index values. -template -bool LessThan(KeyType lhs, int64 lhs_index, KeyType rhs, int64 rhs_index) { - bool lhs_is_negative = std::signbit(lhs); - bool rhs_is_negative = std::signbit(rhs); - // If the signs are different, we can just compare the signs. - if (lhs_is_negative != rhs_is_negative) { - return lhs_is_negative && !rhs_is_negative; +// We would like a total order of floating point numbers so that the +// sort has a predictable behavior in the presence of NaNs. Rather +// than using floating point comparison, we use the following trick: +// If f is a float, and +// x = bit_cast(f); +// y = x < 0 ? 0x7FFFFFFF - x : x; +// then y is ordered as an int32 such that finite values have the +// obvious order, -0 is ordered before 0, and -NaN and NaN appear at +// the beginning and end of the ordering. +template +CastType Convert(KeyType value) { + CastType casted_value; + memcpy(&casted_value, &value, sizeof(CastType)); + if (casted_value < 0) { + return static_cast(std::numeric_limits::max()) - + casted_value; } - bool lhs_nan = std::isnan(lhs); - bool rhs_nan = std::isnan(rhs); - // Exactly one number is nan? - if (lhs_nan != rhs_nan) { - if (lhs_nan) { - return lhs_is_negative; - } - return !rhs_is_negative; - } - if (lhs != rhs) { - return lhs < rhs; - } - return lhs_index < rhs_index; + return casted_value; +} + +template +bool LessThan(KeyType lhs, KeyType rhs) { + return Convert(lhs) < + Convert(rhs); } template <> void KeyValueSort(std::pair* row_to_sort, int64 num_elements) { - std::sort(row_to_sort, row_to_sort + num_elements, - [](const std::pair& lhs, - const std::pair& rhs) -> bool { - return LessThan(lhs.first, lhs.second, rhs.first, rhs.second); - }); + std::stable_sort(row_to_sort, row_to_sort + num_elements, + [](const std::pair& lhs, + const std::pair& rhs) -> bool { + return LessThan(lhs.first, rhs.first); + }); } template <> void KeyValueSort(std::pair* row_to_sort, int64 num_elements) { - std::sort(row_to_sort, row_to_sort + num_elements, - [](const std::pair& lhs, - const std::pair& rhs) -> bool { - return LessThan(lhs.first, lhs.second, rhs.first, rhs.second); - }); + std::stable_sort(row_to_sort, row_to_sort + num_elements, + [](const std::pair& lhs, + const std::pair& rhs) -> bool { + return LessThan(lhs.first, rhs.first); + }); } template <> void KeyValueSort(std::pair* row_to_sort, int64 num_elements) { - std::sort(row_to_sort, row_to_sort + num_elements, - [](const std::pair& lhs, - const std::pair& rhs) -> bool { - return LessThan( - Eigen::half_impl::half_to_float(lhs.first), lhs.second, - Eigen::half_impl::half_to_float(rhs.first), rhs.second); - }); + std::stable_sort(row_to_sort, row_to_sort + num_elements, + [](const std::pair& lhs, + const std::pair& rhs) -> bool { + return LessThan( + Eigen::half_impl::half_to_float(lhs.first), + Eigen::half_impl::half_to_float(rhs.first)); + }); } template diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index f77641eb7da..efccadedf27 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -128,8 +128,18 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options, } llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) { - void* func_addr = CustomCallTargetRegistry::Global()->Lookup(name); + void* func_addr = nullptr; + if (name.size() > 1 && name.front() == data_layout_.getGlobalPrefix()) { + // On Mac OS X, 'name' may have a leading underscore prefix, even though the + // registered name may not. + std::string stripped_name(name.begin() + 1, name.end()); + func_addr = CustomCallTargetRegistry::Global()->Lookup(stripped_name); + } else { + func_addr = CustomCallTargetRegistry::Global()->Lookup(name); + } + if (func_addr == nullptr) { + VLOG(2) << "Unable to resolve runtime symbol: " << name; return nullptr; } llvm::JITEvaluatedSymbol symbol_info(reinterpret_cast(func_addr), diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc index 691b3c7bee2..f8f5f392da8 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc @@ -50,7 +50,7 @@ class CpuEigenDotOperationTest /*entry_point_name=*/"entry", /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(std::move(entry_computation)); CompileAheadOfTimeAndVerifyIr(std::move(hlo_module), options, diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc index d201a151d7a..e30f95311fc 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc @@ -46,7 +46,7 @@ class CpuExternalConstantsTest : public CpuCodegenTest { builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, constant)); - std::unique_ptr module = CreateNewUnverifiedModule(); + std::unique_ptr module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); CompileAndVerifyIr(std::move(module), filecheck_pattern, diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc index 773336c7a92..9b10c49f4f5 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc @@ -91,7 +91,7 @@ TEST_P(CpuUnaryIntrinsicTest, DoIt) { /*entry_point_name=*/"entry", /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(std::move(computation)); string check_lines{spec.check_lines.data(), spec.check_lines.size()}; diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc index 3b87683ffff..fa0e09ff6b5 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc @@ -63,7 +63,7 @@ CHECK-NOT: private constant [48 x i8] )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(hlo_text)); + ParseAndReturnVerifiedModule(hlo_text)); CpuAotCompilationOptions options{ /*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"", @@ -104,14 +104,14 @@ ENTRY main { )"; string filecheck_pattern = R"( -CHECK: private constant [4 x i8] -CHECK: private constant [8 x i8] +CHECK-DAG: private constant [4 x i8] +CHECK-DAG: private constant [8 x i8] CHECK-NOT: private constant [4 x i8] CHECK-NOT: private constant [8 x i8] )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(hlo_text)); + ParseAndReturnVerifiedModule(hlo_text)); CpuAotCompilationOptions options{ /*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"", diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc index f5419b7063b..a7702c2aeea 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc @@ -56,7 +56,7 @@ TEST_F(CpuNoAliasTest, Concat) { std::unique_ptr computation = builder.Build(); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(std::move(computation)); // Now that we have an HLO module, build an llvm_ir::AliasAnalysis for it. diff --git a/tensorflow/compiler/xla/service/cpu/xfeed_manager.h b/tensorflow/compiler/xla/service/cpu/xfeed_manager.h index 990ff94ba23..70008947f37 100644 --- a/tensorflow/compiler/xla/service/cpu/xfeed_manager.h +++ b/tensorflow/compiler/xla/service/cpu/xfeed_manager.h @@ -23,6 +23,7 @@ limitations under the License. #include #include "absl/types/span.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index d6371283221..e84bf00153a 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -251,6 +251,7 @@ class DfsHloVisitorBase { virtual Status HandleBatchNormGrad(HloInstructionPtr hlo) = 0; + virtual Status HandleAddDependency(HloInstructionPtr add_dependency) = 0; virtual Status HandleAfterAll(HloInstructionPtr token) = 0; // Invoked to inform the visitor that the traversal has completed, and that diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index e57184f639f..80ea5be298a 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -206,6 +206,9 @@ class DfsHloVisitorWithDefaultBase Status HandleGetDimensionSize(HloInstructionPtr get_size) override { return DefaultAction(get_size); } + Status HandleAddDependency(HloInstructionPtr add_dependency) override { + return DefaultAction(add_dependency); + } // Invoked to inform the visitor that the traversal has completed, and that // the root was "root". diff --git a/tensorflow/compiler/xla/service/dynamic_parameter_binding.cc b/tensorflow/compiler/xla/service/dynamic_parameter_binding.cc new file mode 100644 index 00000000000..c8bfc890506 --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_parameter_binding.cc @@ -0,0 +1,138 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" + +namespace xla { + +Status DynamicParameterBinding::Bind( + const DynamicParameter& dynamic_parameter, + const DynamicDimension& dynamic_dimension) { + auto result = bindings_.emplace(dynamic_dimension, dynamic_parameter); + TF_RET_CHECK(result.second); + return Status::OK(); +} + +absl::optional +DynamicParameterBinding::GetBinding(const DynamicDimension& dynamic_dimension) { + auto param_iter = bindings_.find(dynamic_dimension); + if (param_iter == bindings_.end()) { + return absl::nullopt; + } + return param_iter->second; +} + +DynamicParameterBindingProto DynamicParameterBinding::ToProto() const { + DynamicParameterBindingProto result; + for (const auto& binding : bindings_) { + const DynamicDimension& dynamic_dimension = binding.first; + const DynamicParameter& dynamic_param = binding.second; + DynamicParameterBindingProto::Binding binding_proto; + binding_proto.set_dynamic_param_num(dynamic_param.parameter_num); + for (int64 i : dynamic_param.parameter_index) { + binding_proto.add_dynamic_param_index(i); + } + + binding_proto.set_target_param_num(dynamic_dimension.parameter_num); + + for (int64 i : dynamic_dimension.parameter_index) { + binding_proto.add_target_param_index(i); + } + + binding_proto.set_target_param_dim_num(dynamic_dimension.dimension); + result.add_entries()->Swap(&binding_proto); + } + return result; +} + +StatusOr DynamicParameterBinding::CreateFromProto( + const DynamicParameterBindingProto& proto) { + DynamicParameterBinding result; + for (const DynamicParameterBindingProto::Binding& binding : proto.entries()) { + int64 dynamic_param_num = binding.dynamic_param_num(); + ShapeIndex dynamic_param_index(binding.dynamic_param_index().begin(), + binding.dynamic_param_index().end()); + int64 target_param_num = binding.target_param_num(); + ShapeIndex target_param_index(binding.target_param_index().begin(), + binding.target_param_index().end()); + int64 target_dim_num = binding.target_param_num(); + + TF_RETURN_IF_ERROR( + result.Bind(DynamicParameter{dynamic_param_num, dynamic_param_index}, + DynamicDimension{target_param_num, target_param_index, + target_dim_num})); + } + + return result; +} + +string DynamicParameterBinding::ToString() const { + std::vector pieces; + pieces.push_back("DynamicParameterBinding: "); + for (const auto& binding : bindings_) { + const DynamicDimension& dynamic_dimension = binding.first; + const DynamicParameter& dynamic_param = binding.second; + pieces.push_back(absl::StrFormat( + " -- Input param number %lld at %s has dim %lld as dynamic" + " dimension, which is represented by param number %lld at " + "%s", + dynamic_dimension.parameter_num, + dynamic_dimension.parameter_index.ToString(), + dynamic_dimension.dimension, dynamic_param.parameter_num, + dynamic_param.parameter_index.ToString())); + } + return absl::StrJoin(pieces, "\n"); +} + +Status DynamicParameterBinding::ForEachBinding(BindingFn fn) const { + for (const auto& binding : bindings_) { + TF_RETURN_IF_ERROR(fn(binding.second, binding.first)); + } + return Status::OK(); +} + +Status DynamicParameterBinding::Verify(const HloModule& module) const { + const HloComputation* entry = module.entry_computation(); + return ForEachBinding([&](const DynamicParameter& dynamic_parameter, + const DynamicDimension& dynamic_dimension) + -> Status { + TF_RET_CHECK(dynamic_parameter.parameter_num < entry->num_parameters()); + TF_RET_CHECK(dynamic_dimension.parameter_num < entry->num_parameters()); + TF_RET_CHECK(ShapeUtil::IndexIsValid( + entry->parameter_instruction(dynamic_parameter.parameter_num)->shape(), + dynamic_parameter.parameter_index)); + TF_RET_CHECK(ShapeUtil::IndexIsValid( + entry->parameter_instruction(dynamic_dimension.parameter_num)->shape(), + dynamic_dimension.parameter_index)); + TF_RET_CHECK( + dynamic_dimension.dimension < + ShapeUtil::Rank(ShapeUtil::GetSubshape( + entry->parameter_instruction(dynamic_dimension.parameter_num) + ->shape(), + dynamic_dimension.parameter_index))); + return Status::OK(); + }); +} + +std::ostream& operator<<(std::ostream& out, + const DynamicParameterBinding& binding) { + out << binding.ToString(); + return out; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/dynamic_parameter_binding.h b/tensorflow/compiler/xla/service/dynamic_parameter_binding.h new file mode 100644 index 00000000000..dd474d8eed1 --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_parameter_binding.h @@ -0,0 +1,125 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_PARAMETER_BINDING_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_PARAMETER_BINDING_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/shape_util.h" + +namespace xla { + +class HloModule; +// We currently use an explicit API that takes an extra parameter to indicate +// the runtime size of a dynamic dimension. DynamicParameterBinding indicates +// the relationship between parameter: We can have a dynamic parameter that +// points to another target parameter to indicate that the target parameter is +// dynamic. +// +// +// TODO(b/119520625): Remove this API once we have more dynamic shape infra +// ready. +class DynamicParameterBinding { + public: + // DynamicParameter represents a special parameter that is used to represent + // the runtime size of a dimension of another parameter. A dynamic parameter + // has to be a scalar value. + struct DynamicParameter { + // The parameter number of dynamic parameter. + int64 parameter_num; + // The index of the parameter. + ShapeIndex parameter_index; + }; + + // DynamicDimension represents a dimension whose size is determined at + // runtime. A DynamicDimension's runtime size is determined by the binded + // DynamicParameter using `DynamicParameterBinding::Bind` method. + struct DynamicDimension { + // The parameter number of dynamic dimension. + int64 parameter_num; + // The subshape index of the parameter. + ShapeIndex parameter_index; + // The dimension number in the subshape. + int64 dimension; + + // "friend" keyword are added so these functions can be found by ADL. + template + friend H AbslHashValue(H h, const DynamicDimension& m) { + return H::combine(std::move(h), m.parameter_num, m.parameter_index, + m.dimension); + } + + friend bool operator==(const DynamicDimension& lhs, + const DynamicDimension& rhs) { + return lhs.parameter_num == rhs.parameter_num && + lhs.parameter_index == rhs.parameter_index && + lhs.dimension == rhs.dimension; + } + }; + + DynamicParameterBinding() = default; + + virtual ~DynamicParameterBinding() = default; + + // Adds binding which indicates that the dimension indicated by + // `dynamic_dimension` is dynamic, and its runtime size is represented by + // `dynamic_parameter`. + Status Bind(const DynamicParameter& dynamic_parameter, + const DynamicDimension& dynamic_dimension); + + // Returns the parameter and the index representing the runtime size of + // dimension `dim_num` of parameter `param_num` at `param_index`. + // + // Returns nullopt if the binding is not set. + absl::optional GetBinding( + const DynamicDimension& dynamic_dimension); + + using BindingFn = + std::function; + + // Iterate through each binding. + Status ForEachBinding(BindingFn fn) const; + + DynamicParameterBindingProto ToProto() const; + + static StatusOr CreateFromProto( + const DynamicParameterBindingProto& proto); + + string ToString() const; + + // Verifies that the given binding is valid for the given module. + // Specifically, the binding's parameter and parameter size should be valid. + Status Verify(const HloModule& module) const; + + private: + // Keeps track of mappings from DynamicDimension to DynamicParameter. The + // direction of is chosen so that we can easily query if a dimension is + // dynamic and which dynamic parameter represents the real size of that + // dimension. + absl::flat_hash_map bindings_; +}; + +std::ostream& operator<<(std::ostream& out, + const DynamicParameterBinding& binding); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_PARAMETER_BINDING_H_ diff --git a/tensorflow/compiler/xla/service/dynamic_parameter_binding_test.cc b/tensorflow/compiler/xla/service/dynamic_parameter_binding_test.cc new file mode 100644 index 00000000000..83a6d83dffd --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_parameter_binding_test.cc @@ -0,0 +1,153 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h" + +#include +#include + +#include "absl/algorithm/container.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { +class DynamicParameterBindingTest : public HloTestBase {}; + +TEST_F(DynamicParameterBindingTest, SimpleBinding) { + // 'b' is a dynamic shape; 'a' represents the real size of b's first + // dimension. + const string module_str = R"( +HloModule TEST + +ENTRY main { + a = f32[] parameter(0) + b = f32[10] parameter(1) + ROOT root = (f32[], f32[10]) tuple(%a, %b) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + + DynamicParameterBinding binding; + + TF_EXPECT_OK( + binding.Bind(DynamicParameterBinding::DynamicParameter{0, {}}, + DynamicParameterBinding::DynamicDimension{1, {}, 0})); + + absl::optional param = + binding.GetBinding( + DynamicParameterBinding::DynamicDimension{/*parameter_num=*/1, + /*parameter_index=*/{}, + /*dimension=*/0}); + EXPECT_TRUE(param); + EXPECT_EQ(param->parameter_num, 0); + EXPECT_EQ(param->parameter_index, ShapeIndex({})); + TF_EXPECT_OK(binding.Verify(*module)); +} + +TEST_F(DynamicParameterBindingTest, TupleBinding) { + // 'gte2' is a dynamic shape; 'gte1' represents the real size of gte2's first + // dimension. + const string module_str = R"( +HloModule TEST + +ENTRY main { + param = (f32[], f32[10]) parameter(0) + gte1 = f32[] get-tuple-element(%param), index=0 + gte2 = f32[10] get-tuple-element(%param), index=1 + ROOT root = (f32[], f32[10]) tuple(%gte1, %gte2) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + + DynamicParameterBinding binding; + + TF_EXPECT_OK( + binding.Bind(DynamicParameterBinding::DynamicParameter{0, {0}}, + DynamicParameterBinding::DynamicDimension{0, {1}, 0})); + + absl::optional param = + binding.GetBinding( + DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0, + /*parameter_index=*/{1}, + /*dimension=*/0}); + + EXPECT_TRUE(param); + EXPECT_EQ(param->parameter_num, 0); + EXPECT_EQ(param->parameter_index, ShapeIndex({0})); + TF_EXPECT_OK(binding.Verify(*module)); +} + +TEST_F(DynamicParameterBindingTest, TupleBindingWithMultiDimension) { + // 'gte2' is a dynamic shape; 'gte1' represents the real size of gte2's both + // dimensions. + const string module_str = R"( +HloModule TEST + +ENTRY main { + param = (f32[], f32[10, 10]) parameter(0) + gte1 = f32[] get-tuple-element(%param), index=0 + gte2 = f32[10, 10] get-tuple-element(%param), index=1 + ROOT root = (f32[], f32[10, 10]) tuple(%gte1, %gte2) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + + DynamicParameterBinding binding; + + TF_EXPECT_OK( + binding.Bind(DynamicParameterBinding::DynamicParameter{0, {0}}, + DynamicParameterBinding::DynamicDimension{0, {1}, 0})); + + TF_EXPECT_OK( + binding.Bind(DynamicParameterBinding::DynamicParameter{0, {0}}, + DynamicParameterBinding::DynamicDimension{0, {1}, 1})); + + absl::optional param = + binding.GetBinding( + DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0, + /*parameter_index=*/{1}, + /*dimension=*/0}); + + EXPECT_TRUE(param); + EXPECT_EQ(param->parameter_num, 0); + EXPECT_EQ(param->parameter_index, ShapeIndex({0})); + + absl::optional param2 = + binding.GetBinding( + DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0, + /*parameter_index=*/{1}, + /*dimension=*/0}); + EXPECT_TRUE(param2); + EXPECT_EQ(param2->parameter_num, 0); + EXPECT_EQ(param2->parameter_index, ShapeIndex({0})); + + TF_EXPECT_OK(binding.Verify(*module)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index f98c943669b..6f1f95f2e90 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -22,6 +22,7 @@ limitations under the License. // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" #include "absl/strings/str_cat.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Instructions.h" @@ -1671,26 +1672,66 @@ StatusOr ElementalIrEmitter::EmitElementalConcatenate( b_->SetInsertPoint(init_block); + // Assign a unique id for each *different* operand, and count how often each + // operand is used. If all operands are different, the usage count will be 1 + // for each operand. + absl::flat_hash_map to_unique_operand_id; + std::vector operand_usage_count; + for (const auto* operand : hlo->operands()) { + if (to_unique_operand_id.contains(operand)) { + ++operand_usage_count[to_unique_operand_id[operand]]; + } else { + int64 unique_operand_id = to_unique_operand_id.size(); + to_unique_operand_id[operand] = unique_operand_id; + operand_usage_count.push_back(1); + } + } + + // To avoid that we emit the same operand more than once, we create one basic + // block for each *different* operand with a PHI node for the different source + // index inputs. + std::vector emit_operand_blocks( + to_unique_operand_id.size(), nullptr); + std::vector source_index_phis(to_unique_operand_id.size(), + nullptr); + for (const auto* operand : hlo->operands()) { + int64 operand_id = to_unique_operand_id[operand]; + if (emit_operand_blocks[operand_id] != nullptr) { + continue; + } + + emit_operand_blocks[operand_id] = llvm_ir::CreateBasicBlock( + exit_block, StrCat("concat_index_from_operand_id", operand_id), b_); + auto saved_insert_point = b_->GetInsertPoint(); + llvm_ir::SetToFirstInsertPoint(emit_operand_blocks[operand_id], b_); + source_index_phis[operand_id] = + PHI(source_index.GetType(), operand_usage_count[operand_id]); + auto operand_index = source_index; + operand_index[concat_dim] = source_index_phis[operand_id]; + + // Create the terminator of the block before calling operand generators, + // because they require non-degenerate basic blocks. + b_->SetInsertPoint(llvm::BranchInst::Create( + exit_block, /*InsertAtEnd=*/emit_operand_blocks[operand_id])); + TF_ASSIGN_OR_RETURN(llvm::Value * value, + operand_to_generator.at(operand)(operand_index)); + output->addIncoming(value, b_->GetInsertBlock()); + b_->SetInsertPoint(init_block, saved_insert_point); + } + for (int64 operand_idx = 0; operand_idx < hlo->operand_count(); ++operand_idx) { const HloInstruction* operand = hlo->operand(operand_idx); - auto true_block = llvm_ir::CreateBasicBlock( - exit_block, StrCat("concat_index_from_operand", operand_idx), b_); auto false_block = llvm_ir::CreateBasicBlock( exit_block, StrCat("concat_index_not_from_operand", operand_idx), b_); auto concat_dim_size = llvm::ConstantInt::get(source_index[concat_dim]->getType(), operand->shape().dimensions(concat_dim)); - CondBr(ICmpULT(source_index[concat_dim], concat_dim_size), true_block, - false_block); - - // Create the terminator of the true block before calling operand - // generators, because they require non-degenerate basic blocks. - b_->SetInsertPoint( - llvm::BranchInst::Create(exit_block, /*InsertAtEnd=*/true_block)); - TF_ASSIGN_OR_RETURN(llvm::Value * value, - operand_to_generator.at(operand)(source_index)); - output->addIncoming(value, b_->GetInsertBlock()); + int64 operand_id = to_unique_operand_id[operand]; + source_index_phis[operand_id]->addIncoming(source_index[concat_dim], + b_->GetInsertBlock()); + CondBr(ICmpULT(source_index[concat_dim], concat_dim_size), + emit_operand_blocks[operand_id], false_block); // Subtract the size of the concat dimension of the current operand // from the source index. @@ -2204,13 +2245,15 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( : iota->shape(); PrimitiveType component_element_type = component_shape.element_type(); llvm::Value* iota_result; - if (ShapeUtil::ElementIsIntegral(component_shape)) { + if (primitive_util::IsIntegralType(component_element_type) || + component_element_type == PRED) { iota_result = b_->CreateIntCast( elem_index_linear, llvm_ir::PrimitiveTypeToIrType(component_element_type, module_), /*isSigned=*/false); } else { - TF_RET_CHECK(ShapeUtil::ElementIsFloating(component_shape)) + TF_RET_CHECK( + primitive_util::IsFloatingPointType(component_element_type)) << component_element_type; llvm::Type* float_ir_type; if (component_element_type == BF16) { diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index 45f620f3f33..b34bca55a48 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -61,7 +61,7 @@ struct ExecutionOutput { class Executable { public: explicit Executable( - std::unique_ptr hlo_module, + std::unique_ptr hlo_module, std::unique_ptr hlo_profile_printer_data, std::unique_ptr hlo_profile_index_map) : hlo_module_(std::move(hlo_module)), @@ -162,7 +162,7 @@ class Executable { return hlo_profile_printer_data_ != nullptr; } - const HloModule& module() const { return *hlo_module_; } + HloModule& module() const { return *hlo_module_; } const bool has_module() const { return hlo_module_ != nullptr; } @@ -199,7 +199,7 @@ class Executable { // HloModule this was compiled from. BufferAssignment keeps pointers to // HloInstructions owned by the HloModule so we need to keep the HloModule // around. - const std::unique_ptr hlo_module_; + const std::unique_ptr hlo_module_; // HloSnapshot this was compiled from. Null if not dumping executions. std::unique_ptr hlo_snapshot_; diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index b1629616acd..bfd1b6cb149 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -701,6 +701,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_cse", "//tensorflow/compiler/xla/service:hlo_dce", "//tensorflow/compiler/xla/service:hlo_element_type_converter", + "//tensorflow/compiler/xla/service:hlo_get_dimension_size_rewriter", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:hlo_pass_pipeline", "//tensorflow/compiler/xla/service:hlo_proto", diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc index 4ce877f62a5..e81850db69e 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc @@ -77,7 +77,11 @@ bool CanImplementAsCudnnForwardConv(HloInstruction* conv) { return false; } - if (window_util::HasWindowReversal(conv->window())) { + // CuDNN can perform either cross correlation (no reversal), + // or convolution (all dimensions reversed). + if (dnums.input_spatial_dimensions_size() == 2 + ? !window_util::AllOrNoneReversed(conv->window()) + : window_util::HasWindowReversal(conv->window())) { return false; } return true; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc index 492d290bf4a..3425e1b4942 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc @@ -138,6 +138,7 @@ Status RunCudnnConvImpl(CudnnConvParams params, const int num_dimensions = window.dimensions_size(); CHECK_LE(num_dimensions, 3); + CHECK_GE(num_dimensions, 1); // cuDNN does not support 1D convolutions. We therefore express 1D // convolutions as 2D convolutions where the first spatial dimension is 1. // This matches the behavior of TF (see definition of conv1d in @@ -148,10 +149,15 @@ Status RunCudnnConvImpl(CudnnConvParams params, output_shape.element_type()) << ShapeUtil::HumanString(output_shape); + // If one dimension is reversed, we need to have all dimensions reversed (so + // we're doing convolution not cross correlation). + const bool dims_reversed = window.dimensions()[0].window_reversal(); + CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size()); CHECK_EQ(num_dimensions, dnums.kernel_spatial_dimensions_size()); CHECK_EQ(num_dimensions, dnums.output_spatial_dimensions_size()); for (const WindowDimension& dim : window.dimensions()) { + CHECK_EQ(dims_reversed, dim.window_reversal()); CHECK_EQ(dim.padding_low(), dim.padding_high()); CHECK_EQ(dim.base_dilation(), 1) << "cudnn does not support base dilation; it " @@ -198,6 +204,7 @@ Status RunCudnnConvImpl(CudnnConvParams params, ConvolutionDescriptor convolution_descriptor(effective_num_dimensions); convolution_descriptor.set_group_count(feature_group_count); + convolution_descriptor.set_convolution_not_crosscorr(dims_reversed); for (int dim = 0; dim < num_dimensions; ++dim) { convolution_descriptor .set_zero_padding( @@ -363,14 +370,12 @@ StatusOr GetCudnnConvParams( params.output_shape = &conv_result_shape; params.fusion.emplace(); auto& fusion = *params.fusion; - if (backend_config.activation_mode() < - static_cast(se::dnn::ActivationMode::kNumActivationModes)) { - fusion.mode = static_cast( - backend_config.activation_mode()); - } else { + if (!se::dnn::ActivationMode_IsValid(backend_config.activation_mode())) { return InternalError("Bad activation mode: %s", backend_config.ShortDebugString()); } + fusion.mode = static_cast( + backend_config.activation_mode()); fusion.side_input_scale = backend_config.side_input_scale(); params.input_buf = operand_buffers[0]; params.filter_buf = operand_buffers[1]; diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 6dcdaf1cfe0..2ab754a4710 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -161,6 +161,16 @@ StatusOr GpuElementalIrEmitter::EmitFloatBinaryOp( PrimitiveType lhs_input_type = op->operand(0)->shape().element_type(); PrimitiveType rhs_input_type = op->operand(1)->shape().element_type(); PrimitiveType output_type = op->shape().element_type(); + HloOpcode opcode = op->opcode(); + + if (hlo_module_config_.debug_options().xla_gpu_enable_fast_min_max() && + (opcode == HloOpcode::kMaximum || opcode == HloOpcode::kMinimum)) { + return llvm_ir::EmitCallToIntrinsic( + opcode == HloOpcode::kMaximum ? llvm::Intrinsic::maxnum + : llvm::Intrinsic::minnum, + {lhs_value, rhs_value}, {lhs_value->getType()}, b_); + } + switch (op->opcode()) { case HloOpcode::kRemainder: { return EmitLibdeviceMathCall("__nv_fmod", {lhs_value, rhs_value}, diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc index 30c1f908896..470457935ac 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc @@ -229,7 +229,7 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { if (!absl::c_all_of(fusion->users(), [&](const HloInstruction* user) { return user->opcode() == HloOpcode::kFusion && (user->fusion_kind() == HloInstruction::FusionKind::kLoop || - (user->fusion_kind() == HloInstruction::FusionKind::kInput && + (IsReduceInputFusion(*user) && LayoutsAreReduceInputFusionFriendly(*fusion, *user))); })) { VLOG(3) << "Not merging " << fusion->name() diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 57426327822..ae2e718db29 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -51,7 +51,7 @@ GpuExecutable::GpuExecutable( const string& ptx, const std::vector& cubin, std::pair compute_capability, std::unique_ptr thunk_schedule, - std::unique_ptr hlo_module, + std::unique_ptr hlo_module, std::unique_ptr assignment, std::unique_ptr hlo_profile_printer_data, std::unique_ptr hlo_profile_index_map) diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index 0e276282e40..2b3c77f5b82 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -54,7 +54,7 @@ class GpuExecutable : public Executable { GpuExecutable(const string& ptx, const std::vector& cubin, std::pair compute_capability, std::unique_ptr thunk_schedule, - std::unique_ptr hlo_module, + std::unique_ptr hlo_module, std::unique_ptr assignment, std::unique_ptr hlo_profile_printer_data, std::unique_ptr hlo_profile_index_map); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc index 2d31fd5570c..452e763a8ea 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc @@ -55,7 +55,7 @@ bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer, }); } -bool IsInputFusibleReduction(const HloInstruction& instr) { +bool IsReduceInputFusion(const HloInstruction& instr) { if (instr.IsMultiOutputFusion()) { for (const HloInstruction* operand : instr.fused_expression_root()->operands()) { @@ -67,17 +67,70 @@ bool IsInputFusibleReduction(const HloInstruction& instr) { return true; } } - return false; - } else if (instr.opcode() == HloOpcode::kFusion) { - if (IsReductionToVector(*instr.fused_expression_root())) { - CHECK(instr.fusion_kind() == HloInstruction::FusionKind::kInput) - << " Fusion rooted at reduction-to-vector op must be of kind kInput: " - << instr.ToString(); - return true; + } else if (instr.opcode() == HloOpcode::kFusion && + IsReductionToVector(*instr.fused_expression_root())) { + CHECK(instr.fusion_kind() == HloInstruction::FusionKind::kInput) + << " Fusion rooted at reduction-to-vector op must be of kind kInput: " + << instr.ToString(); + return true; + } + return false; +} + +bool IsInputFusibleReduction(const HloInstruction& instr) { + return IsReduceInputFusion(instr) || IsReductionToVector(instr); +} + +bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1, + const HloInstruction& instr2) { + // Returns the instructions that determines the emitter used for lowering, + // sometimes referred to as "the real hero". + auto get_real_hero = + [&](const HloInstruction* instr) -> const HloInstruction* { + if (instr->opcode() == HloOpcode::kFusion) { + auto fused_expression_root = instr->fused_expression_root(); + if (instr->IsMultiOutputFusion()) { + // If possible, we want to pick a reduction-to-vector operand of the + // fusion root, because it has the most constraints. + for (const auto* inst : fused_expression_root->operands()) { + if (IsReductionToVector(*inst)) { + return inst; + } + } + return fused_expression_root->operands()[0]; + } + return fused_expression_root; } + return instr; + }; + + // Multi-output fusion kernels share a common parallel loop. The loop + // dimenstions are determined by instruction shapes. + auto get_loop_shape = [&](const HloInstruction* element_instr) { + // Special-case reduction-to-vector ops: The loop dimensions are determined + // by the shape of the first operand. + if (IsReductionToVector(*element_instr)) { + return element_instr->operand(0)->shape(); + } + return element_instr->shape(); + }; + + // All shapes of the root tuple of multi-output fusions should agree, i.e. all + // root ops should have equal output shapes. An exception are + // reduction-to-vector ops. Here the input shapes of the reduction (first + // operand shape) and the reduction dimensions need to match. + auto* instr_1 = get_real_hero(&instr1); + auto* instr_2 = get_real_hero(&instr2); + // TODO(tjoerg): Relax the shape constraint. The datatype does not matter. + if (IsReductionToVector(*instr_1) && IsReductionToVector(*instr_2) && + (!ShapeUtil::Equal(instr_1->shape(), instr_2->shape()) || + instr_1->dimensions() != instr_2->dimensions())) { return false; } - return IsReductionToVector(instr); + // The elementwise output shapes must be the same (including layout). + // TODO(tjoerg): Further relax the constraint. The datatype does not matter. + return ShapeUtil::EqualIgnoringFpPrecision(get_loop_shape(instr_1), + get_loop_shape(instr_2)); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h index f7c24a0d5bb..e9d7ba1c4cf 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h @@ -33,16 +33,29 @@ namespace gpu { bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer, const HloInstruction& reduce); -// Whether `instr` is fusible as root of a reduce input fusions, i.e. `instr` -// is either an unfused reduction-to-vector op, an input fusion rooted at a -// reduction-to-vector op, or a multi-output input fusion with at least one -// reduction-to-vector op root. // Note that reduction ops are lowered in different ways. Reduce input fusions // are lowered by IrEmitterUnnested::EmitReductionToVector and must be rooted at // reduction-to-vector ops. Other reduction ops are lowered by // GpuElementalIrEmitter and fused like elementwise ops. + +// Whether `instr` is an input fusion rooted at a reduction-to-vector op or a +// multi-output input fusion with at least one reduction-to-vector op root. +bool IsReduceInputFusion(const HloInstruction& instr); + +// Whether `instr` is fusible as root of a reduce input fusions, i.e. `instr` +// is either an unfused reduction-to-vector op or a reduce input fusion. bool IsInputFusibleReduction(const HloInstruction& instr); +// Whether instruction shapes are compatible for multi-output fusion, i.e. +// whether the emitters support lowering the resulting fusion. +// This function works for both, sibling and producer-conumser multi-output +// fusion. +// So far, multi-output fusion is supported for loop fusions and reduce +// input fusions only. It is up to the caller to ensure the instructions +// themselves are fusible! +bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1, + const HloInstruction& instr2); + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc index d91b7bc61fd..15d4ee206ce 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc @@ -178,7 +178,7 @@ TEST_F(GpuFusibleTest, EXPECT_TRUE(LayoutsAreReduceInputFusionFriendly(*loop_fusion, *reduce)); } -TEST_F(GpuFusibleTest, IsInputFusibleReduction_ReductionToVector) { +TEST_F(GpuFusibleTest, IsReduceInputFusion_ReductionToVector) { auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( ENTRY entry { c0 = f32[] parameter(0) @@ -191,10 +191,11 @@ TEST_F(GpuFusibleTest, IsInputFusibleReduction_ReductionToVector) { const HloInstruction* reduce = module->entry_computation()->root_instruction(); ASSERT_EQ(reduce->opcode(), HloOpcode::kReduce); + EXPECT_FALSE(IsReduceInputFusion(*reduce)); EXPECT_TRUE(IsInputFusibleReduction(*reduce)); } -TEST_F(GpuFusibleTest, IsInputFusibleReduction_ElementalReduction) { +TEST_F(GpuFusibleTest, IsReduceInputFusion_ElementalReduction) { auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( ENTRY entry { c0 = f32[] parameter(0) @@ -207,10 +208,11 @@ TEST_F(GpuFusibleTest, IsInputFusibleReduction_ElementalReduction) { const HloInstruction* reduce = module->entry_computation()->root_instruction(); ASSERT_EQ(reduce->opcode(), HloOpcode::kReduce); + EXPECT_FALSE(IsReduceInputFusion(*reduce)); EXPECT_FALSE(IsInputFusibleReduction(*reduce)); } -TEST_F(GpuFusibleTest, IsInputFusibleReduction_SingleOutputInputReduceFusion) { +TEST_F(GpuFusibleTest, IsReduceInputFusion_SingleOutputInputReduceFusion) { auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_reduction { c0 = f32[] parameter(0) @@ -225,10 +227,11 @@ TEST_F(GpuFusibleTest, IsInputFusibleReduction_SingleOutputInputReduceFusion) { const HloInstruction* reduce = module->entry_computation()->root_instruction(); ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion); + EXPECT_TRUE(IsReduceInputFusion(*reduce)); EXPECT_TRUE(IsInputFusibleReduction(*reduce)); } -TEST_F(GpuFusibleTest, IsInputFusibleReduction_SingleOutputLoopReduceFusion) { +TEST_F(GpuFusibleTest, IsReduceInputFusion_SingleOutputLoopReduceFusion) { auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_reduction { c0 = f32[] parameter(0) @@ -243,10 +246,11 @@ TEST_F(GpuFusibleTest, IsInputFusibleReduction_SingleOutputLoopReduceFusion) { const HloInstruction* reduce = module->entry_computation()->root_instruction(); ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion); + EXPECT_FALSE(IsReduceInputFusion(*reduce)); EXPECT_FALSE(IsInputFusibleReduction(*reduce)); } -TEST_F(GpuFusibleTest, IsInputFusibleReduction_MultiOutputInputReduceFusion) { +TEST_F(GpuFusibleTest, IsReduceInputFusion_MultiOutputInputReduceFusion) { auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_reduction { c0 = f32[] parameter(0) @@ -263,11 +267,12 @@ TEST_F(GpuFusibleTest, IsInputFusibleReduction_MultiOutputInputReduceFusion) { const HloInstruction* reduce = module->entry_computation()->root_instruction(); ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion); + EXPECT_TRUE(IsReduceInputFusion(*reduce)); EXPECT_TRUE(IsInputFusibleReduction(*reduce)); } TEST_F(GpuFusibleTest, - IsInputFusibleReduction_MultiOutputInputReduceFusionWithExtraOutputs) { + IsReduceInputFusion_MultiOutputInputReduceFusionWithExtraOutputs) { auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_reduction { c0 = f32[] parameter(0) @@ -284,10 +289,11 @@ TEST_F(GpuFusibleTest, const HloInstruction* reduce = module->entry_computation()->root_instruction(); ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion); + EXPECT_TRUE(IsReduceInputFusion(*reduce)); EXPECT_TRUE(IsInputFusibleReduction(*reduce)); } -TEST_F(GpuFusibleTest, IsInputFusibleReduction_MultiOutputLoopReduceFusion) { +TEST_F(GpuFusibleTest, IsReduceInputFusion_MultiOutputLoopReduceFusion) { auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_reduction { c0 = f32[] parameter(0) @@ -304,11 +310,12 @@ TEST_F(GpuFusibleTest, IsInputFusibleReduction_MultiOutputLoopReduceFusion) { const HloInstruction* reduce = module->entry_computation()->root_instruction(); ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion); + EXPECT_FALSE(IsReduceInputFusion(*reduce)); EXPECT_FALSE(IsInputFusibleReduction(*reduce)); } TEST_F(GpuFusibleTest, - IsInputFusibleReduction_MultiOutputLoopFusionReduceAndElementwiseOp) { + IsReduceInputFusion_MultiOutputLoopFusionReduceAndElementwiseOp) { auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_reduction { c0 = f32[] parameter(0) @@ -325,8 +332,304 @@ TEST_F(GpuFusibleTest, const HloInstruction* reduce = module->entry_computation()->root_instruction(); ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion); + EXPECT_FALSE(IsReduceInputFusion(*reduce)); EXPECT_FALSE(IsInputFusibleReduction(*reduce)); } +TEST_F(GpuFusibleTest, ShapesCompatibleForMultiOutputFusion_LoopFusions) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_computation_1 { + p0.1 = f32[6400]{0} parameter(0) + ROOT mul = f32[6400]{0} multiply(p0.1, p0.1) + } + + fused_computation_2 { + p0.2 = f32[6400]{0} parameter(0) + const.2 = f32[] constant(1) + ROOT div = f32[6400]{0} divide(p0.2, const.2) + } + + ENTRY entry { + p0 = f32[6400]{0} parameter(0) + fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1 + fusion.2 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_2 + ROOT root = (f32[6400]{0}, f32[6400]{0}) tuple(fusion.1, fusion.2) + })")) + .ValueOrDie(); + const HloInstruction* fusion_1 = + module->entry_computation()->root_instruction()->operand(0); + const HloInstruction* fusion_2 = + module->entry_computation()->root_instruction()->operand(1); + EXPECT_TRUE(ShapesCompatibleForMultiOutputFusion(*fusion_1, *fusion_2)); +} + +TEST_F(GpuFusibleTest, ShapesCompatibleForMultiOutputFusion_IgnoreFpPrecision) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_computation_1 { + p0.1 = f32[6400]{0} parameter(0) + ROOT mul = f32[6400]{0} multiply(p0.1, p0.1) + } + + fused_computation_2 { + p0.2 = f32[6400]{0} parameter(0) + ROOT convert = f16[6400]{0} convert(p0.2) + } + + ENTRY entry { + p0 = f32[6400]{0} parameter(0) + fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1 + fusion.2 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_2 + ROOT root = (f32[6400]{0}, f32[6400]{0}) tuple(fusion.1, fusion.2) + })")) + .ValueOrDie(); + const HloInstruction* fusion_1 = + module->entry_computation()->root_instruction()->operand(0); + const HloInstruction* fusion_2 = + module->entry_computation()->root_instruction()->operand(1); + EXPECT_TRUE(ShapesCompatibleForMultiOutputFusion(*fusion_1, *fusion_2)); +} + +TEST_F(GpuFusibleTest, ShapesCompatibleForMultiOutputFusion_Reduce) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_computation_1 { + p0.1 = f32[6400]{0} parameter(0) + ROOT mul = f32[6400]{0} multiply(p0.1, p0.1) + } + + ENTRY entry { + p0 = f32[6400]{0} parameter(0) + fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1 + const.2 = f32[] constant(0) + reduce = f32[] reduce(p0, const.2), dimensions={0}, to_apply=scalar_add + ROOT root = (f32[6400]{0}, f32[]) tuple(fusion.1, reduce) + })")) + .ValueOrDie(); + const HloInstruction* fusion = + module->entry_computation()->root_instruction()->operand(0); + const HloInstruction* reduce = + module->entry_computation()->root_instruction()->operand(1); + EXPECT_TRUE(ShapesCompatibleForMultiOutputFusion(*fusion, *reduce)); +} + +TEST_F(GpuFusibleTest, ShapesCompatibleForMultiOutputFusion_Elementwise) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_computation_1 { + p0.1 = f32[6400]{0} parameter(0) + ROOT mul = f32[6400]{0} multiply(p0.1, p0.1) + } + + ENTRY entry { + p0 = f32[6400]{0} parameter(0) + fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1 + const.2 = f32[] constant(1) + div = f32[6400]{0} divide(p0, const.2) + ROOT root = (f32[6400]{0}, f32[6400]{0}) tuple(fusion.1, div) + })")) + .ValueOrDie(); + const HloInstruction* fusion = + module->entry_computation()->root_instruction()->operand(0); + const HloInstruction* div = + module->entry_computation()->root_instruction()->operand(1); + EXPECT_TRUE(ShapesCompatibleForMultiOutputFusion(*fusion, *div)); +} + +TEST_F(GpuFusibleTest, + ShapesCompatibleForMultiOutputFusion_MultiOutputLoopFusion) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_computation_1 { + p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) + mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1) + exp = f32[8,1,5,16,1,1]{5,4,3,2,1,0} exponential(p0.1) + ROOT tuple = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(mul, exp) + } + + fused_computation_2 { + p0.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) + const.2 = f32[] constant(0) + ROOT add = f32[8,1,5,16,1,1]{5,4,3,2,1,0} add(p0.2, const.2) + } + + ENTRY entry { + p0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) + fusion.1 = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) fusion(p0), kind=kLoop, calls=fused_computation_1 + fusion.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_2 + gte0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=0 + gte1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=1 + ROOT root = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(gte0, gte1, fusion.2) + })")) + .ValueOrDie(); + const HloInstruction* fusion_1 = + module->entry_computation()->root_instruction()->operand(0)->operand(0); + const HloInstruction* fusion_2 = + module->entry_computation()->root_instruction()->operand(1)->operand(0); + EXPECT_TRUE(ShapesCompatibleForMultiOutputFusion(*fusion_1, *fusion_2)); +} + +TEST_F(GpuFusibleTest, ShapesCompatibleForMultiOutputFusion_UnfusedOps) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + ENTRY reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + exp = f32[2,2,2]{2,1,0} exponential(p0) + reduce = f32[2,2]{1,0} reduce(exp, c0), dimensions={2}, to_apply=scalar_add + ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(reduce, exp) + })")) + .ValueOrDie(); + const HloInstruction* reduce = + module->entry_computation()->root_instruction()->operand(0); + const HloInstruction* exp = + module->entry_computation()->root_instruction()->operand(1); + EXPECT_TRUE(ShapesCompatibleForMultiOutputFusion(*reduce, *exp)); +} + +TEST_F(GpuFusibleTest, ShapesCompatibleForMultiOutputFusion_DifferentLayouts) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + ENTRY reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + p1 = f32[2,2,2]{0,1,2} parameter(1) + c0 = f32[] constant(0) + exp = f32[2,2,2]{2,1,0} exponential(p0) + reduce = f32[2,2]{0,1} reduce(p1, c0), dimensions={2}, to_apply=scalar_add + ROOT root = (f32[2,2]{0,1}, f32[2,2,2]{2,1,0}) tuple(reduce, exp) + })")) + .ValueOrDie(); + const HloInstruction* reduce = + module->entry_computation()->root_instruction()->operand(0); + const HloInstruction* exp = + module->entry_computation()->root_instruction()->operand(1); + EXPECT_FALSE(ShapesCompatibleForMultiOutputFusion(*reduce, *exp)); +} + +TEST_F(GpuFusibleTest, + ShapesCompatibleForMultiOutputFusion_MultiOutputReduceFusion) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_select { + p1.1 = f32[2,2,2]{2,1,0} parameter(1) + c0 = f32[] constant(0) + broadcast = f32[2,2,2]{2,1,0} broadcast(f32[] c0), dimensions={} + greater-than = pred[2,2,2]{2,1,0} greater-than(f32[2,2,2]{2,1,0} p1.1, f32[2,2,2]{2,1,0} broadcast) + p0.1 = f32[2,2,2]{2,1,0} parameter(0) + ROOT select = f32[2,2,2]{2,1,0} select(pred[2,2,2]{2,1,0} greater-than, f32[2,2,2]{2,1,0} p0.1, f32[2,2,2]{2,1,0} broadcast) + } + + fused_reduce { + p0.2 = f32[2,2,2]{2,1,0} parameter(0) + c1 = f32[] constant(0) + r1 = f32[2,2]{1,0} reduce(p0.2, c1), dimensions={2}, to_apply=scalar_add + mul = f32[2,2,2]{2,1,0} multiply(p0.2, p0.2) + r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=scalar_add + ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2) + } + + ENTRY reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + p1 = f32[2,2,2]{2,1,0} parameter(1) + select = f32[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_select + fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(select), kind=kInput, calls=fused_reduce + gte0 = f32[2,2]{1,0} get-tuple-element(fusion), index=0 + gte1 = f32[2,2]{1,0} get-tuple-element(fusion), index=1 + ROOT root = (f32[2,2]{1,0}, f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(gte1, gte1, select) + })")) + .ValueOrDie(); + const HloInstruction* fusion_1 = + module->entry_computation()->root_instruction()->operand(0)->operand(0); + const HloInstruction* fusion_2 = + module->entry_computation()->root_instruction()->operand(1)->operand(0); + EXPECT_TRUE(ShapesCompatibleForMultiOutputFusion(*fusion_1, *fusion_2)); +} + +TEST_F(GpuFusibleTest, ShapesCompatibleForMultiOutputFusion_ReduceFusions) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_reduce_1 { + p0.1 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + ROOT reduce = f32[2,2]{1,0} reduce(f32[2,2,2]{2,1,0} p0.1, f32[] c0), dimensions={0}, to_apply=scalar_add + } + + fused_reduce_2 { + p0.2 = f32[2,2,2]{2,1,0} parameter(0) + mul = f32[2,2,2]{2,1,0} multiply(f32[2,2,2]{2,1,0} p0.2, f32[2,2,2]{2,1,0} p0.2) + c1 = f32[] constant(0) + ROOT reduce = f32[2,2]{1,0} reduce(f32[2,2,2]{2,1,0} mul, f32[] c1), dimensions={0}, to_apply=scalar_add + } + + ENTRY reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + p1 = f32[2,2,2]{2,1,0} parameter(1) + reduce_1 = f32[2,2]{1,0} fusion(p0), kind=kLoop, calls=fused_reduce_1 + reduce_2 = f32[2,2]{1,0} fusion(p1), kind=kLoop, calls=fused_reduce_2 + ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(reduce_1, reduce_2) + })")) + .ValueOrDie(); + const HloInstruction* fusion_1 = + module->entry_computation()->root_instruction()->operand(0); + const HloInstruction* fusion_2 = + module->entry_computation()->root_instruction()->operand(1); + EXPECT_TRUE(ShapesCompatibleForMultiOutputFusion(*fusion_1, *fusion_2)); +} + +TEST_F(GpuFusibleTest, + ShapesCompatibleForMultiOutputFusion_DifferentReduceDimensions) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_reduce_1 { + p0.1 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + ROOT reduce = f32[2,2]{1,0} reduce(f32[2,2,2]{2,1,0} p0.1, f32[] c0), dimensions={0}, to_apply=scalar_add + } + + fused_reduce_2 { + p0.2 = f32[2,2,2]{2,1,0} parameter(0) + mul = f32[2,2,2]{2,1,0} multiply(f32[2,2,2]{2,1,0} p0.2, f32[2,2,2]{2,1,0} p0.2) + c1 = f32[] constant(0) + ROOT reduce = f32[2,2]{1,0} reduce(f32[2,2,2]{2,1,0} mul, f32[] c1), dimensions={2}, to_apply=scalar_add + } + + ENTRY reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + p1 = f32[2,2,2]{2,1,0} parameter(1) + reduce_1 = f32[2,2]{1,0} fusion(p0), kind=kLoop, calls=fused_reduce_1 + reduce_2 = f32[2,2]{1,0} fusion(p1), kind=kLoop, calls=fused_reduce_2 + ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(reduce_1, reduce_2) + })")) + .ValueOrDie(); + const HloInstruction* fusion_1 = + module->entry_computation()->root_instruction()->operand(0); + const HloInstruction* fusion_2 = + module->entry_computation()->root_instruction()->operand(1); + EXPECT_FALSE(ShapesCompatibleForMultiOutputFusion(*fusion_1, *fusion_2)); +} + +TEST_F(GpuFusibleTest, + ShapesCompatibleForMultiOutputFusion_NoReductionToVector) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_element_wise { + p0.1 = f32[2,2,2]{2,1,0} parameter(0) + p1.1 = f32[2,2,2]{2,1,0} parameter(1) + ROOT add = f32[2,2,2]{2,1,0} add(p0.1, p1.1) + } + + fused_reduce { + p0.2 = f32[2,2,2]{2,1,0} parameter(0) + mul = f32[2,2,2]{2,1,0} multiply(f32[2,2,2]{2,1,0} p0.2, f32[2,2,2]{2,1,0} p0.2) + c1 = f32[] constant(0) + // Note that reduce is not a reduction-to-vector. + ROOT reduce = f32[2,2]{1,0} reduce(f32[2,2,2]{2,1,0} mul, f32[] c1), dimensions={1}, to_apply=scalar_add + } + + ENTRY reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + p1 = f32[2,2,2]{2,1,0} parameter(1) + element_wise = f32[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_element_wise + fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(element_wise), kind=kLoop, calls=fused_reduce + ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(fusion, element_wise) + })")) + .ValueOrDie(); + const HloInstruction* fusion_1 = + module->entry_computation()->root_instruction()->operand(0); + const HloInstruction* fusion_2 = + module->entry_computation()->root_instruction()->operand(1); + EXPECT_FALSE(ShapesCompatibleForMultiOutputFusion(*fusion_1, *fusion_2)); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc index 91609c730b6..1126943624a 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc @@ -37,7 +37,7 @@ class GpuHloOrdering : public PredecessorHloOrdering { public: GpuHloOrdering(const HloModule* module, const StreamAssignment& stream_assignment, - const std::vector& thunk_launch_order); + const std::vector& thunk_launch_order); ~GpuHloOrdering() override = default; // Only the entry computation can possibly be sequentially ordered, and only @@ -56,7 +56,7 @@ class GpuHloOrdering : public PredecessorHloOrdering { GpuHloOrdering::GpuHloOrdering( const HloModule* module, const StreamAssignment& stream_assignment, - const std::vector& thunk_launch_order) + const std::vector& thunk_launch_order) : PredecessorHloOrdering(module) { // The entry computation has a total order when there's only one stream. if (stream_assignment.StreamCount() == 1) { @@ -150,7 +150,7 @@ GpuHloOrdering::GpuHloOrdering( // However, if the total order is A,B,D,C,E, then C and E can run // concurrently. void BFSLaunchOrder(const HloComputation* computation, - std::vector* launch_order) { + std::vector* launch_order) { // This topological sort uses two data structures: // 1. `incoming_edge_count` which keeps track of the number of incoming // edges to each HLO; @@ -158,9 +158,9 @@ void BFSLaunchOrder(const HloComputation* computation, // // The sorting algorithm repeatedly pops the top from the queue and deletes // that HLO from the graph, making more HLOs incoming-edge free. - std::deque queue; + std::deque queue; std::unordered_map incoming_edge_count; - for (const auto& hlo : computation->instructions()) { + for (auto* hlo : computation->instructions()) { if (hlo->operand_count() == 0) { queue.push_back(hlo); } else { @@ -172,10 +172,10 @@ void BFSLaunchOrder(const HloComputation* computation, } while (!queue.empty()) { - const HloInstruction* x = queue.front(); + HloInstruction* x = queue.front(); queue.pop_front(); launch_order->push_back(x); - for (const HloInstruction* y : x->users()) { + for (HloInstruction* y : x->users()) { --incoming_edge_count[y]; if (incoming_edge_count[y] == 0) { queue.push_back(y); @@ -195,14 +195,14 @@ StatusOr> GpuHloSchedule::Build( std::unique_ptr schedule(new GpuHloSchedule); // Initialize thunk_launch_order_, the total order of thunk launches. - const HloComputation* entry_computation = module.entry_computation(); + HloComputation* entry_computation = module.entry_computation(); if (stream_assignment.StreamCount() == 1) { // All kernels are launched on a single stream, so there's no loss of // concurrency by optimizing for minimal memory usage. TF_ASSIGN_OR_RETURN( HloInstructionSequence sequence, ScheduleComputation( - *entry_computation, [pointer_size](const BufferValue& buffer) { + entry_computation, [pointer_size](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size); })); schedule->thunk_launch_order_ = sequence.instructions(); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h index 07a7fc67aa5..7f224ffe4f0 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h @@ -46,7 +46,7 @@ class GpuHloSchedule { // Returns the total order of thunk launches, represented in terms of HLO // instructions. - const std::vector& ThunkLaunchOrder() const { + const std::vector& ThunkLaunchOrder() const { return thunk_launch_order_; } @@ -60,7 +60,7 @@ class GpuHloSchedule { private: GpuHloSchedule(); - std::vector thunk_launch_order_; + std::vector thunk_launch_order_; std::unique_ptr hlo_ordering_; }; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc index 6d3aed15ebe..91db7151f22 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc @@ -33,7 +33,7 @@ namespace gpu { class GpuHloScheduleTest : public HloTestBase { protected: - using HloVec = std::vector; + using HloVec = std::vector; // Pre-canned shapes. Shape f32_2x2_ = ShapeUtil::MakeShape(F32, {2, 2}); @@ -44,7 +44,7 @@ class GpuHloScheduleTest : public HloTestBase { .ConsumeValueOrDie(); } - std::unique_ptr CreateNewUnverifiedModule() { + std::unique_ptr CreateNewVerifiedModule() { HloModuleConfig config; auto debug_options = GetDebugOptionsForTest(); debug_options.set_xla_gpu_disable_multi_streaming(false); @@ -79,7 +79,7 @@ TEST_F(GpuHloScheduleTest, SequentialMatMul) { HloInstruction* dot2 = builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, z)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build(dot2)); std::unique_ptr streams = AssignStreams(*module); @@ -139,7 +139,7 @@ TEST_F(GpuHloScheduleTest, SequentialAdd) { HloInstruction* add3 = builder.AddInstruction( HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, add1, add2)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build(add3)); std::unique_ptr streams = AssignStreams(*module); @@ -209,7 +209,7 @@ TEST_F(GpuHloScheduleTest, ConcurrentMatMul) { HloInstruction* add = builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, dot2)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build(add)); std::unique_ptr streams = AssignStreams(*module); @@ -288,7 +288,7 @@ TEST_F(GpuHloScheduleTest, LatticeMatMul) { HloInstruction* d40 = builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d30, d31)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build(d40)); std::unique_ptr streams = AssignStreams(*module); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc index 1c0a23fa3eb..f59da2caa18 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -65,8 +65,8 @@ HeuristicLayoutAssignment(const HloInstruction* instr, VLOG(2) << "Using heuristic to figure out layouts for " << instr->ToString(); - // Empirically we've found with Volta and cudnn 7 that backward-input convs - // with stride are significantly faster with NCHW layouts. + // Empirically we've found with Volta and cudnn <= 7.3 that backward-input + // convs with stride are significantly faster with NCHW layouts. // // We could have used a mixed layout combination, e.g. (NHWC, NCHW, NCHW), // which on paper gives good performance. However, there are two observations: @@ -75,11 +75,17 @@ HeuristicLayoutAssignment(const HloInstruction* instr, // * we've also observed that for mixed layouts, cuDNN transposes data back // and forth from a different layout combination. If we end up with // transposes anyway, we prefer to have them in XLA, as they can be fused. - // TODO(timshen): Figure out the exact condition. This may be achieved by - // auto-tuning layouts offline. - if (instr->custom_call_target() == kCudnnConvBackwardInputCallTarget && - window_util::HasStride(instr->window())) { - return kAllNCHW; + if (auto* dnn = stream_executor->AsDnn()) { + auto version_status = dnn->GetVersion(); + if (version_status.ok()) { + auto version = version_status.ConsumeValueOrDie(); + if (std::make_tuple(version.major_version(), version.minor_version()) <= + std::make_tuple(7, 3) && + instr->custom_call_target() == kCudnnConvBackwardInputCallTarget && + window_util::HasStride(instr->window())) { + return kAllNCHW; + } + } } // For other Volta f16 convolutions, use NHWC. diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index 8cc76c872c6..2ffc8bfb49b 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -61,7 +61,7 @@ TEST_F(LayoutAssignmentTest, Elementwise) { HloInstruction::CreateParameter(1, ashape, "y")); auto add = builder.AddInstruction( HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, x, y)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build(add)); @@ -148,7 +148,7 @@ TEST_F(LayoutAssignmentTest, BatchNormInference) { {operand, scale, offset, mean, variance, epsilon, feature_index}, kCudnnBatchNormForwardInferenceCallTarget)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build(batchnorm)); @@ -217,7 +217,7 @@ TEST_F(LayoutAssignmentTest, BatchNormTraining) { batchnorm_shape, {operand, scale, offset, epsilon, feature_index}, kCudnnBatchNormForwardTrainingCallTarget)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build(batchnorm)); @@ -298,7 +298,7 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) { feature_index}, kCudnnBatchNormBackwardCallTarget)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build(batchnorm)); diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index 43f43b50e4a..6151dd8ff4c 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -80,7 +80,7 @@ bool IsIEEEFloatingPointScalarConstant(const HloInstruction* constant) { // This function limits the maximum number of operands to a fusion. // // There's a cap on how many parameters we can pass to a CUDA kernel, but -// exactly what that limit is is hazy, as it depends on (among other things) how +// exactly what that limit is hazy, as it depends on (among other things) how // much GPU constant memory is in use for other purposes. // // Moreover, we don't even know at the point that we're running fusion how many @@ -181,7 +181,8 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return true; } } else if (consumer->operand_count() == 2 && - consumer->opcode() == HloOpcode::kAdd) { + consumer->opcode() == HloOpcode::kAdd && + consumer->operand(other_operand_index) != producer) { // Fuse a bias add into the output of the dot. return true; } diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index fb77bc4b8eb..688604cd36e 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -117,7 +117,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfDotUnfused) { auto reshape2 = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {1, 1, 1}), dot1)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(reshape2, computation->root_instruction()); EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) @@ -134,7 +134,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) { auto transpose2 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {1, 1}), dot1, {0, 1})); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(transpose2, computation->root_instruction()); EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) @@ -358,6 +358,29 @@ TEST_F(InstructionFusionTest, DotOutputFusionBiasAdd) { op::Parameter())); } +TEST_F(InstructionFusionTest, + DotOperationFusion_DontOutputFuseDuplicateOperands) { + absl::string_view module_string = R"( +HloModule module + +ENTRY main { + a = f32[50,60]{1,0} parameter(0) + b = f32[60,1]{1,0} parameter(1) + c = f32[50,1]{1,0} dot(a, b), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT d = f32[50,1]{1,0} add(c, c) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool fused_something, + GpuInstructionFusion(/*may_duplicate=*/false).Run(module.get())); + EXPECT_FALSE(fused_something); + EXPECT_THAT(module->entry_computation()->root_instruction(), + Not(op::Fusion())); +} + // Compute sum(1/p0), where p0 has type f32, twice. Check that the division is // duplicated and fused into both reduces. TEST_F(InstructionFusionTest, FloatingPointDivIsCheap) { @@ -723,7 +746,7 @@ TEST_F(InstructionFusionTest, AvoidsLargeFusion) { sum = b.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sum, param)); } - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(b.Build()); EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) .Run(module.get()) diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 7fcdd805ed3..6693f66d62d 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -63,9 +63,6 @@ IrEmitter::IrEmitter(const HloModuleConfig& hlo_module_config, &ir_emitter_context->buffer_assignment(), &b_, module_, is_nested), hlo_module_config_(hlo_module_config) { - b_.setFastMathFlags(llvm_ir::GetFastMathFlags( - /*fast_math_enabled=*/hlo_module_config.debug_options() - .xla_gpu_enable_fast_math())); } Status IrEmitter::DefaultAction(HloInstruction* hlo) { @@ -97,6 +94,18 @@ Status IrEmitter::HandleBitcast(HloInstruction* bitcast) { return Status::OK(); } +Status IrEmitter::HandleAddDependency(HloInstruction* add_dependency) { + VLOG(2) << "HandleAddDependency: " << add_dependency->ToString(); + const HloInstruction* operand = add_dependency->operand(0); + // Add_Dependency is a no-op, but we still want to bind it to an llvm::Value + // sometimes, e.g., when it's operand is a constant or a bitcast of a + // constant. + if (bindings_.BoundToIrValue(*operand)) { + bindings_.BindHloToIrValue(*add_dependency, GetBasePointer(*operand)); + } + return Status::OK(); +} + Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) { auto operand = get_tuple_element->operand(0); CHECK(bindings_.BoundToIrValue(*operand)); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index 56c3f452006..2da46c01693 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -100,6 +100,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, Status HandleBatchNormInference(HloInstruction* batch_norm) override; Status HandleBatchNormTraining(HloInstruction* batch_norm) override; Status HandleBatchNormGrad(HloInstruction* batch_norm) override; + Status HandleAddDependency(HloInstruction* add_dependency) override; Status FinishVisit(HloInstruction* root) override { return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 87b6cd640ac..bbe1583c011 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" #include "absl/algorithm/container.h" -#include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/types/optional.h" @@ -65,11 +64,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" -#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" @@ -88,6 +87,8 @@ limitations under the License. namespace xla { namespace gpu { +using llvm_ir::KernelMappingScheme; + namespace { using absl::InlinedVector; @@ -546,91 +547,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { // TODO(b/112040122): Support variadic reduce. return Unimplemented("Variadic reduce is not supported on GPU"); } - VLOG(3) << "Emitting fused reduction to vector: " << fusion->ToString(); - std::vector> thunks; - absl::Span output_instructions = - root->opcode() == HloOpcode::kTuple - ? root->operands() - : absl::Span(&root, 1); - - // For multi-output fusion emit an initializer for each tuple element. - // Otherwise it's sufficient to just initialize the single output. - HloInstruction* first_reduce = nullptr; - for (int i = 0, e = output_instructions.size(); i != e; ++i) { - if (output_instructions[i]->opcode() == HloOpcode::kReduce) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr initializer_thunk, - BuildInitializerThunk(fusion, output_instructions[i] == root - ? ShapeIndex() - : ShapeIndex({i}))); - thunks.push_back(std::move(initializer_thunk)); - first_reduce = - first_reduce == nullptr ? output_instructions[i] : first_reduce; - } - } - CHECK(first_reduce != nullptr); - std::unique_ptr kernel_thunk = - BuildKernelThunk(fusion, /*implements_whole_instruction=*/false); - GpuElementalIrEmitter elemental_emitter( - hlo_module_config_, ir_emitter_context_->llvm_module(), &b_, - GetNestedComputer()); - FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(fusion), - &elemental_emitter); - TF_RETURN_IF_ERROR(root->Accept(&fused_emitter)); - - // For multi-output fusion CHECK the constraints and feed all the - // reduces into a single loop code generator. Single-output reduce - // fusion is a special case of that. - InlinedVector input_gens; - InlinedVector init_value_gens; - std::vector> - extra_output_gens; - InlinedVector reducers; - InlinedVector reduce_output_shapes; - for (int i = 0, e = output_instructions.size(); i != e; ++i) { - const HloInstruction* inst = output_instructions[i]; - ShapeIndex output_shape_index; - if (root->opcode() == HloOpcode::kTuple) { - output_shape_index = {i}; - } - if (inst->opcode() == HloOpcode::kReduce) { - CHECK(IsReductionToVector(*inst)) - << "Only reductions to vector are supported"; - // Shapes, layouts and dimensions must be the same for all reduces - // inside of this fusion. - CHECK(ShapeUtil::Equal(first_reduce->shape(), inst->shape())); - CHECK(ShapeUtil::Equal(first_reduce->operand(0)->shape(), - inst->operand(0)->shape())); - CHECK(ShapeUtil::Equal(first_reduce->operand(1)->shape(), - inst->operand(1)->shape())); - CHECK(first_reduce->dimensions() == inst->dimensions()); - input_gens.push_back(fused_emitter.GetGenerator(inst->operand(0))); - init_value_gens.push_back( - fused_emitter.GetGenerator(inst->operand(1))); - reducers.push_back(inst->to_apply()); - reduce_output_shapes.push_back(std::move(output_shape_index)); - } else { - // For extra outputs we can relax shape equality to allow different - // types (with the same number of elements). Layouts still have to - // match. - CHECK(ShapeUtil::CompatibleIgnoringElementType( - first_reduce->operand(0)->shape(), inst->shape())); - CHECK(LayoutUtil::Equal(first_reduce->operand(0)->shape().layout(), - inst->shape().layout())); - extra_output_gens.emplace_back(fused_emitter.GetGenerator(inst), - std::move(output_shape_index)); - } - } - const Shape& input_shape = first_reduce->operand(0)->shape(); - TF_CHECK_OK(EmitReductionToVector( - kernel_thunk.get(), first_reduce, input_shape, input_gens, - init_value_gens, first_reduce->dimensions(), reducers, - reduce_output_shapes, extra_output_gens)); - thunks.push_back(std::move(kernel_thunk)); - std::unique_ptr sequential_thunk = - absl::make_unique(std::move(thunks), fusion); - AddThunkToThunkSequence(std::move(sequential_thunk)); - return Status::OK(); + return EmitReductionToVector(fusion); } default: LOG(FATAL) << "Bad opcode for input fusion: " @@ -700,13 +617,12 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { } Status IrEmitterUnnested::EmitExtraOutputsForReduce( - const HloInstruction* reduce, const IrArray::Index& index, + const HloInstruction* unnested_hlo, const IrArray::Index& index, absl::Span> extra_output_gens) { for (int i = 0; i != extra_output_gens.size(); ++i) { - const HloInstruction* output = reduce->parent()->FusionInstruction(); llvm::Value* extra_output_address = - GetIrArray(*output, *output, extra_output_gens[i].second) + GetIrArray(*unnested_hlo, *unnested_hlo, extra_output_gens[i].second) .EmitArrayElementAddress(index, &b_, "extra_output_element_address"); TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value, @@ -716,984 +632,13 @@ Status IrEmitterUnnested::EmitExtraOutputsForReduce( return Status::OK(); } -Status IrEmitterUnnested::EmitReductionToScalar( - KernelThunk* kernel_thunk, HloInstruction* reduce, const Shape& input_shape, - absl::Span input_gens, - absl::Span init_value_gens, - absl::Span reducers, - absl::Span reduce_output_shapes, - absl::Span> - extra_output_gens) { - // Number of elements processed by a single thread. - constexpr int64 kTileSize = 16; - int64 num_elems = ShapeUtil::ElementsIn(input_shape); - - // Round up the number of tiles to a multiple of the warp size. This is - // necessary for correctness. We launch one thread per tile, and if the - // number of threads isn't a multiple of the number of the warp size, our - // shuffles will read from inactive threads, producing undefined values. - int64 num_tiles = - RoundUpToNearest(CeilOfRatio(num_elems, kTileSize), kWarpSize); - - Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( - reduce->shape().element_type(), {num_tiles}, {0}); - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - tiled_input_shape, ir_emitter_context_->device_description()); - - llvm::Type* index_ty = - GetIndexTypeForKernel(reduce, launch_dimensions.launch_bound(), &b_); - - auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { - return llvm::ConstantInt::get(index_ty, c); - }; - - // Check whether every thread will process a full tile's worth of elements - // without reading outside the bounds of the input. If this is true, we can - // skip some bounds checks in the final algorithm. - bool all_threads_in_bounds = num_tiles * kTileSize == num_elems; - - // __global__ void full_reduce_kernel() { - // x_in_tiles = threadIdx.x + blockIdx.x * blockDim.x; - // x = x_in_tiles * kTileSize; - // - // partial_result = init_value; - // if (all_threads_in_bounds || x + kTileSize <= num_elems) { - // for (i = 0; i < kTileSize; ++i) { - // partial_result = Reducer(partial_result, input[x + i]); - // } - // } else { - // for (i = 0; i < kTileSize; ++i) { - // if (x + i < num_elems) { - // partial_result = Reducer(partial_result, input[x + i]); - // } - // } - // } - // for (i = warpSize / 2; i > 0; i /= 2) { - // partial_result = Reducer(partial_result, - // __shfl_down(partial_result, i)); - // } - // if (lane_id == 0) { - // AtomicReducer(&output[y], partial_result); - // } - // } - // - // // Choose num_blocks and threads_per_block such that: - // // - // // num_blocks * threads_per_block = - // // RoundUpToNextMultipleOf(Ceil(num_elems / kTileSize), warpSize), - // // - // // and threads_per_block is a multiple of warpSize. - // reduce_kernel // - auto loop_body_emitter = [=](const IrArray::Index& tile_index) -> Status { - const int num_reduces = reducers.size(); - llvm::Type* element_ir_type = - llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_); - std::vector partial_reduction_result_addresses; - for (int i = 0; i != num_reduces; ++i) { - llvm::Value* partial_reduction_result_address = - Alloca(element_ir_type, /*ArraySize=*/nullptr, - "partial_reduction_result." + llvm::Twine(i)); - TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, - init_value_gens[i](IrArray::Index(index_ty))); - Store(init_ir_value, partial_reduction_result_address); - partial_reduction_result_addresses.push_back( - partial_reduction_result_address); - } - - llvm::Value* x_in_tiles = tile_index[0]; - x_in_tiles = ZExtOrTrunc(x_in_tiles, index_ty); - - // Emit an inner for-loop that reduces the elements in the tile. - auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status { - std::unique_ptr tile_element_loop = - llvm_ir::ForLoop::EmitForLoop( - "element_id_in_tile", index_typed_constant(0), - index_typed_constant(kTileSize), index_typed_constant(1), &b_); - - // Emit the body of the partial reduction loop. - llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), - &b_); - llvm::Value* x = - NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileSize)), - tile_element_loop->GetIndVarValue()); - // Unless we know the tile is entirely in bounds, we have to emit a - // x-in-bounds check before reading from the input. - if (!tile_in_bounds) { - llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - ICmpULT(x, index_typed_constant(num_elems)), "x_in_bounds", &b_); - - // Emit code that reads the input element and accumulates it to - // the partial reduction result. - llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_); - } - - IrArray::Index input_index( - /*linear=*/x, input_shape, &b_); - llvm::Value* input_address = Alloca(element_ir_type); - for (int i = 0; i != num_reduces; ++i) { - TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, - input_gens[i](input_index)); - Store(input_ir_value, input_address); - TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - *reducers[i], - {partial_reduction_result_addresses[i], input_address}, - partial_reduction_result_addresses[i])); - } - return EmitExtraOutputsForReduce(reduce, input_index, extra_output_gens); - }; - - // x_end = kTileSize + x_in_tiles * kTileSize, i.e., the location that's - // immediately beyond the tile. - llvm::Value* x_end = - NSWAdd(index_typed_constant(kTileSize), - NSWMul(x_in_tiles, index_typed_constant(kTileSize))); - // The tile is entirely in bound if all_threads_in_bounds or - // x_end <= num_elems. - llvm::Value* tile_in_bounds = - Or(ICmpULE(x_end, index_typed_constant(num_elems)), - b_.getInt1(all_threads_in_bounds)); - llvm_ir::LlvmIfData if_tile_in_bounds_data = - llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &b_); - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block, &b_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/true)); - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block, &b_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/false)); - - // After the if-then-else statement on tile_in_bounds, emit calls to - // shfl_down that accumulate the partial reduction results of all threads - // from the warp. - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block, &b_); - int bit_width = llvm_ir::GetSizeInBits(element_ir_type); - // bitcast cannot be applied to aggregate types (even packed ones), so we - // instead bitcast addresses of load/store to intN* of the same bit-width. - llvm::Type* shuffle_ir_type = element_ir_type->isStructTy() - ? b_.getIntNTy(bit_width) - : element_ir_type; - for (int shuffle_distance = kWarpSize / 2; shuffle_distance >= 1; - shuffle_distance /= 2) { - llvm::Value* result_from_other_lane = - Alloca(element_ir_type, nullptr, "result_from_other_lane"); - for (int i = 0; i != num_reduces; ++i) { - llvm::Value* partial_reduction_result = - Load(BitCast(partial_reduction_result_addresses[i], - shuffle_ir_type->getPointerTo()), - "partial_reduction_result"); - CHECK_EQ(launch_dimensions.threads_per_block() % kWarpSize, 0) - << "Requires block size a multiple of the warp size, otherwise we " - "will read undefined elements."; - Store(EmitFullWarpShuffleDown(partial_reduction_result, - b_.getInt32(shuffle_distance), &b_), - BitCast(result_from_other_lane, shuffle_ir_type->getPointerTo())); - TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - *reducers[i], - {partial_reduction_result_addresses[i], result_from_other_lane}, - partial_reduction_result_addresses[i])); - } - } - - const HloInstruction* output = - reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; - - // Emit an atomic operation that accumulates the partial reduction result of - // lane 0 (which holds the partially accumulated result for its warp) to the - // output element. - llvm::Value* lane_id = - URem(x_in_tiles, index_typed_constant(kWarpSize), "lane_id"); - llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse( - ICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", &b_); - llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_); - - for (int i = 0; i != num_reduces; ++i) { - llvm::Value* output_address = - GetIrArray(*output, *output, reduce_output_shapes[i]) - .EmitArrayElementAddress( - IrArray::Index( - /*linear=*/b_.getInt64(0), - ShapeUtil::GetSubshape(output->shape(), - reduce_output_shapes[i]), - &b_), - &b_, "output_element_address"); - TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( - *reducers[i], output_address, partial_reduction_result_addresses[i])); - } - return Status::OK(); - }; - - // Emit a parallel loop that iterates through all input tiles, one per thread. - UpdateLaunchDimensions(launch_dimensions, kernel_thunk, - ir_emitter_context_->llvm_module()); - return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape, - launch_dimensions, &b_) - .EmitLoop(IrName(reduce), index_ty); -} - -Status IrEmitterUnnested::EmitColumnReduction( - KernelThunk* kernel_thunk, int64 height, int64 width, - HloInstruction* reduce, const Shape& input_shape, - absl::Span input_gens, - absl::Span init_value_gens, - absl::Span reducers, - absl::Span reduce_output_shapes, - absl::Span> - extra_output_gens) { - // Divide the input matrix into tiles of size KxL. For example, when the - // input matrix is 4x4, K=2, and L=1 the tiled matrix looks like - // - // 0123 - // 0123 - // 4567 - // 4567 // Numbers indicate tile IDs. - // - // Each tile is first partially reduced to a scalar by a thread, and then the - // scalar is accumulated to the output vector using atomic operations. - // - // We choose 128 as the tile size based on empirical evidence. It's big enough - // to reduce the amount of atomic adds in the end, maximizing the memory - // bandwidth. A tile width of 2 allows for high memory bandwidth utilization - // on 16b input data. - constexpr int64 kTileHeight = 128; - constexpr int64 kTileWidth = 2; - - // If the height is not a multiple of kTileHeight, we pad the bottom of the - // input matrix. - const int64 height_in_tiles = CeilOfRatio(height, kTileHeight); - // If width is not a multiple of kTileWidth the rightmost thread will process - // fewer input elements. - const int64 width_in_tiles = CeilOfRatio(width, kTileWidth); - Shape tiled_input_shape = - ShapeUtil::MakeShapeWithLayout(reduce->shape().element_type(), - {height_in_tiles, width_in_tiles}, {1, 0}); - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - tiled_input_shape, ir_emitter_context_->device_description()); - - // TODO(b/110211620): Convert to use i32 index_type when it is possible. - llvm::Type* index_ty = b_.getInt64Ty(); - - auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { - return llvm::ConstantInt::get(index_ty, c); - }; - - // for (linear_index = threadIdx.x + blockIdx.x * blockDim.x; - // linear_index < height_in_tiles * width_in_tiles; - // linear_index += blockDim.x * gridDim.x) { - // y_in_tiles = linear_index / width_in_tiles; - // x_in_tiles = linear_index % width_in_tiles; - // - // partial_results[kTileWidth] = init_values; - // tile_in_y_bounds = height % kTileHeight == 0 || - // y_in_tiles * kTileHeight + kTileHeight <= height; - // tile_in_x_bounds = width % kTileWidth == 0 || - // x_in_tiles * kTileWidth + kTileWidth <= width; - // // The implementation handles y and x bound checks separately. - // if (tile_in_y_bounds && tile_in_x_bounds) { - // for (y_offset : range(kTileHeight)) { - // y = y_in_tiles * kTileHeight + y_offset; - // for (x_offset : range(kTileWidth)) { - // x = x_in_tiles * kTileWidth + x_offset; - // partial_result = Reducer(partial_result[x_offset], input[y][x]); - // } - // } - // } else { - // for (y_offset : range(kTileHeight)) { - // y = y_in_tiles * kTileHeight + y_offset; - // for (y_offset : range(kTileHeight)) { - // x = x_in_tiles * kTileWidth + x_offset; - // if (y < height && x < width) { - // partial_result = Reducer(partial_result, input[y][x]); - // } - // } - // } - // } - // for (x_offset : range(kTileWidth)) { - // AtomicReducer(&output[x + x_offset], partial_result[x_offset]); - // } - // } - auto loop_body_emitter = [=](const IrArray::Index& tile_index) -> Status { - const int num_reduces = reducers.size(); - // Emit the loop body that reduces one tile. - llvm::Type* element_ir_type = - llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_); - std::vector partial_reduction_result_addresses; - for (int i = 0; i != num_reduces; ++i) { - for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) { - llvm::Value* partial_reduction_result_address = - Alloca(element_ir_type, /*ArraySize=*/nullptr, - "partial_reduction_result." + - llvm::Twine(i * kTileWidth + x_offset)); - TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, - init_value_gens[i](IrArray::Index(index_ty))); - Store(init_ir_value, partial_reduction_result_address); - partial_reduction_result_addresses.push_back( - partial_reduction_result_address); - } - } - - // Emit an inner for-loop that partially reduces the elements in the given - // tile. - llvm::Value* y_in_tiles = tile_index[0]; - llvm::Value* x_in_tiles = tile_index[1]; - - y_in_tiles = ZExtOrTrunc(y_in_tiles, index_ty); - x_in_tiles = ZExtOrTrunc(x_in_tiles, index_ty); - - auto emit_tile_element_loop = [=](bool tile_in_y_bounds, - bool tile_in_x_bounds) -> Status { - std::unique_ptr tile_element_loop = - llvm_ir::ForLoop::EmitForLoop( - "element_id_in_tile", index_typed_constant(0), - index_typed_constant(kTileHeight), index_typed_constant(1), &b_); - - // Emit the body of the partial reduction loop. - llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), - &b_); - llvm::Value* y = - NSWAdd(NSWMul(y_in_tiles, index_typed_constant(kTileHeight)), - tile_element_loop->GetIndVarValue()); - - // Unless we know that y is in bounds, we have to emit a check before - // reading from the input. - if (!tile_in_y_bounds) { - llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - ICmpULT(y, index_typed_constant(height)), "y_in_bounds", &b_); - - // Emit code that reads the input element and accumulates it to - // the partial reduction result. - llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_); - } - for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) { - llvm::Value* x = - NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileWidth)), - index_typed_constant(x_offset)); - // Unless we know that x is in bounds, we have to emit a check before - // reading from the input. - if (!tile_in_x_bounds) { - llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - ICmpULT(x, index_typed_constant(width)), "x_in_bounds", &b_); - llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_); - } - llvm::Value* input_address = Alloca(element_ir_type); - // {y,x} is an index to input_matrix_shape [height,width]. We need to - // convert that to an index to input_shape (the shape of the operand of - // "reduce"). This conversion is composed of a transposition from - // input_shape to normalized_input_shape and a reshape from - // normalized_input_shape to input_matrix_shape. - const Shape normalized_input_shape = - ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( - input_shape); - auto input_shape_min2maj = LayoutUtil::MinorToMajor(input_shape); - const std::vector transpose_dimension_mapping( - input_shape_min2maj.rbegin(), input_shape_min2maj.rend()); - - const Shape input_matrix_shape = - ShapeUtil::MakeShapeWithDescendingLayout(input_shape.element_type(), - {height, width}); - const IrArray::Index input_matrix_index({y, x}, input_matrix_shape, - &b_); - const IrArray::Index input_index = - input_matrix_index - .SourceIndexOfReshape(input_matrix_shape, - normalized_input_shape, &b_) - .SourceIndexOfTranspose(normalized_input_shape, input_shape, - transpose_dimension_mapping, &b_); - for (int i = 0; i != num_reduces; ++i) { - TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, - input_gens[i](input_index)); - Store(input_ir_value, input_address); - TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - *reducers[i], - {partial_reduction_result_addresses[i * kTileWidth + x_offset], - input_address}, - partial_reduction_result_addresses[i * kTileWidth + x_offset])); - TF_RETURN_IF_ERROR(EmitExtraOutputsForReduce(reduce, input_index, - extra_output_gens)); - } - } - return Status::OK(); - }; - - // y_end = kTileHeight + y_in_tiles * kTileHeight, i.e., the y location - // that's immediately beyond the tile. - llvm::Value* y_end = - NSWAdd(index_typed_constant(kTileHeight), - NSWMul(y_in_tiles, index_typed_constant(kTileHeight))); - // x_end = kTileWidth + x_in_tiles * kTileWidth, i.e., the x location - // that's immediately beyond the tile. - llvm::Value* x_end = - NSWAdd(index_typed_constant(kTileWidth), - NSWMul(x_in_tiles, index_typed_constant(kTileWidth))); - llvm::Value* tile_in_y_bounds = - Or(ICmpULE(y_end, index_typed_constant(height)), - b_.getInt1(height % kTileHeight == 0)); - llvm::Value* tile_in_x_bounds = - Or(ICmpULE(x_end, index_typed_constant(width)), - b_.getInt1(width % kTileWidth == 0)); - // The tile is in y bounds if "height" is a multiple of kTileHeight or - // y_end <= height. - llvm_ir::LlvmIfData if_tile_in_y_bounds_data = - llvm_ir::EmitIfThenElse(tile_in_y_bounds, "tile_in_y_bounds", &b_); - llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.true_block, &b_); - // The tile is in x bounds if "width" is a multiple of kTileWidth or - // x_end <= width. - llvm_ir::LlvmIfData if_tile_in_x_bounds_data = - llvm_ir::EmitIfThenElse(tile_in_x_bounds, "tile_in_x_bounds", &b_); - llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.true_block, &b_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/true, - /*tile_in_x_bounds=*/true)); - llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.false_block, &b_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/true, - /*tile_in_x_bounds=*/false)); - llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.false_block, &b_); - if_tile_in_x_bounds_data = - llvm_ir::EmitIfThenElse(tile_in_x_bounds, "tile_in_x_bounds", &b_); - llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.true_block, &b_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/false, - /*tile_in_x_bounds=*/true)); - llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.false_block, &b_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/false, - /*tile_in_x_bounds=*/false)); - - // After the nested if-then-else statement on tile_in_y_bounds and - // tile_in_x_bounds, emit atomic operations to accumulate the partial - // reduction result to the output element. - llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.after_block, &b_); - const HloInstruction* output = - reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; - for (int i = 0; i != num_reduces; ++i) { - for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) { - llvm::Value* x = - NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileWidth)), - index_typed_constant(x_offset)); - llvm::Value* output_address = - GetIrArray(*output, *output, reduce_output_shapes[i]) - .EmitArrayElementAddress( - IrArray::Index( - x, - ShapeUtil::GetSubshape(output->shape(), - reduce_output_shapes[i]), - &b_), - &b_, "output_element_address"); - TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( - *reducers[i], output_address, - partial_reduction_result_addresses[i * kTileWidth + x_offset])); - } - } - return Status::OK(); - }; - - // Emit a parallel loop that iterate through all input tiles. - UpdateLaunchDimensions(launch_dimensions, kernel_thunk, - ir_emitter_context_->llvm_module()); - return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape, - launch_dimensions, &b_) - .EmitLoop(IrName(reduce), index_ty); -} - -static std::pair ComputeTilingSchemeForReduction( - int64 depth, int64 width, int64 kWarpSize) { - constexpr int64 kTargetNumElementsPerThread = 64; - int64 x_tile_size = kTargetNumElementsPerThread; - int64 z_tile_size = 1; - - // Only tile along the x dimension with tile size kTargetNumElementsPerThread - // if doing so doesn't require a slow version of loop with bound check on each - // dimension. A more sophisticated heuristics is to enable tile along the - // x dimension with tile size kTargetNumElementsPerThread when either width is - // a factor of (kWarpSize * kTargetNumElementsPerThread) or width is big - // enough so that only a small fraction of the threads execute the slow - // version of loop with bound check. - if (width % (kWarpSize * kTargetNumElementsPerThread) != 0) { - x_tile_size = 8; - z_tile_size = 8; - while (depth % z_tile_size != 0) { - z_tile_size -= 1; - } - } - - return std::pair(x_tile_size, z_tile_size); -} - -Status IrEmitterUnnested::EmitRowReduction( - KernelThunk* kernel_thunk, int64 depth, int64 height, int64 width, - HloInstruction* reduce, const Shape& input_shape, - absl::Span input_gens, - absl::Span init_value_gens, - absl::Span reducers, - absl::Span reduce_output_shapes, - absl::Span> - extra_output_gens) { - // A naive algorithm is: - // 1. Divide the x dimension of the input tensor into tiles of size 1x1xX. - // 2. Partially reduces each tile to a scalar using one thread. - // 3. Accumulates that scalar to the output vector using atomic operations. - // - // for (linear_index = threadIdx.x + blockIdx.x * blockDim.x; - // linear_index < depth * height * width_in_tiles; - // linear_index += blockDim.x * gridDim.x) { - // int x_in_tiles = linear_index % width_in_tiles; - // int y = linear_index / width_in_tiles % height; - // int z = linear_index / (height * width_in_tiles); - // float partial_result = 0; - // for (element_id_in_tile : range(x_tile_size)) { - // int x = x_in_tiles * x_tile_size + element_id_in_tile; - // if (x < width) - // partial_result = reducer(partial_result, input[z][y][x]); - // } - // AtomicReducer(&output[y], partial_result); - // } - // - // Four optimizations are performed. - // - // 1. To coalesce global memory accesses, dilate the tile with a factor of 32 - // (i.e. the warp size). For example, suppose the width is 8x32=256. Instead - // of making each tile consecutive, we let make tile 0 column - // [0,32,64,...,224], tile 1 column [1,33,65,...,225], and so on. This ensures - // that threads in a warp access consecutive memory in one iteration (i.e. - // coalesced). In the above example, the warp that contains thread 0-31 - // accesses column 0-31 in the first iteration, and 32-63 in the second - // iteration, and so on. - // - // 2. Partially accumulate partial reduced results computed by threads in the - // same warp using shfl_down. Using shfl_down is faster than directly using - // atomic operations because shfl_down transfers the data between threads - // using shared memory and threads in the same warp run in lock step (thus no - // extra synchronization needed). See - // https://devblogs.nvidia.com/parallelforall/faster-parallel-reductions-kepler/ - // for details. The downside is, to produce correct results when using - // shfl_down, we need to guarantee threads in the same warp work on input - // elements with the same y, so the number of tiles in each row must be a - // multiple of 32. - // - // 3. Specialize the case that the entire tile is in bounds. When that is - // true, we don't need to emit "if(x 0; shuffle_distance /= 2) - // partial_result = Reducer( - // partial_result, - // __shfl_down_sync(CUDA_WARP_ALL, partial_result, shuffle_distance)); - // if (lane_id == 0) - // AtomicReducer(&output[y], partial_result); - // } - // - - int64 x_tile_size; - int64 z_tile_size; - std::tie(x_tile_size, z_tile_size) = - ComputeTilingSchemeForReduction(depth, width, kWarpSize); - - // Round the width in tiles up to the nearest multiple of kWarpSize, so that - // the use of shfl_down is valid. - const int64 width_in_tiles = - RoundUpToNearest(CeilOfRatio(width, x_tile_size), kWarpSize); - Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( - reduce->shape().element_type(), - {depth / z_tile_size, height, width_in_tiles}, {2, 1, 0}); - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - tiled_input_shape, ir_emitter_context_->device_description()); - llvm::Type* index_ty = - GetIndexTypeForKernel(reduce, launch_dimensions.launch_bound(), &b_); - - auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { - return llvm::ConstantInt::get(index_ty, c); - }; - - auto loop_body_emitter = [=](const IrArray::Index& tile_index) { - const int num_reduces = reducers.size(); - llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType( - input_shape.element_type(), ir_emitter_context_->llvm_module()); - std::vector partial_reduction_result_addresses; - for (int i = 0; i != num_reduces; ++i) { - llvm::Value* partial_reduction_result_address = - Alloca(element_ir_type, /*ArraySize=*/nullptr, - "partial_reduction_result." + llvm::Twine(i)); - TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, - init_value_gens[i](IrArray::Index(index_ty))); - Store(init_ir_value, partial_reduction_result_address); - partial_reduction_result_addresses.push_back( - partial_reduction_result_address); - } - - llvm::Value* z_tile = tile_index[0]; - llvm::Value* y = tile_index[1]; - llvm::Value* x_tile = tile_index[2]; - - x_tile = ZExtOrTrunc(x_tile, index_ty); - - llvm::Value* warp_id = - UDiv(x_tile, index_typed_constant(kWarpSize), "warp_id"); - llvm::Value* lane_id = - URem(x_tile, index_typed_constant(kWarpSize), "lane_id"); - - // The x-location of the last element in this z-x-tile. - // last_x = lane_id + warpSize * (x_tile_size - 1 + warp_id * x_tile_size); - llvm::Value* last_x = NSWAdd( - lane_id, - NSWMul(index_typed_constant(kWarpSize), - NSWAdd(index_typed_constant(x_tile_size - 1), - NSWMul(warp_id, index_typed_constant(x_tile_size))))); - - KernelSupportLibrary ksl( - &b_, - /*unroll_mode=*/xla::llvm_ir::UnrollMode::kFullyUnroll, - /*prevent_vectorization=*/false); - - // Emit a for-loop that partially reduces the elements in the given - // z-x-tile. - auto emit_z_x_tile_element_loop = [&](bool x_tile_in_bounds, - int64 x_tile_loop_bound) -> Status { - auto emit_z_tile_element_loop = [&](llvm::Value* z_indvar) -> Status { - llvm::Value* z = - NSWAdd(z_indvar, NSWMul(index_typed_constant(z_tile_size), z_tile)); - TF_RETURN_IF_ERROR(ksl.For( - "x_tile", - /*start=*/index_typed_constant(0), - /*end=*/index_typed_constant(x_tile_loop_bound), - /*step=*/1, [&](llvm::Value* x_indvar) -> Status { - // x = lane_id + - // warpSize * (element_id_in_x_tile + warp_id * x_tile_size); - llvm::Value* x = NSWAdd( - lane_id, - NSWMul(index_typed_constant(kWarpSize), - NSWAdd(x_indvar, - NSWMul(warp_id, llvm::ConstantInt::get( - index_ty, x_tile_size))))); - - // Unless we know the x-tile is entirely in bounds, we have to - // emit a x-in-bounds check before reading from the input. - if (!x_tile_in_bounds) { - llvm_ir::LlvmIfData if_x_in_bounds_data = - llvm_ir::EmitIfThenElse( - ICmpULT(x, index_typed_constant(width)), "x_in_bounds", - &b_); - // Points b_ to the then-block. - llvm_ir::SetToFirstInsertPoint(if_x_in_bounds_data.true_block, - &b_); - } - - // Emit code that reads the input element and accumulates it - // to the partial reduction result. - llvm::Value* input_address = Alloca(element_ir_type); - { - // {z,y,x} is an index to input_3d_tensor_shape - // [depth,height,width]. We need to convert that to an index - // to input_shape (the shape of the operand of "reduce"). - // This conversion is composed of a transposition from - // input_shape to normalized_input_shape and a reshape from - // normalized_input_shape to input_3d_tensor_shape. - const Shape normalized_input_shape = ShapeUtil:: - MakeShapeWithDescendingLayoutAndSamePhysicalLayout( - input_shape); - auto input_shape_min2maj = - LayoutUtil::MinorToMajor(input_shape); - const std::vector transpose_dimension_mapping( - input_shape_min2maj.rbegin(), input_shape_min2maj.rend()); - const Shape input_3d_tensor_shape = - ShapeUtil::MakeShapeWithDescendingLayout( - input_shape.element_type(), {depth, height, width}); - const IrArray::Index input_3d_tensor_index( - {z, y, x}, input_3d_tensor_shape, &b_); - const IrArray::Index input_index = - input_3d_tensor_index - .SourceIndexOfReshape(input_3d_tensor_shape, - normalized_input_shape, &b_) - .SourceIndexOfTranspose( - normalized_input_shape, input_shape, - transpose_dimension_mapping, &b_); - - for (int i = 0; i != num_reduces; ++i) { - TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, - input_gens[i](input_index)); - Store(input_ir_value, input_address); - TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - *reducers[i], - {partial_reduction_result_addresses[i], input_address}, - partial_reduction_result_addresses[i])); - } - return EmitExtraOutputsForReduce(reduce, input_index, - extra_output_gens); - } - })); - return Status::OK(); - }; - - return ksl.For("z_tile", - /*start=*/index_typed_constant(0), - /*end=*/index_typed_constant(z_tile_size), - /*step=*/1, emit_z_tile_element_loop); - }; - - llvm::Value* tile_in_bounds = - Or(b_.getInt1(width % (x_tile_size * kWarpSize) == 0), - ICmpULT(last_x, index_typed_constant(width))); - - TF_RETURN_IF_ERROR( - ksl.If(tile_in_bounds, - /*true_block_generator=*/ - [&]() -> Status { - return emit_z_x_tile_element_loop(/*x_tile_in_bounds=*/true, - x_tile_size); - }, - /*false_block_generator=*/ - [&]() -> Status { - return emit_z_x_tile_element_loop( - /*x_tile_in_bounds=*/false, - CeilOfRatio(width % (x_tile_size * kWarpSize), kWarpSize)); - })); - - // After accumulating the elements of the z_x_tile, emit calls to - // shfl_down that accumulate the partial reduction results of all - // threads in a warp. - int bit_width = llvm_ir::GetSizeInBits(element_ir_type); - // bitcast cannot be applied to aggregate types (even packed ones), so we - // instead bitcast addresses of load/store to intN* of the same bit-width. - llvm::Type* shuffle_ir_type = element_ir_type->isStructTy() - ? b_.getIntNTy(bit_width) - : element_ir_type; - for (int shuffle_distance = 16; shuffle_distance >= 1; - shuffle_distance /= 2) { - llvm::Value* result_from_other_lane = - Alloca(element_ir_type, nullptr, "result_from_other_lane"); - for (int i = 0; i != num_reduces; ++i) { - llvm::Value* partial_reduction_result = - Load(BitCast(partial_reduction_result_addresses[i], - shuffle_ir_type->getPointerTo()), - "partial_reduction_result"); - CHECK_EQ(launch_dimensions.threads_per_block() % kWarpSize, 0) - << "Requires block size a multiple of the warp size, otherwise we " - "will read undefined elements."; - Store(EmitFullWarpShuffleDown(partial_reduction_result, - b_.getInt32(shuffle_distance), &b_), - BitCast(result_from_other_lane, shuffle_ir_type->getPointerTo())); - TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - *reducers[i], - {partial_reduction_result_addresses[i], result_from_other_lane}, - partial_reduction_result_addresses[i])); - } - } - - const HloInstruction* output = - reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; - - // Emit an atomic operation that accumulates the partial reduction result of - // lane 0 (which holds the partially accumulated result for its warp) to the - // output element. - llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse( - ICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", &b_); - llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_); - for (int i = 0; i != num_reduces; ++i) { - llvm::Value* output_address = - GetIrArray(*output, *output, reduce_output_shapes[i]) - .EmitArrayElementAddress( - IrArray::Index(y, - ShapeUtil::GetSubshape( - output->shape(), reduce_output_shapes[i]), - &b_), - &b_, "output_element_address"); - // We don't need to emit atomic operations if there is only one tile of - // results. 'depth' is the z dimension, 'width' is the x dimension. - if (z_tile_size >= depth && x_tile_size >= width) { - TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - *reducers[i], - {output_address, partial_reduction_result_addresses[i]}, - output_address)); - } else { - TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( - *reducers[i], output_address, - partial_reduction_result_addresses[i])); - } - } - return Status::OK(); - }; - - // Emit a parallel loop that iterates through every input tiles. - UpdateLaunchDimensions(launch_dimensions, kernel_thunk, - ir_emitter_context_->llvm_module()); - return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape, - launch_dimensions, &b_) - .EmitLoop(IrName(reduce), index_ty); -} - -// Figures out whether `reduce` is a row or column reduction, and which -// dimensions to reduce, and calls either `EmitRowReduction` or -// `EmitColumnReduction` as appropriate. -// Prerequisite: all the dimensions to keep are contiguous in the input layout -// and, if `reduce` is fused, the fused subgraph is pure -// elementwise. -Status IrEmitterUnnested::EmitReductionToVector( - KernelThunk* kernel_thunk, HloInstruction* reduce, const Shape& input_shape, - absl::Span input_gens, - absl::Span init_value_gens, - absl::Span dimensions_to_reduce, - absl::Span reducers, - absl::Span reduce_output_shapes, - absl::Span> - extra_output_gens) { - // This emission requires "reduce" to have an input layout. It is either set - // by LayoutAssignment (for a top-level kReduce) or by InstructionFusion (for - // a fused kReduce). - CHECK(input_shape.has_layout()) << "LayoutAssignment or InstructionFusion " - "doesn't set the input layout of " - << reduce->ToString(); - - // Specialize multi-dimensional-array-to-vector reduction. - std::vector input_dims_to_keep; - for (int64 input_dim = 0; input_dim < ShapeUtil::Rank(input_shape); - ++input_dim) { - if (std::find(dimensions_to_reduce.begin(), dimensions_to_reduce.end(), - input_dim) == dimensions_to_reduce.end()) { - input_dims_to_keep.push_back(input_dim); - } - } - - // Sort the dimensions to keep from minor to major, to facilitate checking - // whether another dimension is major or minor of them. - std::sort(input_dims_to_keep.begin(), input_dims_to_keep.end(), - [&input_shape](int64 dim_a, int64 dim_b) { - return PositionInContainer(LayoutUtil::MinorToMajor(input_shape), - dim_a) < - PositionInContainer(LayoutUtil::MinorToMajor(input_shape), - dim_b); - }); - // Now, if output rank is at least 1, `input_dims_to_keep.front()` is - // minormost and `input_dims_to_keep.back()` is majormost. - - // If the dimensions to keep are minormost, emit a column reduction. As all - // the dimensions to keep are contiguous, by prerequisite of - // `EmitReductionToVector`, we only need to check whether the minormost - // dimension of the input is to keep. - if (ShapeUtil::IsEffectiveScalar(reduce->shape())) { - return EmitReductionToScalar(kernel_thunk, reduce, input_shape, input_gens, - init_value_gens, reducers, - reduce_output_shapes, extra_output_gens); - } else if (input_dims_to_keep.front() == - LayoutUtil::Minor(input_shape.layout(), 0)) { - // Column reduction. Treat the result of "input" as a matrix whose width - // is the most minor dimension and height the product of other dimensions, - // and treat "reduce" as a column reduction of the input matrix. - const int64 width = ShapeUtil::ElementsIn(reduce->shape()); - // "width" can be zero, so don't do - // height = ShapeUtil::ElementsIn(input_shape) / width; - int64 height = 1; - for (int64 input_dim = 0; input_dim < ShapeUtil::Rank(input_shape); - ++input_dim) { - if (!std::count(input_dims_to_keep.begin(), input_dims_to_keep.end(), - input_dim)) { - height *= input_shape.dimensions(input_dim); - } - } - return EmitColumnReduction(kernel_thunk, height, width, reduce, input_shape, - input_gens, init_value_gens, reducers, - reduce_output_shapes, extra_output_gens); - } else { - // Reduce the row dimension of a matrix or reduce dimension 0 and 2 in a - // 3D tensor. The size of dimension 1 (the height) is the size of the - // dimension to keep, the size of dimension 0 (the depth) is the product - // of dimensions that are more major than the dimension to keep, and the - // size of dimension 2 (the width) is the product of more minor - // dimensions. - int64 depth = 1; - int64 width = 1; - for (int64 input_dim = 0; input_dim < ShapeUtil::Rank(input_shape); - ++input_dim) { - if (PositionInContainer(LayoutUtil::MinorToMajor(input_shape), - input_dim) > - PositionInContainer(LayoutUtil::MinorToMajor(input_shape), - input_dims_to_keep.back())) { - depth *= input_shape.dimensions(input_dim); - } else if (PositionInContainer(LayoutUtil::MinorToMajor(input_shape), - input_dim) < - PositionInContainer(LayoutUtil::MinorToMajor(input_shape), - input_dims_to_keep.front())) { - width *= input_shape.dimensions(input_dim); - } - } - const int64 height = ShapeUtil::ElementsIn(reduce->shape()); - return EmitRowReduction(kernel_thunk, depth, height, width, reduce, - input_shape, input_gens, init_value_gens, reducers, - reduce_output_shapes, extra_output_gens); - } -} - Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { // TODO(b/112040122): Support multi-output reduce. if (!ShapeUtil::IsArray(reduce->shape())) { return Unimplemented("Multi-output reduce is not supported on GPU"); } - auto input = reduce->operand(0); - auto init_value = reduce->operand(1); - absl::Span dimensions_to_reduce(reduce->dimensions()); - HloComputation* reducer = reduce->to_apply(); - // HandleReduce specializes reduction from a multi-dimensional array to a 1D - // array. The specialized version requires an initializer thunk that - // initializes the output array to the initial value of the reduce. if (IsReductionToVector(*reduce)) { - TF_ASSIGN_OR_RETURN(std::unique_ptr initializer_thunk, - BuildInitializerThunk(reduce)); - std::vector> thunks; - thunks.push_back(std::move(initializer_thunk)); - std::unique_ptr kernel_thunk = - BuildKernelThunk(reduce, /*implements_whole_instruction=*/false); - - TF_CHECK_OK(EmitReductionToVector( - kernel_thunk.get(), reduce, input->shape(), - {[&](const IrArray::Index& index) { - return GetIrArray(*input, *reduce).EmitReadArrayElement(index, &b_); - }}, - {[&](const IrArray::Index& index) { - return GetIrArray(*init_value, *reduce) - .EmitReadArrayElement(index, &b_); - }}, - dimensions_to_reduce, {reducer}, {{}}, {})); - - thunks.push_back(std::move(kernel_thunk)); - - std::unique_ptr sequential_thunk = - absl::make_unique(std::move(thunks), reduce); - AddThunkToThunkSequence(std::move(sequential_thunk)); - return Status::OK(); + return EmitReductionToVector(reduce); } return IrEmitter::HandleReduce(reduce); @@ -1818,7 +763,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter( // Create the inner loop to iterate over the window. llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "inner"), &b_, index_type); - std::vector window_size; + DimensionVector window_size; for (const auto& dim : window.dimensions()) { window_size.push_back(dim.size()); CHECK_GT(dim.size(), 0); @@ -2171,7 +1116,18 @@ Status IrEmitterUnnested::HandleSelect(HloInstruction* select) { Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { std::vector> thunks; Shape keys_shape = sort->operand(0)->shape(); + int64 dimension_to_sort = sort->dimensions(0); + // In case there is a 'values' parameter that is a iota, we take note and use + // it later to ensure a stable sort. Otherwise, we don't guarantee a stable + // sort. + int64 iota_values_parameter_index = -1; for (int64 i = 0; i < sort->operand_count(); ++i) { + if (i > 0 && sort->operand(i)->opcode() == HloOpcode::kIota && + ShapeUtil::ElementIsIntegral(sort->operand(i)->shape()) && + Cast(sort->operand(i))->iota_dimension() == + dimension_to_sort) { + iota_values_parameter_index = i; + } ShapeIndex shape_index = sort->operand_count() > 1 ? ShapeIndex({i}) : ShapeIndex({}); // We assume that the layout of all involved operands and outputs is the @@ -2196,7 +1152,6 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { } } - int64 dimension_to_sort = sort->dimensions(0); uint64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort); int64 num_stages = tensorflow::Log2Ceiling(dimension_to_sort_bound); CHECK_GE(1ULL << num_stages, dimension_to_sort_bound); @@ -2298,8 +1253,9 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { } } return llvm_ir::EmitSortInPlace( - dimension_to_sort, keys_array, values_arrays, IrName(sort), xor_masks, - &b_, launch_dimensions, + dimension_to_sort, keys_array, values_arrays, + iota_values_parameter_index, IrName(sort), xor_masks, &b_, + launch_dimensions, xor_masks.size() > 1 ? num_iterations_in_sort_dim : standard_num_iterations_in_sort_dim, kTileSize); @@ -2385,7 +1341,7 @@ Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) { return Status::OK(); } -Status IrEmitterUnnested::HandleAfterAll(HloInstruction* gen_token) { +Status IrEmitterUnnested::HandleAfterAll(HloInstruction* after_all) { return Status::OK(); } @@ -3146,31 +2102,6 @@ std::vector IrEmitterUnnested::ConstructIrArrayForInputs( return param_arrays; } -int IrEmitterUnnested::ConstructOutputReducedShapeAndCastOutputIrArrayToShape( - const HloInstruction& hlo, const std::vector& output_arrays, - absl::Span reduced_output_dims, - std::vector* output_reduced_shapes, - std::vector* output_in_reduced_shape_arrays) { - int64 num_outputs = 1; - if (hlo.IsMultiOutputFusion()) { - num_outputs = ShapeUtil::TupleElementCount(hlo.shape()); - output_in_reduced_shape_arrays->reserve(num_outputs); - output_reduced_shapes->reserve(num_outputs); - for (int64 i = 0; i < num_outputs; ++i) { - output_reduced_shapes->push_back(ShapeUtil::MakeShapeWithDescendingLayout( - ShapeUtil::GetSubshape(hlo.shape(), {i}).element_type(), - reduced_output_dims)); - output_in_reduced_shape_arrays->push_back( - output_arrays[i].CastToShape((*output_reduced_shapes)[i], &b_)); - } - } else { - output_reduced_shapes->push_back(ShapeUtil::MakeShapeWithDescendingLayout( - hlo.shape().element_type(), reduced_output_dims)); - output_in_reduced_shape_arrays->push_back( - output_arrays[0].CastToShape((*output_reduced_shapes)[0], &b_)); - } - return num_outputs; -} int IrEmitterUnnested::ConstructInputReducedShapeAndCastInputIrArrayToShape( const HloInstruction& hlo, const std::vector& param_arrays, @@ -3230,82 +2161,854 @@ llvm::Value* GetBlockIdx(llvm::IRBuilder<>* builder, llvm::Type* index_ty, "block.id.x"); } -// Emits code to process up to (tile_size/num_rows) elements in a tile, given -// `emit_elem_function` is the function to emit code to process one element, `y` -// and `x` are the coordinates for the first element to process, and `index` is -// the index for the origin of the tile. Emits bounds check to ensure that each -// processed element is within the boundary defined by `tile_width` and -// `tile_height`. -void EmitTiledElementalCodeWithBoundsCheck( - int64 tile_size, int64 num_rows, const IrArray::Index& index, - const string& loop_name, KernelSupportLibrary* ksl, - llvm::IRBuilder<>* builder, llvm::Value* y, llvm::Value* x, - llvm::Value* tile_width, llvm::Value* tile_height, - const std::function& - emit_elem_function) { - llvm::Type* index_ty = tile_width->getType(); - // Emits a constant value with index type. - auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { - return llvm::ConstantInt::get(index_ty, c); - }; - // Adds `addend` to the given `dim` of `index`. - auto offset_dim = [&](IrArray::Index index, llvm::Value* addend, int64 dim) { - index[dim] = builder->CreateAdd(index[dim], addend); - return index; - }; - - auto emit_full_tile = [&] { - for (int64 i = 0; i < tile_size; i += num_rows) { - auto source_idx = offset_dim(index, index_typed_constant(i), /*dim=*/1); - auto y_loc = builder->CreateAdd(index_typed_constant(i), y); - emit_elem_function(source_idx, y_loc); +void EmitFullTile(const KernelMappingScheme* mapping_scheme, + const IrArray::Index& tile_origin_index, + llvm::IRBuilder<>* builder, llvm::Value* y, llvm::Value* x, + llvm::Type* index_ty, + const std::function& emit_elem_function) { + int64 num_threads_x = mapping_scheme->GetNumberOfThreadsForDimensionX(); + int64 num_threads_y = mapping_scheme->GetNumberOfThreadsForDimensionY(); + int64 tile_size_x = mapping_scheme->GetTileSizeForDimensionX(); + int64 tile_size_y = mapping_scheme->GetTileSizeForDimensionY(); + for (int64 i = 0; i < tile_size_y; i += num_threads_y) { + IrArray::Index source_idx_y = + tile_origin_index.AddOffsetToDim(llvm::ConstantInt::get(index_ty, i), + KernelMappingScheme::DimY, builder); + llvm::Value* y_loc = + builder->CreateAdd(llvm::ConstantInt::get(index_ty, i), y); + for (int64 j = 0; j < tile_size_x; j += num_threads_x) { + IrArray::Index source_idx = + source_idx_y.AddOffsetToDim(llvm::ConstantInt::get(index_ty, j), + KernelMappingScheme::DimX, builder); + llvm::Value* x_loc = + builder->CreateAdd(llvm::ConstantInt::get(index_ty, j), x); + emit_elem_function(source_idx, y_loc, x_loc); } - }; + } +} + +void EmitPartialTile( + const KernelMappingScheme* mapping_scheme, + const IrArray::Index& tile_origin_index, const string& loop_name, + KernelSupportLibrary* ksl, llvm::IRBuilder<>* builder, llvm::Value* y, + llvm::Value* x, llvm::Value* tile_height, llvm::Value* tile_width, + llvm::Type* index_ty, + const std::function& emit_elem_function) { + int64 num_threads_x = mapping_scheme->GetNumberOfThreadsForDimensionX(); + int64 num_threads_y = mapping_scheme->GetNumberOfThreadsForDimensionY(); + int64 tile_size_x = mapping_scheme->GetTileSizeForDimensionX(); + + for (int64 j = 0; j < tile_size_x; j += num_threads_x) { + IrArray::Index source_idx = + tile_origin_index.AddOffsetToDim(llvm::ConstantInt::get(index_ty, j), + KernelMappingScheme::DimX, builder); + llvm::Value* x_loc = + builder->CreateAdd(llvm::ConstantInt::get(index_ty, j), x); + + ksl->IfReturnVoid( + loop_name + "_x_in_tile", builder->CreateICmpULT(x_loc, tile_width), + [&] { + // tile_height_bound = + // ceil(tile_height / num_threads_y) * num_threads_y + llvm::Value* ceiling_of_ratio = builder->CreateUDiv( + builder->CreateAdd(tile_height, llvm::ConstantInt::get( + index_ty, num_threads_y - 1)), + llvm::ConstantInt::get(index_ty, num_threads_y)); + llvm::Value* tile_height_bound = builder->CreateMul( + ceiling_of_ratio, + llvm::ConstantInt::get(index_ty, num_threads_y)); + ksl->ForReturnVoid( + loop_name, /*start=*/llvm::ConstantInt::get(index_ty, 0), + /*end=*/tile_height_bound, + /*step=*/llvm::ConstantInt::get(index_ty, num_threads_y), + [&](llvm::Value* y_indvar) { + llvm::Value* y_loc = builder->CreateAdd(y_indvar, y); + ksl->IfReturnVoid( + loop_name + "_y_in_tile", + builder->CreateICmpULT(y_loc, tile_height), [&] { + emit_elem_function( + source_idx.AddOffsetToDim( + y_indvar, KernelMappingScheme::DimY, builder), + y_loc, x_loc); + }); + }); + }); + } +} + +// Emits code to process up to +// (tile_size_x/num_threads_x * tile_size_y/num_threads_y) elements in a tile, +// given `emit_elem_function` is the function to emit code to process one +// element, `y` and `x` are the intra-tile coordinates for the first element +// to process, and `index` is the index for the origin of the tile. Information +// about tile_size_x/y and num_threads_x/y are stored in `mapping_scheme`. Emits +// bounds check to ensure that each processed element is within the boundary +// defined by `tile_width` and `tile_height`. +void EmitTiledElementalCodeWithBoundsCheck( + const KernelMappingScheme* mapping_scheme, + const IrArray::Index& tile_origin_index, const string& loop_name, + KernelSupportLibrary* ksl, llvm::IRBuilder<>* builder, llvm::Value* y, + llvm::Value* x, llvm::Value* tile_height, llvm::Value* tile_width, + const std::function& emit_elem_function) { + int64 tile_size_x = mapping_scheme->GetTileSizeForDimensionX(); + int64 tile_size_y = mapping_scheme->GetTileSizeForDimensionY(); + llvm::Type* index_ty = tile_width->getType(); - auto emit_last_row = [&] { - ksl->IfReturnVoid("x_in_tile", builder->CreateICmpULT(x, tile_width), [&] { - // tile_height_upper_bound = - // ceil(tile_height / num_rows) * num_rows - auto tile_height_upper_bound = builder->CreateMul( - builder->CreateUDiv( - builder->CreateAdd(tile_height, - index_typed_constant(num_rows - 1)), - index_typed_constant(num_rows)), - index_typed_constant(num_rows)); - ksl->ForReturnVoid( - loop_name, /*start=*/index_typed_constant(0), - /*end=*/tile_height_upper_bound, - /*step=*/index_typed_constant(num_rows), [&](llvm::Value* y_indvar) { - auto y_loc = builder->CreateAdd(y_indvar, y); - ksl->IfReturnVoid( - "y_in_tile", builder->CreateICmpULT(y_loc, tile_height), [&] { - emit_elem_function(offset_dim(index, y_indvar, /*dim=*/1), - y_loc); - }); - }); - }); - }; ksl->IfReturnVoid( - "full_tile", + loop_name + "_full_tile", builder->CreateAnd( - builder->CreateICmpEQ(index_typed_constant(tile_size), tile_width), - builder->CreateICmpEQ(index_typed_constant(tile_size), tile_height)), - emit_full_tile, emit_last_row); + builder->CreateICmpEQ(llvm::ConstantInt::get(index_ty, tile_size_x), + tile_width), + builder->CreateICmpEQ(llvm::ConstantInt::get(index_ty, tile_size_y), + tile_height)), + [&] { + EmitFullTile(mapping_scheme, tile_origin_index, builder, y, x, index_ty, + emit_elem_function); + }, + [&] { + EmitPartialTile(mapping_scheme, tile_origin_index, loop_name, ksl, + builder, y, x, tile_height, tile_width, index_ty, + emit_elem_function); + }); } } // namespace +// Emits code to process a tensor element in a tile for the given kCopy HLO that +// performs a 0-2-1 transpose. +// +// index: The index for the first output element in the normalized tensor. The +// normalized tensor is the resulting tensor after collapsing contiguous +// dimensions that play the same role in the transpose. +// y_loc: The y coordinate within a tile. +// x_loc: The x coordinate within a tile. +// kernel_info: Other information to support the kernel code generation. +void IrEmitterUnnested::EmitTileElementForCopy( + HloInstruction* hlo, const llvm_ir::IrArray::Index& index, + const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, + llvm::Value* x_loc) { + llvm_ir::TiledParameterInfo* tiled_param_info = + kernel_info->GetTiledParameterInfo(); + // TODO(jlebar): Add AA metadata to this load. + llvm::Instruction* load_from_shmem_buffer = + Load(GEP(tiled_param_info->GetBufferForParameter(0), + {b_.getInt64(0), x_loc, y_loc}), + "output_element"); + llvm_ir::IrArray output_array = GetIrArray(*hlo, *hlo); + Shape output_reduced_shape = ShapeUtil::MakeShapeWithDescendingLayout( + hlo->shape().element_type(), + kernel_info->GetKernelMappingScheme()->GetDimensionsInElements()); + // When the output_reduced_shape is a 0-2-1 transpose of the input shape, + // the 0-2-1 transpose is achieved through EmitWriteArrayElement. + output_array.CastToShape(output_reduced_shape, &b_) + .EmitWriteArrayElement(index, load_from_shmem_buffer, &b_); +} + +// Emits code to process a tensor element in a tile for the given kLoop fusion +// HLO containing parameters that are 0-2-1 transpose of its outputs. +// +// index: The index for the first output element in the normalized tensor, that +// is the resulting tensor after collapsing contiguous dimensions that play +// the same role in the transpose. +// kernel_info: Other information to support the kernel code generation. +// y_loc: The y coordinate within a tile. +// x_loc: The x coordinate within a tile. +void IrEmitterUnnested::EmitTileElementForFusion( + HloInstruction* hlo, const llvm_ir::IrArray::Index& index, + const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, + llvm::Value* x_loc) { + llvm_ir::TiledParameterInfo* tiled_param_info = + kernel_info->GetTiledParameterInfo(); + std::vector output_arrays = ConstructIrArrayForOutputs(*hlo); + GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_, + GetNestedComputer()); + FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(hlo), + &elem_emitter); + tiled_param_info->set_y(y_loc); + tiled_param_info->set_x(x_loc); + fused_emitter.SetTiledParameterInfo(tiled_param_info); + TF_CHECK_OK(hlo->fused_expression_root()->Accept(&fused_emitter)); + IrArray::Index untiled_index = + kernel_info->GetKernelMappingScheme()->GetUnnormalizedIndex( + index, output_arrays[0].GetShape()); + const llvm_ir::ElementGenerator& output_generator = + fused_emitter.GetRootGenerator(); + llvm::Value* output_value = output_generator(untiled_index).ValueOrDie(); + if (hlo->IsMultiOutputFusion()) { + DCHECK(output_value->getType()->isStructTy()); + DCHECK_EQ(output_value->getType()->getStructNumElements(), + output_arrays.size()); + for (int64 i = 0; i < output_arrays.size(); ++i) { + output_arrays[i].EmitWriteArrayElement( + untiled_index, ExtractValue(output_value, i), &b_); + } + } else { + output_arrays[0].EmitWriteArrayElement(untiled_index, output_value, &b_); + } +} + +// Information to support the code generation for a tiled reduction kernel. +using AddressVector = InlinedVector; +class ReductionCodegenInfo : public IrEmitterUnnested::KernelCodegenInfo { + public: + explicit ReductionCodegenInfo(llvm_ir::KernelMappingScheme* mapping_scheme, + bool is_row_reduction) + : KernelCodegenInfo(mapping_scheme), + current_output_linear_index_address_(nullptr), + current_output_inbound_address_(nullptr), + is_row_reduction_(is_row_reduction) {} + + void SetCurrentOutputLinearIndexAddress(llvm::AllocaInst* a) { + current_output_linear_index_address_ = a; + } + // Returns the address of the memory that stores the linear index of the + // current output. Since we are processing reduction to contiguous physical + // dimensions, this linear index is the linear index of the 1D output array. + llvm::AllocaInst* GetCurrentOutputLinearIndexAddress() const { + return current_output_linear_index_address_; + } + + void SetCurrentOutputInboundAddress(llvm::AllocaInst* a) { + current_output_inbound_address_ = a; + } + + llvm::AllocaInst* GetCurrentOutputInboundAddress() const { + return current_output_inbound_address_; + } + + AddressVector* GetMutablePartialResultAddresses() { + return &partial_result_addresses_; + } + const AddressVector& GetPartialResultAddresses() const { + return partial_result_addresses_; + } + + AddressVector* GetMutableReductionInputAddresses() { + return &reduction_input_addresses_; + } + const AddressVector& GetReductionInputAddresses() const { + return reduction_input_addresses_; + } + + InlinedVector* GetMutableReducers() { return &reducers_; } + const InlinedVector& GetReducers() const { + return reducers_; + } + int GetNumberOfReduces() const { return reducers_.size(); } + + InlinedVector* GetMutableReductionOutputShapeIndices() { + return &reduction_output_shape_indices_; + } + const InlinedVector& GetReductionOutputShapeIndices() const { + return reduction_output_shape_indices_; + } + + bool IsRowReduction() const { return is_row_reduction_; } + + // Return the dimension that is being reduced between DimX and DimY. + int GetReducedDimensionEnum() const { + return IsRowReduction() ? llvm_ir::KernelMappingScheme::DimX + : llvm_ir::KernelMappingScheme::DimY; + } + + // Return the dimension that is being ketp between DimX and DimY. + int GetKeptDimensionEnum() const { + return IsRowReduction() ? llvm_ir::KernelMappingScheme::DimY + : llvm_ir::KernelMappingScheme::DimX; + } + + private: + AddressVector partial_result_addresses_; + AddressVector reduction_input_addresses_; + InlinedVector reducers_; + InlinedVector reduction_output_shape_indices_; + llvm::AllocaInst* current_output_linear_index_address_; + llvm::AllocaInst* current_output_inbound_address_; + bool is_row_reduction_; +}; + +namespace { +// Returns a group of instructions that generate the output for the kernel +// containing the given HLO instruction. The result may be an unnested kReduce +// HLO, a nested kReduce HLO of a kInput fusion, or the operands of the tuple +// for a multiple output fusion. +absl::Span GetOutputInstructions( + HloInstruction* const* reduce_or_tuple_pointer) { + HloOpcode opcode = (*reduce_or_tuple_pointer)->opcode(); + CHECK(opcode == HloOpcode::kReduce || opcode == HloOpcode::kTuple); + return opcode == HloOpcode::kTuple + ? (*reduce_or_tuple_pointer)->operands() + : absl::Span(reduce_or_tuple_pointer, 1); +} + +const HloInstruction* GetFirstReduceInstruction( + absl::Span instructions) { + auto first_reduce_iter = + absl::c_find_if(instructions, [](const HloInstruction* inst) { + return inst->opcode() == HloOpcode::kReduce; + }); + CHECK_NE(first_reduce_iter, instructions.end()); + return *first_reduce_iter; +} + +}; // namespace + +void IrEmitterUnnested::EmitPrologueForOneReduction( + HloInstruction* unnested_hlo, HloInstruction* reduce_inst, int reduce_idx, + KernelCodegenInfo* kernel_info, GpuElementalIrEmitter* elemental_emitter, + ShapeIndex output_shape_index) { + ReductionCodegenInfo* reduction_info = + static_cast(kernel_info); + + InlinedVector* reducers = + reduction_info->GetMutableReducers(); + CHECK(IsReductionToVector(*reduce_inst)); + reducers->push_back(reduce_inst->to_apply()); + + InlinedVector* reduction_output_shape_indices = + reduction_info->GetMutableReductionOutputShapeIndices(); + reduction_output_shape_indices->push_back(std::move(output_shape_index)); + + AddressVector* reduction_input_addresses = + reduction_info->GetMutableReductionInputAddresses(); + llvm::Type* element_type = llvm_ir::PrimitiveTypeToIrType( + reduce_inst->shape().element_type(), ir_emitter_context_->llvm_module()); + llvm::AllocaInst* reduction_input_address = Alloca(element_type); + reduction_input_addresses->push_back(reduction_input_address); + + AddressVector* partial_result_addresses = + reduction_info->GetMutablePartialResultAddresses(); + llvm::AllocaInst* partial_result_address = + Alloca(element_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + llvm::Twine(reduce_idx)); + partial_result_addresses->push_back(partial_result_address); + + // Initialize the partial result with the initial value of the reduction. + llvm::Value* init_ir_value; + if (unnested_hlo->opcode() == HloOpcode::kFusion) { + HloInstruction* init_value_operand = reduce_inst->mutable_operand(1); + FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(unnested_hlo), + elemental_emitter); + + TF_CHECK_OK(init_value_operand->Accept(&fused_emitter)); + init_ir_value = + fused_emitter + .GetGenerator(init_value_operand)(IrArray::Index(b_.getInt32Ty())) + .ValueOrDie(); + } else { + const HloInstruction* init_value = unnested_hlo->operand(1); + init_ir_value = + GetIrArray(*init_value, *unnested_hlo) + .EmitReadArrayElement(IrArray::Index(b_.getInt32Ty()), &b_); + } + + Store(init_ir_value, partial_result_address); +} + +void IrEmitterUnnested::EmitPrologueForReduction( + HloInstruction* unnested_hlo, KernelCodegenInfo* kernel_info) { + VLOG(10) << "Emit prologue for reduction " << unnested_hlo->ToString(); + // Find the unnested kReduce or the tuple that contains a list of kReduce. + HloInstruction* reduce_or_tuple = unnested_hlo->opcode() == HloOpcode::kFusion + ? unnested_hlo->fused_expression_root() + : unnested_hlo; + absl::Span output_instructions = + GetOutputInstructions(&reduce_or_tuple); + ReductionCodegenInfo* reduction_info = + static_cast(kernel_info); + GpuElementalIrEmitter elemental_emitter(hlo_module_config_, + ir_emitter_context_->llvm_module(), + &b_, GetNestedComputer()); + const HloInstruction* first_reduce = nullptr; + for (int i = 0, e = output_instructions.size(); i != e; ++i) { + if (output_instructions[i]->opcode() != HloOpcode::kReduce) { + continue; + } + HloInstruction* reduce_inst = output_instructions[i]; + if (first_reduce == nullptr) { + first_reduce = reduce_inst; + } else { + CHECK(first_reduce->dimensions() == reduce_inst->dimensions()); + } + ShapeIndex output_shape_index; + if (reduce_or_tuple->opcode() == HloOpcode::kTuple) { + output_shape_index = {i}; + } + + EmitPrologueForOneReduction(unnested_hlo, reduce_inst, i, kernel_info, + &elemental_emitter, + std::move(output_shape_index)); + } + + // Allocate stack storage to store the current output linear index and record + // the address of the storage. + reduction_info->SetCurrentOutputLinearIndexAddress( + Alloca(reduction_info->GetIndexType())); + + if (!reduction_info->IsRowReduction()) { + llvm::Type* bool_ty = b_.getInt1Ty(); + llvm::AllocaInst* output_inbound_addr = Alloca(bool_ty); + Store(llvm::ConstantInt::get(bool_ty, 0), output_inbound_addr); + reduction_info->SetCurrentOutputInboundAddress(output_inbound_addr); + } +} + +void IrEmitterUnnested::EmitFullWarpShuffleDownLoopForAllReduces( + const InlinedVector& reducers, + const AddressVector& partial_result_addresses) { + for (int distance = 16; distance >= 1; distance /= 2) { + for (int i = 0; i != reducers.size(); ++i) { + llvm::Type* element_type = + partial_result_addresses[i]->getType()->getElementType(); + int bit_width = llvm_ir::GetSizeInBits(element_type); + llvm::Value* result_from_other_lane = Alloca( + element_type, nullptr, "result_from_other_lane" + llvm::Twine(i)); + // Bitcast cannot be applied to aggregate types (even packed ones), so + // we bitcast addresses of load/store to intN* of the same bit-width. + llvm::Type* shuffled_value_type = + element_type->isStructTy() ? b_.getIntNTy(bit_width) : element_type; + auto convert_pointer_for_shuffle = [&](llvm::Value* ptr) { + return BitCast(ptr, shuffled_value_type->getPointerTo()); + }; + llvm::Value* partial_result = + Load(convert_pointer_for_shuffle(partial_result_addresses[i]), + "partial_reduction_result"); + Store(EmitFullWarpShuffleDown(partial_result, b_.getInt32(distance), &b_), + convert_pointer_for_shuffle(result_from_other_lane)); + TF_CHECK_OK(EmitCallToNestedComputation( + *reducers[i], {partial_result_addresses[i], result_from_other_lane}, + partial_result_addresses[i])); + } + } +} + +void IrEmitterUnnested::EmitEpilogueForReduction( + HloInstruction* unnested_hlo, KernelCodegenInfo* kernel_info) { + ReductionCodegenInfo* reduction_info = + static_cast(kernel_info); + int num_reduces = reduction_info->GetNumberOfReduces(); + const AddressVector& partial_result_addresses = + reduction_info->GetPartialResultAddresses(); + const InlinedVector& reducers = + reduction_info->GetReducers(); + const InlinedVector& reduction_output_shape_indices = + reduction_info->GetReductionOutputShapeIndices(); + + if (reduction_info->IsRowReduction()) { + EmitFullWarpShuffleDownLoopForAllReduces(reducers, + partial_result_addresses); + llvm::Value* lane_id = reduction_info->GetLaneId(); + llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse( + ICmpEQ(lane_id, llvm::ConstantInt::get(lane_id->getType(), 0)), + "lane_id_is_zero", &b_); + llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_); + } else { + llvm::Value* output_inbound_addr = + reduction_info->GetCurrentOutputInboundAddress(); + llvm::Value* output_inbound = Load(output_inbound_addr); + llvm_ir::LlvmIfData if_output_inbound_data = llvm_ir::EmitIfThenElse( + ICmpEQ(output_inbound, + llvm::ConstantInt::get(output_inbound->getType(), 1)), + "output_inbound", &b_); + llvm_ir::SetToFirstInsertPoint(if_output_inbound_data.true_block, &b_); + } + + // Emit an atomic operation that accumulates the partial reduction to the + // output element. For row reduction, this is only for lane 0 due to the + // if-statement emitted above. + for (int i = 0; i != num_reduces; ++i) { + IrArray::Index element_index( + /*linear=*/Load(reduction_info->GetCurrentOutputLinearIndexAddress(), + "output_linear_addr"), + ShapeUtil::GetSubshape(unnested_hlo->shape(), + reduction_output_shape_indices[i]), + &b_); + llvm::Value* output_address = + GetIrArray(*unnested_hlo, *unnested_hlo, + reduction_output_shape_indices[i]) + .EmitArrayElementAddress(element_index, &b_, + "output_element_address"); + // Do not emit atomic operations if each element in the reduction result is + // computed by one block, that is the dimension being reduced has only one + // block. + const llvm_ir::KernelMappingScheme* mapping_scheme = + reduction_info->GetKernelMappingScheme(); + if (mapping_scheme->GetTileBlockSizeForDimension( + llvm_ir::KernelMappingScheme::DimZ) == 1 && + mapping_scheme->GetTileBlockSizeForDimension( + reduction_info->GetReducedDimensionEnum()) == 1) { + TF_CHECK_OK(EmitCallToNestedComputation( + *reducers[i], {output_address, partial_result_addresses[i]}, + output_address)); + } else { + TF_CHECK_OK(EmitAtomicOperationForNestedComputation( + *reducers[i], output_address, partial_result_addresses[i])); + } + } +} + +void IrEmitterUnnested::EmitTileElementForReduction( + HloInstruction* unnested_hlo, const llvm_ir::IrArray::Index& index, + const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, + llvm::Value* x_loc) { + VLOG(10) << "Emit tile element for reduce " << unnested_hlo->ToString(); + HloInstruction* reduce_or_tuple = unnested_hlo->opcode() == HloOpcode::kFusion + ? unnested_hlo->fused_expression_root() + : unnested_hlo; + llvm_ir::TiledParameterInfo* tiled_param_info = + kernel_info->GetTiledParameterInfo(); + tiled_param_info->set_y(y_loc); + tiled_param_info->set_x(x_loc); + + // Record the linear address for the current reduction. + const ReductionCodegenInfo* reduction_info = + dynamic_cast(kernel_info); + Store(index[reduction_info->GetKeptDimensionEnum()], + reduction_info->GetCurrentOutputLinearIndexAddress()); + if (!reduction_info->IsRowReduction()) { + llvm::Type* bool_ty = b_.getInt1Ty(); + llvm::AllocaInst* output_inbound_addr = + reduction_info->GetCurrentOutputInboundAddress(); + Store(llvm::ConstantInt::get(bool_ty, 1), output_inbound_addr); + } + + InlinedVector input_gens; + std::vector> + extra_output_gens; + GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_, + GetNestedComputer()); + FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(unnested_hlo), + &elem_emitter); + absl::Span output_instructions = + GetOutputInstructions(&reduce_or_tuple); + // Construct the ElementGenerator for each reduction and extra output in the + // the group of output instructions. + if (unnested_hlo->opcode() == HloOpcode::kFusion) { + fused_emitter.SetTiledParameterInfo(tiled_param_info); + TF_CHECK_OK(unnested_hlo->fused_expression_root()->Accept(&fused_emitter)); + + for (int i = 0, e = output_instructions.size(); i != e; ++i) { + const HloInstruction* inst = output_instructions[i]; + ShapeIndex output_shape_index; + if (reduce_or_tuple->opcode() == HloOpcode::kTuple) { + output_shape_index = {i}; + } + if (inst->opcode() == HloOpcode::kReduce) { + input_gens.push_back(fused_emitter.GetGenerator(inst->operand(0))); + } else { + extra_output_gens.emplace_back(fused_emitter.GetGenerator(inst), + std::move(output_shape_index)); + } + } + } else { + input_gens.push_back([&](const IrArray::Index& index) { + return GetIrArray(*unnested_hlo->operand(0), *unnested_hlo) + .EmitReadArrayElement(index, &b_); + }); + } + + IrArray::Index input_index = + reduction_info->GetKernelMappingScheme()->GetUnnormalizedIndex( + index, + GetFirstReduceInstruction(output_instructions)->operand(0)->shape()); + const AddressVector& partial_reduction_result_addresses = + reduction_info->GetPartialResultAddresses(); + const AddressVector& reduction_input_addresses = + reduction_info->GetReductionInputAddresses(); + const InlinedVector& reducers = + reduction_info->GetReducers(); + + // Emit code to generate the input and perform the reduction computation for + // each reduction instruction. + for (int i = 0; i != reducers.size(); ++i) { + llvm::Value* const input_ir_value = input_gens[i](input_index).ValueOrDie(); + Store(input_ir_value, reduction_input_addresses[i]); + TF_CHECK_OK(EmitCallToNestedComputation( + *reducers[i], + {partial_reduction_result_addresses[i], reduction_input_addresses[i]}, + partial_reduction_result_addresses[i])); + } + + // Emit code to generate the output for the non-reduction instructions in the + // fusion, if any. + TF_CHECK_OK( + EmitExtraOutputsForReduce(unnested_hlo, input_index, extra_output_gens)); +} + +// Emits a kernel for the hlo instruction using the given tiling scheme. +void IrEmitterUnnested::EmitBlock(const TileGenerator& emit_one_tile, + const KernelCodegenInfo* kernel_info, + KernelSupportLibrary& ksl, + llvm::Type* index_ty) { + KernelMappingScheme* mapping_scheme = kernel_info->GetKernelMappingScheme(); + absl::Span dims_in_tile = mapping_scheme->GetDimensionsInTiles(); + absl::Span dims_in_block = + mapping_scheme->GetDimensionsInBlocks(); + absl::Span block_sizes = mapping_scheme->GetBlockSizes(); + auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_ty, c); + }; + + // Emit all the tiles for a given dimension in a tile block. + auto emit_tiles_for_block_dim = + [&](const string& loop_name, const IrArray::Index& starting_tile, + int dim_id, + const std::function + emit_next_block_dim) { + if (block_sizes[dim_id] == 1) { + emit_next_block_dim(starting_tile); + } else { + llvm::Value* starting_tile_index_for_dim = starting_tile[dim_id]; + llvm::Value* block_size_for_dim = + index_typed_constant(block_sizes[dim_id]); + llvm::Value* block_id_for_dim = + b_.CreateUDiv(starting_tile_index_for_dim, block_size_for_dim); + llvm::Value* last_block_for_dim = + index_typed_constant(dims_in_block[dim_id] - 1); + llvm::Value* last_block_size_for_dim = index_typed_constant( + dims_in_tile[dim_id] - + (dims_in_block[dim_id] - 1) * block_sizes[dim_id]); + llvm::Value* num_tiles_in_block = + Select(ICmpEQ(last_block_for_dim, block_id_for_dim), + last_block_size_for_dim, block_size_for_dim); + + ksl.ForReturnVoid( + loop_name, + /*start=*/index_typed_constant(0), + /*end=*/num_tiles_in_block, + /*step=*/1, [&](llvm::Value* block_dim_induction_var) { + IrArray::Index tile_index = starting_tile.AddOffsetToDim( + block_dim_induction_var, dim_id, &b_); + emit_next_block_dim(tile_index); + }); + } + }; + + absl::Span reduced_dims = + mapping_scheme->GetDimensionsInElements(); + const bool block_contains_multi_tiles = + mapping_scheme->GetNumberOfTilesInOneBlock() > 1; + + // Emit the tile with a given tile_index, by calculating the tight bounds for + // each dimension of the tile and then calling emit_one_tile. + auto emit_one_tile_for_tile_index = [&](const IrArray::Index& tile_index) { + std::vector output_tile_bounds(3); + for (int i = KernelMappingScheme::DimY; i < KernelMappingScheme::DimTot; + ++i) { + int64 tile_size_for_dim = mapping_scheme->GetTileSizeForDimension(i); + // Only last row or column may not have full size. + llvm::Value* is_last_row = + ICmpEQ(tile_index[i], index_typed_constant(dims_in_tile[i] - 1)); + int64 partial_row_size = + reduced_dims[i] - (dims_in_tile[i] - 1) * tile_size_for_dim; + output_tile_bounds[i] = + Select(is_last_row, index_typed_constant(partial_row_size), + index_typed_constant(tile_size_for_dim), "tile_bound"); + } + + IrArray::Index tile_origin = + mapping_scheme->GetElementIndexForTileOrigin(tile_index); + emit_one_tile(tile_origin, output_tile_bounds, block_contains_multi_tiles); + }; + + const IrArray::Index starting_block = + mapping_scheme->EmitBlockIndex(index_ty); + const IrArray::Index starting_tile_for_dim_z = + mapping_scheme->GetTileIndexForBlockOrigin(starting_block); + + // Emit the three dimensional block of tiles. + emit_tiles_for_block_dim( + "block_dim_z", starting_tile_for_dim_z, KernelMappingScheme::DimZ, + [&](const IrArray::Index& starting_tile_for_dim_y) { + emit_tiles_for_block_dim( + "block_dim_y", starting_tile_for_dim_y, KernelMappingScheme::DimY, + [&](const IrArray::Index& starting_tile_for_dim_x) { + emit_tiles_for_block_dim("block_dim_x", starting_tile_for_dim_x, + KernelMappingScheme::DimX, + emit_one_tile_for_tile_index); + }); + }); +} + +// Emits a kernel for the hlo instruction using the given kernel mapping scheme. +// +// unnested_hlo: The unnested hlo instruction for which the kernel is generated. +// Currently, these hlo instructions are supported: kLoop fusion, kCopy. +// tiled_param_ids: The IDs for the parameters that are 0-2-1 transpose of +// other tensors with the same dimensions and need to be tiled and tranposed. +// mapping_scheme: The tiling scheme to use. +// kernel_generator: Contains function objects for code generation, such as +// element generator, block prologue and epilogue generators. +// kernel_info: Represent other information to support the code generation +// of the tiled kernel for the hlo. +LaunchDimensions IrEmitterUnnested::EmitKernel( + HloInstruction* unnested_hlo, absl::Span tiled_param_ids, + const KernelCodeGenerator& kernel_generator, + KernelCodegenInfo* kernel_info) { + KernelMappingScheme* mapping_scheme = kernel_info->GetKernelMappingScheme(); + + std::vector param_arrays = ConstructIrArrayForInputs(*unnested_hlo); + int64 num_params = param_arrays.size(); + // Allocate shared memory buffers to store the tiled inputs. + std::vector param_shmem_buffers(num_params, nullptr); + for (int64 id : tiled_param_ids) { + const HloInstruction* param = unnested_hlo->operand(id); + param_shmem_buffers[id] = + mapping_scheme->GetSharedMemoryBufferForElementType( + llvm_ir::PrimitiveTypeToIrType(param->shape().element_type(), + module_), + IrName(unnested_hlo, StrCat("tile", id))); + VLOG(3) << "Added shmem buffer for parameter " << id << ": " + << llvm_ir::DumpToString(*param_shmem_buffers[id]); + } + + LaunchDimensions launch_dimensions = LaunchDimensions( + mapping_scheme->GetNumberOfBlocks(), mapping_scheme->GetThreadsPerTile()); + llvm::Type* index_ty = GetIndexTypeForKernel( + unnested_hlo, launch_dimensions.launch_bound(), &b_); + auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_ty, c); + }; + + // For each tiled parameter, cast its input IrArray to the corresponding + // reduced shape and keep the reduced shape live during IR emission. + std::vector param_in_reduced_shape_arrays; + std::vector param_reduced_shapes; + absl::Span reduced_dims = + mapping_scheme->GetDimensionsInElements(); + int num_shapes = ConstructInputReducedShapeAndCastInputIrArrayToShape( + *unnested_hlo, param_arrays, param_shmem_buffers, reduced_dims, + ¶m_reduced_shapes, ¶m_in_reduced_shape_arrays); + DCHECK_EQ(num_shapes, num_params); + + // Calculate the starting element coordinate within a tile for the current + // thread, (y, x) from thread_id. + llvm::Value* x; + llvm::Value* y; + std::tie(y, x) = mapping_scheme->EmitThreadYXCoordinate(index_ty); + + kernel_info->SetLaneId( + mapping_scheme->GetNumberOfThreadsForDimensionX() == kWarpSize ? x + : nullptr); + kernel_info->SetIndexType(index_ty); + + KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll); + // Curry a few parameters to EmitTiledElementalCodeWithBoundsCheck. + auto emit_tiled_elemental_code_with_bounds_check = + [&](const IrArray::Index& index, const string& loop_name, + llvm::Value* tile_height, llvm::Value* tile_width, + const std::function& emit_elem_function) { + EmitTiledElementalCodeWithBoundsCheck(mapping_scheme, index, loop_name, + &ksl, &b_, y, x, tile_height, + tile_width, emit_elem_function); + }; + + auto emit_one_tile = [&](const IrArray::Index& output_tile_origin, + absl::Span output_tile_bounds, + bool block_contains_multi_tiles) { + // Calculate the input tile origin from the output tile origin. + const IrArray::Index input_tile_origin( + Permute({0, 2, 1}, output_tile_origin.multidim())); + + const IrArray::Index input_index = + input_tile_origin.AddOffsetToDim(x, KernelMappingScheme::DimX, &b_) + .AddOffsetToDim(y, KernelMappingScheme::DimY, &b_); + + // If shared memory transpose is needed, wait for all threads to reach this + // point, lest we copy a value from tile to output before the other thread + // copies it from input to tile. This is `__syncthreads` in CUDA. + if (!tiled_param_ids.empty()) { + // Copy input parameter values to shared memory buffers: + // tile[y, x] = input[index] + // Note that tile_width and tile_height are flipped here because we are + // reading a transposed tile. + emit_tiled_elemental_code_with_bounds_check( + input_index, "input", output_tile_bounds[2], output_tile_bounds[1], + [&](const IrArray::Index& index, llvm::Value* y_loc, + llvm::Value* x_loc) { + for (int64 id : tiled_param_ids) { + IrArray& input_in_logical_shape = + param_in_reduced_shape_arrays[id]; + llvm::Value* shmem_buffer = param_shmem_buffers[id]; + // TODO(jlebar): Add AA metadata to this store. Tile buffers are + // global variables, so LLVM can't infer much about it. + Store(input_in_logical_shape.EmitReadArrayElement( + index, &b_, "input_element"), + GEP(shmem_buffer, {index_typed_constant(0), y_loc, x_loc})); + } + }); + + // Wait for all threads to reach this point using `__syncthreads` in CUDA. + llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, &b_); + } + + llvm_ir::TiledParameterInfo tiled_param_info(param_shmem_buffers, y, x); + kernel_info->SetTiledParamInfo(&tiled_param_info); + + const IrArray::Index output_index = + output_tile_origin.AddOffsetToDim(x, KernelMappingScheme::DimX, &b_) + .AddOffsetToDim(y, KernelMappingScheme::DimY, &b_); + + // Write to output[index] by emitting code like normal, except that values + // for the tiled parameters are read from the shmem buffers. + emit_tiled_elemental_code_with_bounds_check( + output_index, "output", output_tile_bounds[1], output_tile_bounds[2], + [&](const IrArray::Index& index, llvm::Value* y_loc, + llvm::Value* x_loc) { + kernel_generator.GetTileElementGenerator()(unnested_hlo, index, + kernel_info, y_loc, x_loc); + }); + + // If a tile block contains multiple tiles and shared memory buffers are + // used, we need to wait for all threads to finish using the shared memory + // buffer for the current tile before we move on to process the next tile + // and overwrite the shared memory buffers. + if (block_contains_multi_tiles && !tiled_param_ids.empty()) { + llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, &b_); + } + }; + + const BlockPrologueGenerator& block_prologue_generator = + kernel_generator.GetBlockPrologueGenerator(); + if (block_prologue_generator) { + block_prologue_generator(unnested_hlo, kernel_info); + } + + EmitBlock(std::move(emit_one_tile), kernel_info, ksl, index_ty); + + const BlockEpilogueGenerator& block_epilogue_generator = + kernel_generator.GetBlockEpilogueGenerator(); + if (block_epilogue_generator) { + block_epilogue_generator(unnested_hlo, kernel_info); + } + + // For multioutput fusion, emit a tuple with pointers to all the individual + // outputs. + if (unnested_hlo->IsMultiOutputFusion()) { + std::vector output_arrays = + ConstructIrArrayForOutputs(*unnested_hlo); + llvm_ir::EmitTuple(GetIrArray(*unnested_hlo, *unnested_hlo), output_arrays, + &b_, module_); + } + + return launch_dimensions; +} + // Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose // algorithm to improve the memory access patterns for the input parameters -// which have a shape that is a 0-2-1 transpose of the output tensors. +// with a shape that is a 0-2-1 transpose of the output tensor shape. // // For the purpose of tiling, the output tensors have a logical shape of three -// components 0-2-1 while the relevant input parameters have a logical shape of -// three components 0-1-2 in the order major to minor. The x- and y- dimensions -// of the tensors are tiled in square tiles of edge length `kTileSize`. Each -// thread block of `kTileSize` x `kNumRows` threads transposes one tile: each -// thread copies kTileSize/kNumRows elements from the input to a shared memory -// tile, then the otherwise "regular hlo kernel" reads from the shared memory -// instead of the original input. +// components 0-2-1 while the relevant input parameters have a logical shape +// of three components 0-1-2 in the order major to minor. The x- and y- +// dimensions of the tensors are tiled in square tiles with an edge length +// `kTileSize`. Each thread block of `kTileSize` x `kNumRows` threads +// transposes one tile: each thread copies kTileSize/kNumRows elements from +// the input to a shared memory tile, then the otherwise "regular HLO kernel" +// reads from the shared memory instead of the original input. // // This is similar to the following CUDA algorithm in TensorFlow: // https://goo.gl/MStRV6. @@ -3313,219 +3016,37 @@ void EmitTiledElementalCodeWithBoundsCheck( // `kTileSize` should usually be same as warp size. We currently choose 32 for // `kTileSize` and 4 for `kNumRows`. The CUDA algorithm uses 8 for `kNumRows`. // -// TODO(b/33320379): Here each block transposes 1 tile. It may be more efficient -// to launch fewer blocks so each transposes many tiles. +// TODO(b/33320379): Here each block transposes 1 tile. It may be more +// efficient to launch fewer blocks so each transposes many tiles. LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( HloInstruction* hlo, absl::Span reduced_output_dims, absl::Span tiled_param_ids) { - // Parameters for the tiling algorithm. - constexpr int64 kTileSize = 32; - constexpr int64 kNumRows = 4; - constexpr int64 kThreadsPerTile = kTileSize * kNumRows; - - // Construct IrArrays for the inputs and outputs. - std::vector output_arrays = ConstructIrArrayForOutputs(*hlo); - int64 num_outputs = output_arrays.size(); - std::vector param_arrays = ConstructIrArrayForInputs(*hlo); - int64 num_params = param_arrays.size(); - - // Allocate shared memory buffers to store the tiled inputs. - std::vector param_shmem_buffers(num_params, nullptr); - for (int64 id : tiled_param_ids) { - const HloInstruction* param = hlo->operand(id); - // Add 1 to the minor dimension to reduce shared memory bank conflicts. - llvm::Type* tile_type = llvm::ArrayType::get( - llvm::ArrayType::get(llvm_ir::PrimitiveTypeToIrType( - param->shape().element_type(), module_), - kTileSize + 1), - kTileSize); - auto* tile_base_ptr = llvm_ir::AllocateSharedMemoryTile( - b_.GetInsertBlock()->getParent()->getParent(), tile_type, - IrName(hlo, StrCat("tile", id))); - param_shmem_buffers[id] = tile_base_ptr; - VLOG(3) << "Added shmem buffer for parameter " << id << ": " - << llvm_ir::DumpToString(*tile_base_ptr); - } - - // The 0-2-1 shape of the tiling scheme is the reduced shape of the HLO result - // for the purpose of tiling. Calculate the logical output dimensions in the - // tile from the reduced output dimensions. - std::vector output_dims_in_tiles = std::vector( - reduced_output_dims.begin(), reduced_output_dims.end()); - CHECK_EQ(output_dims_in_tiles.size(), 3); - for (int i = 1; i < 3; ++i) { - output_dims_in_tiles[i] = - CeilOfRatio(output_dims_in_tiles[i], kTileSize); - } - const int64 num_tiles = - absl::c_accumulate(output_dims_in_tiles, 1, std::multiplies()); - LaunchDimensions launch_dimensions(num_tiles, kThreadsPerTile); - - llvm::Type* index_ty = - GetIndexTypeForKernel(hlo, launch_dimensions.launch_bound(), &b_); - auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { - return llvm::ConstantInt::get(index_ty, c); - }; - - // Cast each output IrArray to its corresponding reduced shape and keep the - // reduced shape live during IR emission. - std::vector output_in_reduced_shape_arrays; - std::vector output_reduced_shapes; - CHECK_EQ(ConstructOutputReducedShapeAndCastOutputIrArrayToShape( - *hlo, output_arrays, reduced_output_dims, &output_reduced_shapes, - &output_in_reduced_shape_arrays), - num_outputs); - - // For each tiled parameter, cast its input IrArray to the corresponding - // reduced shape and keep the reduced shape live during IR emission. - std::vector param_in_reduced_shape_arrays; - std::vector param_reduced_shapes; - CHECK_EQ(ConstructInputReducedShapeAndCastInputIrArrayToShape( - *hlo, param_arrays, param_shmem_buffers, reduced_output_dims, - ¶m_reduced_shapes, ¶m_in_reduced_shape_arrays), - num_params); - - // Calculate the starting element coordinate within a tile for the current - // thread, (y, x) from thread_id. - llvm::Value* x; - llvm::Value* y; - std::tie(y, x) = CalculateYXCoordinateWithinTile( - &b_, index_typed_constant(kTileSize), kThreadsPerTile); - - // Calculate the index for the current output tile from block_id. - const IrArray::Index output_tile_index( - GetBlockIdx(&b_, index_ty, num_tiles), - ShapeUtil::MakeShapeWithDescendingLayout(PRED /*arbitrary*/, - output_dims_in_tiles), - &b_); - - // Output tile origin is the index for the first element of the current output - // tile. - const IrArray::Index output_tile_origin = [&] { - IrArray::Index index = output_tile_index; - for (int i = 1; i < 3; ++i) { - index[i] = Mul(output_tile_index[i], index_typed_constant(kTileSize), - "tile_origin." + std::to_string(i)); - } - return index; - }(); - - // Calculate the input tile origin from the output tile origin. - const IrArray::Index input_tile_origin( - Permute({0, 2, 1}, output_tile_origin.multidim())); - - // Calculate the current output tile bounds in each of the logical dimensions. - std::vector output_tile_bounds(3); - for (int i = 1; i < 3; ++i) { - // Only last row or column may not have full size. - output_tile_bounds[i] = - Select(ICmpEQ(output_tile_index[i], - index_typed_constant(output_dims_in_tiles[i] - 1)), - index_typed_constant(reduced_output_dims[i] - - (output_dims_in_tiles[i] - 1) * kTileSize), - index_typed_constant(kTileSize), "kTileSize"); - } - - KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll); - - // Curry a few parameters to EmitTiledElementalCodeWithBoundsCheck. - auto emit_tiled_elemental_code_with_bounds_check = - [&](const IrArray::Index& index, const string& loop_name, - llvm::Value* tile_width, llvm::Value* tile_height, - const std::function& - emit_elem_function) { - EmitTiledElementalCodeWithBoundsCheck( - kTileSize, kNumRows, index, loop_name, &ksl, &b_, y, x, tile_width, - tile_height, emit_elem_function); - }; - - // Adds `addend` to the given `dim` of `index`. - auto offset_dim = [&](IrArray::Index index, llvm::Value* addend, int64 dim) { - index[dim] = Add(index[dim], addend); - return index; - }; - const IrArray::Index input_index = - offset_dim(offset_dim(input_tile_origin, x, /*dim=*/2), y, /*dim=*/1); - - // Copy input parameter values to shared memory buffers: - // tile[y, x] = input[index] - emit_tiled_elemental_code_with_bounds_check( - input_index, "input", output_tile_bounds[1], output_tile_bounds[2], - [&](const IrArray::Index& index, llvm::Value* y_loc) { - for (int64 id : tiled_param_ids) { - IrArray& input_in_logical_shape = param_in_reduced_shape_arrays[id]; - llvm::Value* shmem_buffer = param_shmem_buffers[id]; - // TODO(jlebar): Add AA metadata to this store. Tile buffers are - // global variables, so LLVM can't infer much about it. - Store(input_in_logical_shape.EmitReadArrayElement(index, &b_, - "input_element"), - GEP(shmem_buffer, {index_typed_constant(0), y_loc, x})); - } - }); - - // Wait for all threads to reach this point, lest we copy a value from tile to - // output before the other thread copies it from input to tile. - // This is `__syncthreads` in CUDA. - llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, &b_); - - llvm_ir::TiledParameterInfo tiled_param_info(param_shmem_buffers, y, x); - - const IrArray::Index output_index = - offset_dim(offset_dim(output_tile_origin, x, /*dim=*/2), y, /*dim=*/1); - - // Write to output[index] by emitting code like normal, except that values for - // the tiled parameters are read from the shmem buffers. + constexpr int kNumRows = 4; + KernelMappingScheme mapping_scheme( + reduced_output_dims, /*tile_size_y=*/kWarpSize, + /*tile_size_x=*/kWarpSize, /*req_block_sizes=*/{1, 1, 1}, + /*num_threads_y=*/kNumRows, + /*num_threads_x=*/kWarpSize, &b_); + TileElementGenerator element_generator; if (hlo->opcode() == HloOpcode::kCopy) { - emit_tiled_elemental_code_with_bounds_check( - output_index, "output", output_tile_bounds[2], output_tile_bounds[1], - [&](const IrArray::Index& index, llvm::Value* y_loc) { - // TODO(jlebar): Add AA metadata to this load. - llvm::Instruction* load_from_shmem_buffer = - Load(GEP(param_shmem_buffers[0], {b_.getInt64(0), x, y_loc}), - "output_element"); - output_in_reduced_shape_arrays[0].EmitWriteArrayElement( - index, load_from_shmem_buffer, &b_); - }); + element_generator = [&](HloInstruction* hlo, + const llvm_ir::IrArray::Index& index, + const KernelCodegenInfo* kernel_info, + llvm::Value* y_loc, llvm::Value* x_loc) { + EmitTileElementForCopy(hlo, index, kernel_info, y_loc, x_loc); + }; } else { - CHECK_EQ(hlo->opcode(), HloOpcode::kFusion); - emit_tiled_elemental_code_with_bounds_check( - output_index, "output", output_tile_bounds[2], output_tile_bounds[1], - [&](const IrArray::Index& index, llvm::Value* y_loc) { - GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_, - GetNestedComputer()); - FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(hlo), - &elem_emitter); - tiled_param_info.set_y(y_loc); - fused_emitter.SetTiledParameterInfo(&tiled_param_info); - TF_CHECK_OK(hlo->fused_expression_root()->Accept(&fused_emitter)); - IrArray::Index untiled_index = llvm_ir::GetUnreducedOutputIndex( - index, output_reduced_shapes[0], output_arrays[0].GetShape(), - &b_); - const llvm_ir::ElementGenerator& output_generator = - fused_emitter.GetRootGenerator(); - llvm::Value* output_value = - output_generator(untiled_index).ValueOrDie(); - if (hlo->IsMultiOutputFusion()) { - CHECK(output_value->getType()->isStructTy()); - CHECK_EQ(output_value->getType()->getStructNumElements(), - output_in_reduced_shape_arrays.size()); - for (int64 i = 0; i < output_in_reduced_shape_arrays.size(); ++i) { - output_in_reduced_shape_arrays[i].EmitWriteArrayElement( - index, ExtractValue(output_value, i), &b_); - } - } else { - output_in_reduced_shape_arrays[0].EmitWriteArrayElement( - index, output_value, &b_); - } - }); + DCHECK_EQ(hlo->opcode(), HloOpcode::kFusion); + element_generator = [&](HloInstruction* hlo, + const llvm_ir::IrArray::Index& index, + const KernelCodegenInfo* kernel_info, + llvm::Value* y_loc, llvm::Value* x_loc) { + EmitTileElementForFusion(hlo, index, kernel_info, y_loc, x_loc); + }; } - - // For multioutput fusion, emit a tuple with all the individual outputs. - if (hlo->IsMultiOutputFusion()) { - llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo), output_arrays, &b_, module_); - } - - return launch_dimensions; + KernelCodegenInfo kernel_info(&mapping_scheme); + KernelCodeGenerator kernel_generator(std::move(element_generator)); + return EmitKernel(hlo, tiled_param_ids, kernel_generator, &kernel_info); } namespace { @@ -3562,8 +3083,8 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { ? ShapeUtil::GetSubshape(hlo->shape(), {0}) : hlo->shape(); - // If the output_shape is reduced to 021 shape, find all the parameters of the - // hlo that are in the corresponding 012 shape. + // If the output_shape is reduced to 021 shape, find all the parameters of + // the HLO that are in the corresponding 012 shape. std::vector params_012; optional> reduced_dims_021; for (int64 operand_idx = 0; operand_idx < hlo->operand_count(); @@ -3600,9 +3121,9 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { } // Each of our shared memory tiles has 32*33 elements (so ~4kb, if the - // elements are of size 4 bytes), and CUDA has an architectural limit of 48kb - // shared memory per SM. (This is increased to 96kb in Volta, but we don't - // use this, in part because it eats into our L1 cache space.) + // elements are of size 4 bytes), and CUDA has an architectural limit of + // 48kb shared memory per SM. (This is increased to 96kb in Volta, but we + // don't use this, in part because it eats into our L1 cache space.) // // For correctness we need to ensure that we don't make more than 48kb worth // of shmem tiles per block. And for performance, we'd probably like to use @@ -3610,9 +3131,9 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { // gpu core. // // We say without benchmarks that we want at least 3 threads/block, - // corresponding to 3 shmem tiles if the elements are 32 bits wide. We choose - // which params get the shmem transpose treatment arbitrarily; it's not clear - // if there's a Right Choice. + // corresponding to 3 shmem tiles if the elements are 32 bits wide. We + // choose which params get the shmem transpose treatment arbitrarily; it's + // not clear if there's a Right Choice. // // This is only sound if tiled transposes are the only place where we use // shared memory in fusions. If in the future other fusible ops use shared @@ -3645,6 +3166,246 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { return true; } +namespace { +// Checks that the outputs of a fusion with reduction are consistent. +Status AreFusedReductionOutputsConsistent( + absl::Span output_instructions, + const HloInstruction* first_reduce) { + for (const HloInstruction* inst : output_instructions) { + if (inst->opcode() == HloOpcode::kReduce) { + // Shapes, layouts and dimensions must be the same for all reduces + // inside of this fusion. + TF_RET_CHECK(ShapeUtil::Equal(first_reduce->shape(), inst->shape())); + TF_RET_CHECK(ShapeUtil::Equal(first_reduce->operand(0)->shape(), + inst->operand(0)->shape())); + TF_RET_CHECK(ShapeUtil::Equal(first_reduce->operand(1)->shape(), + inst->operand(1)->shape())); + TF_RET_CHECK(first_reduce->dimensions() == inst->dimensions()); + } else { + // For extra outputs we can relax shape equality to allow different + // types (with the same number of elements). Layouts still have to + // match. + TF_RET_CHECK(ShapeUtil::CompatibleIgnoringElementType( + first_reduce->operand(0)->shape(), inst->shape())); + TF_RET_CHECK(LayoutUtil::Equal(first_reduce->operand(0)->shape().layout(), + inst->shape().layout())); + } + } + return Status::OK(); +} + +// Finds the dimensions to keep for the reduction, sorts and returns the +// dimensions from minor to major. +DimensionVector GetDimensionsToKeepMinorToMajor( + const Shape& input_shape, absl::Span dims_to_reduce) { + DimensionVector input_dims(ShapeUtil::Rank(input_shape), 0); + absl::c_iota(input_dims, 0); + DimensionVector input_dims_to_keep; + for (int input_dim : input_dims) { + auto it = absl::c_find_if(dims_to_reduce, [&](int64 dim_to_reduce) { + return dim_to_reduce == input_dim; + }); + if (it == dims_to_reduce.end()) { + input_dims_to_keep.push_back(input_dim); + } + } + + // Sort the dimensions to keep from minor to major. + absl::c_sort(input_dims_to_keep, [&input_shape](int64 dim_a, int64 dim_b) { + return PositionInContainer(LayoutUtil::MinorToMajor(input_shape), dim_a) < + PositionInContainer(LayoutUtil::MinorToMajor(input_shape), dim_b); + }); + + VLOG(10) << "dims to keep minor to major" + << absl::StrJoin(input_dims_to_keep, ","); + return input_dims_to_keep; +} + +// Given the input shape and dimensions to reduce for the reduction to vector, +// returns : +// num_kept: the number of elements in the contiguous dimensions to keep. +// num_reduced_major: the number of elements in the dimensions to reduce that +// are more major than the dimensions to keep. +// num_reduced_minor: the number of elements in the dimensions to reduce that +// are more minor than the dimensions to kept. +std::tuple GetReductionToVectorDimensions( + const Shape& input_shape, absl::Span dims_to_reduce) { + DimensionVector input_dims_to_keep_minor_to_major = + GetDimensionsToKeepMinorToMajor(input_shape, dims_to_reduce); + CHECK(LayoutUtil::AreDimensionsConsecutive( + input_shape.layout(), input_dims_to_keep_minor_to_major)); + int num_reduced_major = 1, num_kept = 1, num_reduced_minor = 1; + if (input_dims_to_keep_minor_to_major.empty()) { + return std::make_tuple(num_reduced_major, num_kept, num_reduced_minor); + } + DimensionVector input_dims(ShapeUtil::Rank(input_shape), 0); + absl::c_iota(input_dims, 0); + absl::Span minor_to_major = + LayoutUtil::MinorToMajor(input_shape); + for (int input_dim : input_dims) { + int64 curr_dim_size = input_shape.dimensions(input_dim); + if (PositionInContainer(minor_to_major, input_dim) > + PositionInContainer(minor_to_major, + input_dims_to_keep_minor_to_major.back())) { + num_reduced_major *= curr_dim_size; + } else if (PositionInContainer(minor_to_major, input_dim) < + PositionInContainer(minor_to_major, + input_dims_to_keep_minor_to_major.front())) { + num_reduced_minor *= curr_dim_size; + } else { + num_kept *= curr_dim_size; + } + } + + return std::make_tuple(num_reduced_major, num_kept, num_reduced_minor); +} + +std::tuple ComputeMappingSchemeAndReductionKind( + const HloInstruction* first_reduce, llvm::IRBuilder<>* b) { + int64 depth = 1; + int64 height = 1; + int64 width = 1; + bool is_row_reduction = true; + int64 tile_size_x = 1; + int64 tile_size_y = 1; + int64 block_size_y = 1; + int64 block_size_z = 1; + int64 num_threads_x = 1; + int64 num_threads_y = 1; + const Shape& input_shape = first_reduce->operand(0)->shape(); + int64 num_input_elems = ShapeUtil::ElementsIn(input_shape); + int64 num_output_elems = ShapeUtil::ElementsIn(first_reduce->shape()); + int64 num_reduced_major, num_kept, num_reduced_minor; + std::tie(num_reduced_major, num_kept, num_reduced_minor) = + GetReductionToVectorDimensions(input_shape, first_reduce->dimensions()); + CHECK_EQ(num_output_elems, num_kept); + + if (num_kept == 1) { + // Scalar reduction is a special row reduction with depth = height = 1. + width = num_input_elems; + tile_size_x = kWarpSize * 16; + num_threads_x = kWarpSize; + } else if (num_reduced_minor == 1) { + // Column reduction reduces inputs with dimension [height, width], where + // width is the minor dimension, to dimension [width]. + height = num_reduced_major; + width = num_kept; + is_row_reduction = false; + tile_size_x = std::min(kWarpSize, num_kept); + // The old Column reduction algorithm uses kTileHeight = 128. We choose + // tile_size_y * block_size_y = 128 to match the value of kTileHeight. Using + // a non-trivial block_size_y here is a way to avoid unrolling all the 128 + // iterations. + tile_size_y = 32; + block_size_y = 4; + num_threads_x = tile_size_x; + } else { + // Row reduction reduces inputs with dimension [depth, height, width], + // where width is the most minor dimension, to dimension [height] . + depth = num_reduced_major; + height = num_kept; + width = num_reduced_minor; + num_threads_x = kWarpSize; + if (width % (kWarpSize * 64) == 0) { + tile_size_x = kWarpSize * 64; + } else { + tile_size_x = kWarpSize * 8; + block_size_z = 8; + while (depth % block_size_z != 0) { + block_size_z -= 1; + } + } + } + DCHECK_EQ(depth * height * width, num_input_elems); + VLOG(10) << "is_row_reduction " << is_row_reduction << depth << " " << height + << " " << width; + + DimensionVector dims_in_elem{depth, height, width}; + DimensionVector req_block_sizes{block_size_z, block_size_y, 1}; + llvm_ir::KernelMappingScheme mapping_scheme(dims_in_elem, tile_size_y, + tile_size_x, req_block_sizes, + num_threads_y, num_threads_x, b); + return std::make_tuple(mapping_scheme, is_row_reduction); +} + +} // namespace + +Status IrEmitterUnnested::EmitReductionToVector(HloInstruction* unnested_hlo) { + VLOG(10) << "Emitting reduction to vector " << unnested_hlo->ToString(); + + HloInstruction* reduce_or_tuple = unnested_hlo->opcode() == HloOpcode::kFusion + ? unnested_hlo->fused_expression_root() + : unnested_hlo; + absl::Span output_instructions = + GetOutputInstructions(&reduce_or_tuple); + const HloInstruction* first_reduce = + GetFirstReduceInstruction(output_instructions); + + if (output_instructions.size() > 1) { + TF_RETURN_IF_ERROR( + AreFusedReductionOutputsConsistent(output_instructions, first_reduce)); + } + + // Build an initializer thunk to initialize each reduction output. + std::vector> thunks; + for (int i = 0, e = output_instructions.size(); i != e; ++i) { + if (output_instructions[i]->opcode() != HloOpcode::kReduce) { + continue; + } + TF_ASSIGN_OR_RETURN( + std::unique_ptr initializer_thunk, + BuildInitializerThunk(unnested_hlo, + (output_instructions[i] == reduce_or_tuple) + ? ShapeIndex() + : ShapeIndex({i}))); + thunks.push_back(std::move(initializer_thunk)); + } + + // Build a kernel thunk to compute all the outputs. + std::unique_ptr kernel_thunk = + BuildKernelThunk(unnested_hlo, /*implements_whole_instruction=*/false); + + const Shape& input_shape = first_reduce->operand(0)->shape(); + // The layout of a reduction input is either set by LayoutAssignment for + // unnested kReduce or by InstructionFusion for fused kReduce. + CHECK(input_shape.has_layout()) << "LayoutAssignment or InstructionFusion " + "doesn't set the input layout of " + << first_reduce->ToString(); + + bool is_row_reduction; + llvm_ir::KernelMappingScheme mapping_scheme; + std::tie(mapping_scheme, is_row_reduction) = + ComputeMappingSchemeAndReductionKind(first_reduce, &b_); + ReductionCodegenInfo reduction_info(&mapping_scheme, is_row_reduction); + KernelCodeGenerator kernel_generator( + /*tile_element_generator=*/ + [&](HloInstruction* hlo, const llvm_ir::IrArray::Index& index, + const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, + llvm::Value* x_loc) { + EmitTileElementForReduction(hlo, index, kernel_info, y_loc, x_loc); + }, + /*block_prologue_generator=*/ + [&](HloInstruction* hlo, KernelCodegenInfo* kernel_info) { + EmitPrologueForReduction(hlo, kernel_info); + }, + /*block_epilogue_generator*/ + [&](HloInstruction* hlo, KernelCodegenInfo* kernel_info) { + EmitEpilogueForReduction(hlo, kernel_info); + }); + + LaunchDimensions launch_dimensions = + EmitKernel(unnested_hlo, {}, kernel_generator, &reduction_info); + UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(), + ir_emitter_context_->llvm_module()); + + thunks.push_back(std::move(kernel_thunk)); + std::unique_ptr sequential_thunk = + absl::make_unique(std::move(thunks), unnested_hlo); + AddThunkToThunkSequence(std::move(sequential_thunk)); + + return Status::OK(); +} + Status IrEmitterUnnested::EmitConstantGlobals() { for (const BufferAllocation& allocation : ir_emitter_context_->buffer_assignment().Allocations()) { @@ -3666,10 +3427,10 @@ Status IrEmitterUnnested::EmitConstantGlobals() { } // These globals will be looked up by name by GpuExecutable so we need to - // give them an external linkage. Not all of their uses are visible in the - // LLVM IR (e.g. TupleThunk) so we can't give then a linkage that merely - // preserves their names (like available_externally), we also need to ensure - // that they stick around even if they're "unused". + // give them an external linkage. Not all of their uses are visible in + // the LLVM IR (e.g. TupleThunk) so we can't give then a linkage that + // merely preserves their names (like available_externally), we also need + // to ensure that they stick around even if they're "unused". // // We may have to be more more clever here in the future if we notice that // we're keeping around too many globals because of their linkage. diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 334c0b3c20b..85a0e5328c4 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -16,9 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_ +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" #include "tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h" namespace xla { @@ -47,6 +49,99 @@ namespace gpu { // class IrEmitterUnnested : public IrEmitter { public: + // Parameter block_contains_multi_tiles indicates whether a tile block + // consists of multiple tiles or not. If the tile block contains only one + // tile, there is no need to use atomic operation to accumulate a local result + // to a global result to implement reduction. + using TileGenerator = + std::function output_tile_bounds, + bool block_contains_multi_tiles)>; + // KernelCodegenInfo records the common information to support the code + // generation for a kernel to process tensor elements by blocks. A block of + // tensor elements may contain one or multiple tiles. The code generators that + // generate code for tile elements or block prologue/epilogue refer to this + // class in their prototypes. If the implementations of such code generators + // require other information that are specific to the HLO instructions, the + // implementations need to define and use derived classes of this class. + class KernelCodegenInfo { + public: + explicit KernelCodegenInfo(llvm_ir::KernelMappingScheme* mapping_scheme) + : mapping_scheme_(mapping_scheme), + tiled_param_info_(nullptr), + lane_id_(nullptr), + index_ty_(nullptr) {} + virtual ~KernelCodegenInfo() {} + + void SetLaneId(llvm::Value* v) { lane_id_ = v; } + void SetIndexType(llvm::Type* t) { index_ty_ = t; } + void SetTiledParamInfo(llvm_ir::TiledParameterInfo* tiled_param_info) { + CHECK_EQ(tiled_param_info_, nullptr); + tiled_param_info_ = tiled_param_info; + } + + llvm::Value* GetLaneId() const { return lane_id_; } + llvm_ir::KernelMappingScheme* GetKernelMappingScheme() const { + return mapping_scheme_; + } + llvm_ir::TiledParameterInfo* GetTiledParameterInfo() const { + return tiled_param_info_; + } + llvm::Type* GetIndexType() const { return index_ty_; } + + private: + llvm_ir::KernelMappingScheme* mapping_scheme_; + llvm_ir::TiledParameterInfo* tiled_param_info_; + llvm::Value* lane_id_; + llvm::Type* index_ty_; + }; + + // A function object to prepare for the code generation for a tile block. + using BlockPrologueGenerator = + std::function; + // A function object to finalize the code generation for a tile block. + using BlockEpilogueGenerator = + std::function; + // A function object to generate code to process one element in a tile. + // + // hlo: the instruction for which the code is generated for. + // index: the index for the first output element of the current thread. + // y_loc: The y coordinate within a tile. + // x_loc: The x coordinate within a tile. + // kernel_info: Other information to support the kernel code generation. + using TileElementGenerator = std::function; + + // KernelCodeGenerator records the code generator objects that generate code + // for tile elements or tile block prologue/epilogue. + class KernelCodeGenerator { + public: + explicit KernelCodeGenerator( + TileElementGenerator tile_element_generator, + BlockPrologueGenerator block_prologue_generator = {}, + BlockEpilogueGenerator block_epilogue_generator = {}) + : tile_element_generator_(std::move(tile_element_generator)), + block_prologue_generator_(std::move(block_prologue_generator)), + block_epilogue_generator_(std::move(block_epilogue_generator)) {} + + const TileElementGenerator& GetTileElementGenerator() const { + return tile_element_generator_; + } + const BlockPrologueGenerator& GetBlockPrologueGenerator() const { + return block_prologue_generator_; + } + const BlockEpilogueGenerator& GetBlockEpilogueGenerator() const { + return block_epilogue_generator_; + } + + private: + TileElementGenerator tile_element_generator_; + BlockPrologueGenerator block_prologue_generator_; + BlockEpilogueGenerator block_epilogue_generator_; + }; + IrEmitterUnnested(const HloModuleConfig& hlo_module_config, const HloComputation* hlo_computation, IrEmitterContext* ir_emitter_context); @@ -82,7 +177,7 @@ class IrEmitterUnnested : public IrEmitter { Status HandleSort(HloInstruction* sort) override; Status HandleTupleSelect(HloInstruction* tuple_select) override; Status HandleCrossReplicaSum(HloInstruction* crs) override; - Status HandleAfterAll(HloInstruction* gen_token) override; + Status HandleAfterAll(HloInstruction* after_all) override; Status EmitTargetElementLoop( const HloInstruction& hlo, @@ -111,82 +206,14 @@ class IrEmitterUnnested : public IrEmitter { // Helper for writing extra outputs from inside a reduce kernel. Status EmitExtraOutputsForReduce( - const HloInstruction* reduce, const llvm_ir::IrArray::Index& index, + const HloInstruction* unnested_hlo, const llvm_ir::IrArray::Index& index, absl::Span> extra_output_gens); - // EmitColumnReduction and EmitRowReduction emit code for column and row - // reduction of a matrix and/or 3D tensor. Row and column reduction have - // different memory access pattern, so for performance their implementations - // are significantly different. + // Generates code for reduction to contiguous dimensions. // - // Emits code that reduces a matrix of shape [height x width] to a vector of - // [width]. Other parameters have the same meaning as those of - // `EmitReductionToVector`. Note that input shape might not be - // [height x width], but can be bitcast to [height x width] with "height" - // being the major dimension. - Status EmitColumnReduction( - KernelThunk* kernel_thunk, int64 height, int64 width, - HloInstruction* reduce, const Shape& input_shape, - absl::Span input_gens, - absl::Span init_value_gens, - absl::Span reducers, - absl::Span reduce_output_shapes, - absl::Span> - extra_output_gens); - - // Emits code that reduces a 3D tensor of shape [depth x height x width] to a - // vector of shape [height]. Other parameters have the same meaning as those - // of `EmitReductionToVector`. Note that input shape might not be - // [depth x height x width], but can be bitcast to [depth x height x width] - // with "depth" being the most major dimension. - Status EmitRowReduction( - KernelThunk* kernel_thunk, int64 depth, int64 height, int64 width, - HloInstruction* reduce, const Shape& input_shape, - absl::Span input_gens, - absl::Span init_value_gens, - absl::Span reducers, - absl::Span reduce_output_shapes, - absl::Span> - extra_output_gens); - - // Emits code that reduces a tensor of arbitrary rank to a scalar. - Status EmitReductionToScalar( - KernelThunk* kernel_thunk, HloInstruction* reduce, - const Shape& input_shape, - absl::Span input_gens, - absl::Span init_value_gens, - absl::Span reducers, - absl::Span reduce_output_shapes, - absl::Span> - extra_output_gens); - - // Figures out whether `reduce` is a row or column reduction, and which - // dimensions to reduce, and calls either `EmitRowReduction` or - // `EmitColumnReduction` as appropriate. `input_shape` is the shape of the - // input array, which is the operand of the Reduce instruction if unfused or - // of the Fusion instruction if fused. `input_gen` and `init_value_gen` - // generate elements of the input and the initial value. Other parameters mean - // the same as for `HandleReduce`. - // - // Multiple reduces can be emitted in the same loop, assuming they have the - // same input and output shapes, and the same reduce dimensions. - // - // extra_output_gens can contain extra generators for intermediate outputs. - // These must have the same shape as the reduce input as they are computed - // when the reduce inputs are being read. - // - // Prerequisite: `IsReductionToVector(*reduce)` - Status EmitReductionToVector( - KernelThunk* kernel_thunk, HloInstruction* reduce, - const Shape& input_shape, - absl::Span input_gens, - absl::Span init_value_gens, - absl::Span dimensions_to_reduce, - absl::Span reducers, - absl::Span reduce_output_shapes, - absl::Span> - extra_output_gens); + // Prerequisite: `IsReductionToVector(*unnested_hlo)` + Status EmitReductionToVector(HloInstruction* unnested_hlo); // Emits code for an in-place scatter, modifying `thunk`s launch dimensions in // the process. `scatter` may be fused, scatter indices are taken from @@ -205,22 +232,55 @@ class IrEmitterUnnested : public IrEmitter { LaunchDimensions EmitHlo021Tile(HloInstruction* hlo, absl::Span reduced_output_dims, absl::Span tiled_param_ids); + // Emits a kernel for an unnested HLO instruction. + LaunchDimensions EmitKernel(HloInstruction* unnested_hlo, + absl::Span param_ids, + const KernelCodeGenerator& kernel_generator, + KernelCodegenInfo* kernel_info); + void EmitBlock(const TileGenerator& emit_one_tile, + const KernelCodegenInfo* kernel_info, + KernelSupportLibrary& ksl, llvm::Type* index_ty); + // Emits code to process a tensor element in a tile for the given kCopy HLO + // that performs a 0-2-1 transpose. + void EmitTileElementForCopy(HloInstruction* hlo, + const llvm_ir::IrArray::Index& index, + const KernelCodegenInfo* kernel_info, + llvm::Value* y_loc, llvm::Value* x_loc); + // Emits code to process a tensor element in a tile for the given kLoop fusion + // HLO containing parameters that are 0-2-1 transpose of its outputs. + void EmitTileElementForFusion(HloInstruction* hlo, + const llvm_ir::IrArray::Index& index, + const KernelCodegenInfo* kernel_info, + llvm::Value* y_loc, llvm::Value* x_loc); + // Emits code to process a tensor element in a tile for the given input hlo + // that is either a unnested kReduce or a kInput fusion. + void EmitTileElementForReduction(HloInstruction* unnested_hlo, + const llvm_ir::IrArray::Index& index, + const KernelCodegenInfo* kernel_info, + llvm::Value* y_loc, llvm::Value* x_loc); + // Prepares for the code generation for a tile block of a reduction kernel. + void EmitPrologueForReduction(HloInstruction* unnested_hlo, + KernelCodegenInfo* kernel_info); + void EmitPrologueForOneReduction(HloInstruction* unnested_hlo, + HloInstruction* reduce_inst, int reduce_idx, + KernelCodegenInfo* kernel_info, + GpuElementalIrEmitter* elemental_emitter, + ShapeIndex output_shape_index); + // Wraps up the code generation for a tile block of a reduction kernel. + void EmitEpilogueForReduction(HloInstruction* unnested_hlo, + KernelCodegenInfo* kernel_info); + // For each reducer, emits the shuffle-down loop to accumulate the partial + // result to the global result. + void EmitFullWarpShuffleDownLoopForAllReduces( + const absl::InlinedVector& reducers, + const absl::InlinedVector& + partial_result_addresses); // Generates the IrArray for each input of an hlo and returns a vector that // constains such IrArrays. std::vector ConstructIrArrayForInputs( const HloInstruction& hlo); - // For each output of the `hlo` instruction, constructs the reduced shape for - // the output with the given `reduced_output_dims` and cast the original - // output IrArray element in `output_arrays` to the reduced shape. Returns - // the number of outputs. - int ConstructOutputReducedShapeAndCastOutputIrArrayToShape( - const HloInstruction& hlo, - const std::vector& output_arrays, - absl::Span reduced_output_dims, - std::vector* output_reduced_shapes, - std::vector* output_in_reduced_shape_arrays); // For each input of the `hlo` instruction, checks its value in // `param_buffers` to find out whether the input has a reduced shape. If the // input has a reduced shape, constructs the reduced shape for the input and diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc index 8751e3a9c2a..24f07e68973 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc @@ -177,13 +177,6 @@ std::unique_ptr GetTargetMachine( } TargetOptions target_options = InitTargetOptionsFromCodeGenFlags(); - llvm_ir::SetTargetOptions( - /*fast_math_enabled=*/hlo_module_config.debug_options() - .xla_gpu_enable_fast_math(), - &target_options); - - // Enable FMA synthesis. - target_options.AllowFPOpFusion = FPOpFusion::Fast; // Set the verbose assembly options. target_options.MCOptions.AsmVerbose = false; @@ -453,18 +446,21 @@ void GPUBackendInit(const HloModuleConfig& hlo_module_config) { // * 3-6 gives similar results as 2; // * >6 start hurting the performance of at least dot product kernels. // - // TODO(jingyue): The current threshold only considers the numbr of IR + // TODO(jingyue): The current threshold only considers the number of IR // instructions which do not accurately reflect the true cost. We need a // better cost model. FeedLLVMWithFlags({"-bonus-inst-threshold=2"}); - // TODO(b/22073864): Increase limit when scan memory dependency. - // This helps to reduce more redundant load instructions. + // Increase limit when scanning memory dependencies. This helps to reduce + // more redundant load instructions. // // The specific value is currently large enough for s3d in shoc benchmark, // which contains a lot of load instructions and many arithmetic instructions // between those loads. FeedLLVMWithFlags({"-memdep-block-scan-limit=500"}); + // Use div.approx -- it matters for some float-division heavy benchmarks. + FeedLLVMWithFlags({"-nvptx-prec-divf32=0"}); + llvm_ir::InitializeLLVMCommandLineOptions(hlo_module_config); // Initialize the NVPTX target; it's the only target we link with, so call its diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc index d9b06828e2b..01fddcede64 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -41,50 +41,7 @@ GpuMultiOutputFusion::GpuMultiOutputFusion() : MultiOutputFusion(INT64_MAX) {} bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1, HloInstruction* instr2) { - auto get_element_instr = - [&](const HloInstruction* instr) -> const HloInstruction* { - const HloInstruction* element_instr = instr; - if (instr->opcode() == HloOpcode::kFusion) { - auto fused_expression_root = instr->fused_expression_root(); - if (instr->IsMultiOutputFusion()) { - // If possible, we want to pick a reduce operand of the fusion root, - // because it has the most constraints. - for (const auto* inst : fused_expression_root->operands()) { - if (IsReductionToVector(*inst)) { - return inst; - } - } - return fused_expression_root->operands()[0]; - } else { - element_instr = fused_expression_root; - } - } - return element_instr; - }; - - auto get_element_shape = [&](const HloInstruction* element_instr) { - // Special handling of kReduce instructions -- the fusion - // applies to the first operand. - if (IsReductionToVector(*element_instr)) { - return element_instr->operand(0)->shape(); - } - return element_instr->shape(); - }; - - // The shapes in all tuple operands should agree, unless it is a reduce. - // In that case, the operand of the reduce needs to have the same shape - // as the other tuple operands, but also we need to compare the output - // shapes of the reduces. - auto* element_instr_1 = get_element_instr(instr1); - auto* element_instr_2 = get_element_instr(instr2); - if (element_instr_1->opcode() == HloOpcode::kReduce && - element_instr_2->opcode() == HloOpcode::kReduce && - !ShapeUtil::Equal(element_instr_1->shape(), element_instr_2->shape())) { - return false; - } - // The elementwise output shapes must be the same (including layout). - return ShapeUtil::EqualIgnoringFpPrecision( - get_element_shape(element_instr_1), get_element_shape(element_instr_2)); + return ShapesCompatibleForMultiOutputFusion(*instr1, *instr2); } bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) { @@ -205,7 +162,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { VLOG(3) << producer->name() << " is not a loop fusion."; continue; } - if (!ShapesCompatibleForFusion(producer, consumer)) { + if (!ShapesCompatibleForMultiOutputFusion(*producer, *consumer)) { VLOG(3) << producer->name() << " has an incompatible shape."; continue; } diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc index dc221f22a74..d16c87ba5c6 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc @@ -580,7 +580,7 @@ TEST_F(MultiOutputFusionTest, AvoidsLargeFusion) { // ... // where each of the (pi * pj)'s is represented as a fusion node so that // multi-output fusion will pay attention to it. - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder b(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {10, 100}); diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index de04ed85c30..e934cbda176 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -67,6 +67,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_cse.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" +#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" @@ -173,13 +174,16 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); + pipeline.AddPass(); + // BatchNormExpander can create zero-sized ops, so zero-sized HLO // elimination has to come after that pass. pipeline.AddPass(); - pass.AddPass( - /*is_layout_sensitive=*/false, + AlgebraicSimplifierOptions options( [](const Shape&, const Shape&) { return false; }); + options.set_enable_permutation_sort_replacement(true); + pass.AddPass(options); pass.AddPass(); pass.AddPass(); pass.AddPass(); @@ -248,11 +252,13 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. - pipeline.AddPass>( - /*is_layout_sensitive=*/true, + AlgebraicSimplifierOptions options( /*valid_bitcast_callback=*/[](const Shape&, const Shape&) { return true; }); + options.set_is_layout_sensitive(true); + options.set_enable_permutation_sort_replacement(true); + pipeline.AddPass>(options); // Choose the fastest algorithm for each conv. // @@ -810,7 +816,7 @@ std::vector NVPTXCompiler::CompilePtxOrGetCachedResult(const string& ptx, // binaries are not available. We don't want to spam logs with // identical warnings in this case. - // TODO(zhengxq): we should implement a LOG_FIRST_N and LOG_EVERY_N + // TODO(jlebar): we should implement a LOG_FIRST_N and LOG_EVERY_N // for more general usage. static std::atomic warning_done(false); log_warning = !warning_done.exchange(true); diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc index f2ef11e1e6a..31a5d7a8c04 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc @@ -30,7 +30,7 @@ namespace gpu { class StreamAssignmentTest : public HloTestBase { protected: - std::unique_ptr CreateNewUnverifiedModule() { + std::unique_ptr CreateNewVerifiedModule() { HloModuleConfig config; auto debug_options = GetDebugOptionsForTest(); debug_options.set_xla_gpu_disable_multi_streaming(false); @@ -55,7 +55,7 @@ TEST_F(StreamAssignmentTest, SequentialMatMul) { HloInstruction* dot2 = builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, z)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build(dot2)); std::unique_ptr assignment = AssignStreams(*module); @@ -76,7 +76,7 @@ TEST_F(StreamAssignmentTest, ConcurrentMatMul) { HloInstruction* add = builder.AddInstruction( HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build(add)); std::unique_ptr assignment = AssignStreams(*module); @@ -120,7 +120,7 @@ TEST_F(StreamAssignmentTest, LatticeMatMul) { HloInstruction* d40 = builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d30, d31)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build(d40)); std::unique_ptr assignment = AssignStreams(*module); diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h index d2f30ae7bc4..d917320e363 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h @@ -26,7 +26,7 @@ namespace gpu { // Tests that verify IR or PTX emitted by the GPU backend is as expected. class GpuCodegenTest : public LlvmIrGenTestBase { protected: - // Like HloTestBase::CreateNewUnverifiedModule(), with a flag for configuring + // Like HloTestBase::CreateNewVerifiedModule(), with a flag for configuring // the ftz option. std::unique_ptr CreateNewUnverifiedModuleWithFTZ(bool ftz); diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc index 268b48a1cad..a1ed8499040 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc @@ -46,7 +46,7 @@ TEST_F(GpuCopyTest, UseMemcpy) { std::unique_ptr computation = builder.Build(); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(std::move(computation)); // There should not be any kernel prefixed "copy". diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc index d0ccd8619bd..5e524faab18 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc @@ -75,16 +75,16 @@ class GpuFtzDisabledTest : public GpuFtzTest { // Check that we emit mul.ftz.f32 when in ftz mode, and plain mul.f32 otherwise. TEST_F(GpuFtzEnabledTest, MultiplyFtz) { CompileAndVerifyPtx(CreateBinaryOpModule(HloOpcode::kMultiply), R"( - CHECK-NOT: mul.f32 - CHECK: mul.ftz.f32 - CHECK-NOT: mul.f32 + CHECK-NOT: mul.rn.f32 + CHECK: mul.rn.ftz.f32 + CHECK-NOT: mul.rn.f32 )"); } TEST_F(GpuFtzDisabledTest, MultiplyFtz) { CompileAndVerifyPtx(CreateBinaryOpModule(HloOpcode::kMultiply), R"( - CHECK-NOT: mul.ftz.f32 - CHECK: mul.f32 - CHECK-NOT: mul.ftz.f32 + CHECK-NOT: mul.rn.ftz.f32 + CHECK: mul.rn.f32 + CHECK-NOT: mul.rn.ftz.f32 )"); } diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc index da8e513a2c3..6814be779e0 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc @@ -51,7 +51,7 @@ TEST_F(GpuIndexTest, CompatibleUseLinearIndex) { builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {5, 7, 2}), HloOpcode::kGe, param_x, param_y)); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(builder.Build()); // Check the optimized IR as the unoptimized IR contains dead udiv and urem. diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc index ea1fee040dd..3019215c015 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc @@ -48,7 +48,7 @@ TEST_F(GpuLdgTest, LdgForParamRead) { HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param)); std::unique_ptr computation = builder.Build(); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(std::move(computation)); CompileAndVerifyPtx(std::move(hlo_module), R"( @@ -73,7 +73,7 @@ TEST_F(GpuLdgTest, LdgForNonParamRead) { builder.AddInstruction(HloInstruction::CreateTuple({add, square})); std::unique_ptr computation = builder.Build(); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(std::move(computation)); CompileAndVerifyPtx(std::move(hlo_module), R"( @@ -95,7 +95,7 @@ TEST_F(GpuLdgTest, LdgForNonParamRead) { // reduce in the foreseeable future. But if that turns out to be wrong, I give // you, future reader, permission to delete this test. TEST_F(GpuLdgTest, NoLdgWhenSharingBuffer) { - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloComputation* reduce_computation; diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc index 14285459b5a..ca0a78034d7 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc @@ -47,7 +47,7 @@ TEST_F(GpuNoAliasTest, Concat) { std::unique_ptr computation = builder.Build(); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(std::move(computation)); CompileAndVerifyIr(std::move(hlo_module), diff --git a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc index 141f3219387..6b2d76764a0 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc @@ -45,7 +45,7 @@ void ThunkSchedule::AddDependenciesOnTransitiveOperands( ThunkSchedule::ThunkSchedule( std::unique_ptr thunks, std::unique_ptr stream_assignment, - const std::vector& hlo_total_order) + const std::vector& hlo_total_order) : thunks_(std::move(thunks)), stream_assignment_(std::move(stream_assignment)) { std::unordered_map hlo_to_thunk; @@ -53,7 +53,7 @@ ThunkSchedule::ThunkSchedule( InsertOrDie(&hlo_to_thunk, thunk->hlo_instruction(), thunk.get()); } - for (const HloInstruction* hlo : hlo_total_order) { + for (HloInstruction* hlo : hlo_total_order) { if (hlo_to_thunk.count(hlo)) { thunk_total_order_.push_back(FindOrDie(hlo_to_thunk, hlo)); } diff --git a/tensorflow/compiler/xla/service/gpu/thunk_schedule.h b/tensorflow/compiler/xla/service/gpu/thunk_schedule.h index d3352994f84..43b628a1baf 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_schedule.h +++ b/tensorflow/compiler/xla/service/gpu/thunk_schedule.h @@ -46,7 +46,7 @@ class ThunkSchedule { public: ThunkSchedule(std::unique_ptr thunks, std::unique_ptr stream_assignment, - const std::vector& hlo_total_order); + const std::vector& hlo_total_order); // Returns the total order of executing all the thunks. const std::vector& TotalOrder() const { return thunk_total_order_; } diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc index c7f51127649..2dce7749bbd 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc @@ -29,7 +29,7 @@ namespace { class WhileTransformerTest : public HloTestBase { protected: WhileTransformerTest() - : module_(CreateNewUnverifiedModule()), + : module_(CreateNewVerifiedModule()), induction_variable_shape_(ShapeUtil::MakeShape(S32, {})), data_shape_(ShapeUtil::MakeShape(F32, {8})), condition_result_shape_(ShapeUtil::MakeShape(PRED, {})) {} diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index fad3215fc81..dc40b9446ad 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -258,7 +258,7 @@ class HeapSimulatorTracker { // Constructor for testing a single entry computation. HeapSimulatorTracker( const string& name, std::unique_ptr computation, - const std::vector& instruction_sequence) { + const std::vector& instruction_sequence) { HloModuleConfig config; module_ = absl::make_unique(name, config); module_->AddEntryComputation(std::move(computation)); @@ -286,7 +286,7 @@ class HeapSimulatorTracker { // Similar to the single entry computation constructor above, but runs the // simulation over the entire module. void RunWholeModule( - const std::vector& full_module_sequence) { + const std::vector& full_module_sequence) { points_to_analysis_ = TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); @@ -294,7 +294,7 @@ class HeapSimulatorTracker { HloSchedule schedule(module_.get()); absl::flat_hash_map reverse_position; for (int i = 0; i < full_module_sequence.size(); ++i) { - const HloInstruction* instruction = full_module_sequence[i]; + HloInstruction* instruction = full_module_sequence[i]; schedule.GetOrCreateSequence(instruction->parent()) .push_back(instruction); reverse_position[instruction] = full_module_sequence.size() - i; diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index dbab62f847e..414c6327124 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -51,7 +51,7 @@ message HloInstructionProto { string name = 1; string opcode = 2; - xla.Shape shape = 3; + xla.ShapeProto shape = 3; xla.OpMetadata metadata = 7; @@ -132,7 +132,7 @@ message HloInstructionProto { string custom_call_opaque = 53; // Shape of outfeed request. - xla.Shape outfeed_shape = 29; + xla.ShapeProto outfeed_shape = 29; // Describes the dimension numbers used for a dot operation xla.DotDimensionNumbers dot_dimension_numbers = 30; @@ -190,7 +190,7 @@ message HloInstructionProto { // 'operand_shapes_with_layout' must contain a shape with layout for each // operand. bool constrain_layout = 56; - repeated Shape operand_shapes_with_layout = 57; + repeated xla.ShapeProto operand_shapes_with_layout = 57; } // Serialization of HloComputation. @@ -205,7 +205,8 @@ message HloComputationProto { repeated HloInstructionProto instructions = 2; // The program shape (with layout) of this computation. - xla.ProgramShape program_shape = 4; + + xla.ProgramShapeProto program_shape = 4; // The id of this computation. int64 id = 5; @@ -251,6 +252,41 @@ message HloInputOutputAliasProto { repeated AliasEntryProto entries = 1; } +message DynamicParameterBindingProto { + // A list of bindings which indicates that the `target_dim_num` in + // the subshape `target_param_index` of parameter `target_param_num` + // is a dynamic dimension and its real dynamic size is represented + // by `dynamic_param_index` in parameter `dynamic_param_num`. + // + // As an example, imagine we have a program: + // + // ENTRY main { + // a = f32[] parameter(0) + // b = f32[10] parameter(1) + // ROOT root = (f32[], f32[10]) tuple(%a, %b) + // } + // + // Let's say 'b' (param index 1) is a dynamic shape whose input has + // an upperbound of 10 and real size is determined at runtime.'a' + // represents the real size of b's first dimension. + // + // In this case, the fields are set in the following way: + // dynamic_param_num = 1 + // dynamic_param_index = {} + // target_param_num = 0 + // target_param_index = {} + // target_param_dim = 0 + message Binding { + int64 dynamic_param_num = 1; + repeated int64 dynamic_param_index = 2; + int64 target_param_num = 3; + repeated int64 target_param_index = 4; + int64 target_param_dim_num = 5; + } + + repeated Binding entries = 1; +} + // Serialization of HloModule. message HloModuleProto { string name = 1; @@ -262,7 +298,7 @@ message HloModuleProto { repeated HloComputationProto computations = 3; // The host program shape (with layout) of the entry computation. - xla.ProgramShape host_program_shape = 4; + xla.ProgramShapeProto host_program_shape = 4; // The id of this module. int64 id = 5; @@ -272,6 +308,8 @@ message HloModuleProto { // Describes alias information between inputs and outputs. HloInputOutputAliasProto input_output_alias = 8; + + DynamicParameterBindingProto dynamic_parameter_binding = 9; } // Serialization of LogicalBuffer. diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 0c20d207ddb..ff122b529bd 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -499,7 +499,7 @@ HloComputationProto HloComputation::ToProto() const { proto.add_instructions()->Swap(&instruction_proto); } proto.set_root_id(root_instruction()->unique_id()); - *proto.mutable_program_shape() = ComputeProgramShape(); + *proto.mutable_program_shape() = ComputeProgramShape().ToProto(); return proto; } @@ -711,6 +711,8 @@ bool HloComputation::operator==(const HloComputation& other) const { return eq(root_instruction(), other.root_instruction()); } +uint64 HloComputation::Hash() const { return root_instruction()->Hash(); } + Status HloComputation::ReplaceWithNewInstruction( HloInstruction* old_instruction, std::unique_ptr new_instruction) { @@ -795,7 +797,7 @@ Status HloComputation::AcceptWithOperandOrder( template Status HloComputation::AcceptOrdered( DfsHloVisitorBase* visitor, - const std::vector& order) const { + const std::vector& order) const { VLOG(3) << "Accepting visitor with order."; for (HloInstruction* root : CollectUnreachableRoots()) { TF_RET_CHECK(std::find(order.begin(), order.end(), root) != order.end()) @@ -825,9 +827,9 @@ Status HloComputation::AcceptOrdered( // Explicit instantiations. template Status HloComputation::AcceptOrdered( - DfsHloVisitor*, const std::vector&) const; + DfsHloVisitor*, const std::vector&) const; template Status HloComputation::AcceptOrdered( - ConstDfsHloVisitor*, const std::vector&) const; + ConstDfsHloVisitor*, const std::vector&) const; Status HloComputation::Accept( const std::function& visitor_func) { diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index fc7d2035e5b..c584e4c7ca5 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -264,6 +264,12 @@ class HloComputation { // Return whether `*this` and `other` are functionally equivalent. bool operator==(const HloComputation& other) const; + // Generates a hash value of an HLO computation. Hash considers + // information on opcode, shape, operands, and typically a root instruction. + // This function returns the same hash value for equivalent HLO computations, + // with respect to HloInstruction::Identical() method. + uint64 Hash() const; + // Replaces old instruction with newly created instruction. Removes old // instruction from computation. Updates uses and root instruction. Status ReplaceWithNewInstruction( @@ -301,7 +307,7 @@ class HloComputation { // be a topological sort of all instructions in the computation. template Status AcceptOrdered(DfsHloVisitorBase* visitor, - const std::vector& order) const; + const std::vector& order) const; // Same as Accept() above, but the visitor is given as a function. Status Accept(const std::function& visitor_func); diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index 1e7a6e197f5..8b50cfa9aed 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -65,7 +65,7 @@ class HloComputationTest : public HloTestBase { }; TEST_F(HloComputationTest, GetEmbeddedComputationsEmpty) { - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto negate_computation = module->AddEntryComputation(CreateNegateComputation()); EXPECT_TRUE(negate_computation->MakeEmbeddedComputationsList().empty()); @@ -73,7 +73,7 @@ TEST_F(HloComputationTest, GetEmbeddedComputationsEmpty) { TEST_F(HloComputationTest, GetEmbeddedComputationsOneComputation) { // Create computation which calls one other computation. - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto negate_computation = module->AddEmbeddedComputation(CreateNegateComputation()); auto map_computation = @@ -85,7 +85,7 @@ TEST_F(HloComputationTest, GetEmbeddedComputationsOneComputation) { TEST_F(HloComputationTest, GetEmbeddedComputationsDiamond) { // Create computations with a diamond-shaped callgraph. - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto negate_computation = module->AddEmbeddedComputation(CreateNegateComputation()); auto map1_computation = @@ -119,7 +119,7 @@ TEST_F(HloComputationTest, PostOrderSingleton) { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->MakeInstructionPostOrder(), ElementsAre(constant)); } @@ -134,7 +134,7 @@ TEST_F(HloComputationTest, PostOrderSimple) { HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); auto negate2 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, negate1)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->MakeInstructionPostOrder(), ElementsAre(constant, negate1, negate2)); @@ -170,7 +170,7 @@ TEST_F(HloComputationTest, PostOrderDisconnectedInstructions) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto constant4 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->MakeInstructionPostOrder(), UnorderedElementsAre(constant1, constant2, constant3, constant4)); @@ -192,7 +192,7 @@ TEST_F(HloComputationTest, PostOrderWithMultipleRoots) { r0f32_, HloOpcode::kAdd, constant2, constant3)); auto add3 = builder.AddInstruction(HloInstruction::CreateBinary( r0f32_, HloOpcode::kAdd, constant1, constant3)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); auto post_order = computation->MakeInstructionPostOrder(); EXPECT_EQ(6, post_order.size()); @@ -217,7 +217,7 @@ TEST_F(HloComputationTest, VisitWithMultipleRoots) { constant2, constant3)); builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, constant1, constant3)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Visitor which keeps track of which instructions have been visited. class TestVisitor : public DfsHloVisitorWithDefault { @@ -257,7 +257,7 @@ TEST_F(HloComputationTest, DeepCopyArray) { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); auto copy = computation->DeepCopyInstruction(constant).ValueOrDie(); @@ -274,7 +274,7 @@ TEST_F(HloComputationTest, DeepCopyTuple) { auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); auto tuple_copy = computation->DeepCopyInstruction(tuple).ValueOrDie(); @@ -376,7 +376,7 @@ TEST_F(HloComputationTest, DeepCopyToken) { // copied. auto builder = HloComputation::Builder(TestName()); auto token = builder.AddInstruction(HloInstruction::CreateToken()); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); auto copy = computation->DeepCopyInstruction(token).ValueOrDie(); @@ -393,7 +393,7 @@ TEST_F(HloComputationTest, DeepCopyTokenTuple) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({token, constant})); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); auto copy = computation->DeepCopyInstruction(tuple).ValueOrDie(); @@ -440,7 +440,7 @@ TEST_F(HloComputationTest, RemoveInstructionWithDuplicateOperand) { r0f32_, HloOpcode::kAdd, dead_negate, dead_negate)); auto negate = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(4, computation->instruction_count()); EXPECT_THAT(computation->root_instruction(), op::Negate(constant)); @@ -466,7 +466,7 @@ TEST_F(HloComputationTest, CloneWithControlDependency) { HloInstruction::CreateParameter(0, r0f32_, "param0")); auto negate = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, param)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build(/*root_instruction=*/add)); @@ -505,7 +505,7 @@ TEST_F(HloComputationTest, Stringification) { 2, PrecisionConfig::DEFAULT); builder.AddInstruction( HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto options = HloPrintOptions().set_print_metadata(false); @@ -540,7 +540,7 @@ TEST_F(HloComputationTest, StringificationIndent) { 2, PrecisionConfig::DEFAULT); builder.AddInstruction( HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto options = @@ -576,7 +576,7 @@ TEST_F(HloComputationTest, StringificationCanonical) { 2, PrecisionConfig::DEFAULT); builder.AddInstruction( HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto options = HloPrintOptions().set_print_metadata(false); diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index d12f920722e..4f81dc94e57 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -22,21 +22,22 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/types.h" -namespace op = xla::testing::opcode_matchers; - namespace xla { namespace { +namespace m = xla::match; + using HloConstantFoldingTest = HloTestBase; TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { @@ -49,13 +50,14 @@ TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Convert(input)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Convert().WithOperand(0, m::Op().Is(input)))); HloConstantFolding const_folder; TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); EXPECT_TRUE(result); - EXPECT_THAT(computation->root_instruction(), op::Constant()); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Constant())); EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement(), 42); } @@ -70,13 +72,14 @@ TEST_F(HloConstantFoldingTest, ConvertS64ToF32) { auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Convert(input)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Convert().WithOperand(0, m::Op().Is(input)))); HloConstantFolding const_folder; TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); EXPECT_TRUE(result); - EXPECT_THAT(computation->root_instruction(), op::Constant()); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Constant())); EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement(), 42.0f); } @@ -91,13 +94,14 @@ TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) { auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Convert(input)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Convert().WithOperand(0, m::Op().Is(input)))); HloConstantFolding const_folder; TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); EXPECT_TRUE(result); - EXPECT_THAT(computation->root_instruction(), op::Constant()); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Constant())); EXPECT_EQ(computation->root_instruction()->literal().Get({0}), 42); EXPECT_EQ(computation->root_instruction()->literal().Get({1}), 19); } @@ -138,7 +142,7 @@ TEST_F(HloConstantFoldingTest, Concatenate) { EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Constant()); + EXPECT_THAT(root, GmockMatch(m::Constant())); EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape)); } } @@ -165,7 +169,7 @@ TEST_F(HloConstantFoldingTest, Slice) { EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Constant()); + EXPECT_THAT(root, GmockMatch(m::Constant())); EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape)); } @@ -190,7 +194,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Constant()); + EXPECT_THAT(root, GmockMatch(m::Constant())); EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), shape)); using NativeT = typename primitive_util::PrimitiveTypeToNative::type; @@ -240,7 +244,8 @@ TEST_F(HloConstantFoldingTest, ConstantFoldReduceNoLayout) { TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(m.get())); EXPECT_FALSE(result); - EXPECT_THAT(m->entry_computation()->root_instruction(), op::Reduce()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Reduce())); } const char* const kConstantFoldLargePad = R"( @@ -260,7 +265,7 @@ TEST_F(HloConstantFoldingTest, DoesNotFoldLargePad) { EXPECT_FALSE(result); EXPECT_THAT(module->entry_computation()->root_instruction(), - op::Pad(op::Constant(), op::Constant())); + GmockMatch(m::Pad(m::Constant(), m::Constant()))); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index fdfb38b858c..df7d3826dba 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -419,6 +419,21 @@ Status HloCostAnalysis::HandleTranspose(const HloInstruction*) { } Status HloCostAnalysis::HandleAfterAll(const HloInstruction*) { + // This instruction is used to enforce ordering at compile time. No code is + // emitted. + current_should_compute_bottleneck_time_ = false; + current_properties_[kBytesAccessedKey] = 0; + current_properties_[kOptimalSecondsKey] = 0; + return Status::OK(); +} + +Status HloCostAnalysis::HandleAddDependency( + const HloInstruction* add_dependency) { + // This instruction is used to enforce ordering at compile time. No code is + // emitted. + current_should_compute_bottleneck_time_ = false; + current_properties_[kBytesAccessedKey] = 0; + current_properties_[kOptimalSecondsKey] = 0; return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 8ced9d776e1..33983119c9b 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -101,6 +101,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleBroadcast(const HloInstruction* broadcast) override; Status HandlePad(const HloInstruction* pad) override; Status HandleReshape(const HloInstruction* reshape) override; + Status HandleAddDependency(const HloInstruction* add_dependency) override; Status HandleAfterAll(const HloInstruction* token) override; Status HandleTranspose(const HloInstruction* transpose) override; Status HandleWhile(const HloInstruction* xla_while) override; diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index 6a15b3440c6..ff32faf298d 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -387,7 +387,7 @@ TEST_F(FusionCostAnalysis, LoopFusion) { HloInstruction::CreateBinary(r2f32, HloOpcode::kSubtract, mul, clamp)); auto tuple = HloInstruction::CreateTuple({sub, sub, mul, c1}); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop); @@ -429,7 +429,7 @@ TEST_F(FusionCostAnalysis, NoLayout) { auto add = builder.AddInstruction(HloInstruction::CreateBinary( shape_with_layout, HloOpcode::kAdd, c1, broadcast)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {add, broadcast}, HloInstruction::FusionKind::kLoop); @@ -472,7 +472,7 @@ TEST_F(DomainCostAnalysis, DomainCost) { auto domain = builder.AddInstruction( HloInstruction::CreateDomain(tuple->shape(), tuple, nullptr, nullptr)); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(builder.Build()); EXPECT_EQ(hlo_module->entry_computation()->root_instruction(), domain); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 5dcf6bc985f..3ed3d3c11c7 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -466,6 +466,21 @@ bool HloDataflowAnalysis::UpdateDomainValueSet(HloInstruction* domain) { return changed; } +bool HloDataflowAnalysis::UpdateAddDependencyValueSet( + HloInstruction* add_dependency) { + // AddDependency just forwards the value of its zero-th operand. + CHECK_EQ(add_dependency->opcode(), HloOpcode::kAddDependency); + const InstructionValueSet& operand_set = + GetInstructionValueSet(add_dependency->operand(0)); + InstructionValueSet& add_dependency_set = + GetInstructionValueSet(add_dependency); + if (operand_set != add_dependency_set) { + add_dependency_set = operand_set; + return true; + } + return false; +} + bool HloDataflowAnalysis::UpdateGetTupleElementValueSet(HloInstruction* gte) { CHECK_EQ(gte->opcode(), HloOpcode::kGetTupleElement); bool changed = false; @@ -622,6 +637,8 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet( HloInstruction* instruction) { // Recompute from operands. switch (instruction->opcode()) { + case HloOpcode::kAddDependency: + return UpdateAddDependencyValueSet(instruction); case HloOpcode::kBitcast: return UpdateBitcastValueSet(instruction); case HloOpcode::kDomain: @@ -795,6 +812,7 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { define_all_values(); } break; + case HloOpcode::kAddDependency: case HloOpcode::kWhile: case HloOpcode::kCall: case HloOpcode::kConditional: diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index abac398c04f..ece17fc4c3e 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -193,6 +193,7 @@ class HloDataflowAnalysis { bool UpdateSendValueSet(HloInstruction* send); bool UpdateTupleValueSet(HloInstruction* tuple); bool UpdateWhileValueSet(HloInstruction* xla_while); + bool UpdateAddDependencyValueSet(HloInstruction* add_dependency); // Propagate the dataflow through the module. void Propagate(); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 6422346c101..f7a1f19a6f5 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -43,7 +43,7 @@ using ::testing::UnorderedElementsAre; class HloDataflowAnalysisTest : public HloTestBase, public ::testing::WithParamInterface { protected: - HloDataflowAnalysisTest() : module_(CreateNewUnverifiedModule()) {} + HloDataflowAnalysisTest() : module_(CreateNewVerifiedModule()) {} // Run dataflow analysis on the member module. For convenience returns a // reference to the generated analysis stored in analysis_. @@ -1877,6 +1877,30 @@ TEST_P(HloDataflowAnalysisTest, NestedConditionals) { } } +TEST_P(HloDataflowAnalysisTest, AddDependency) { + string module_string = R"( +HloModule AddDependency +ENTRY %AddDependency (p: f32[3]) -> f32[3] { + %p = f32[3] parameter(0) + %token = token[] after-all() + ROOT %add_dep = f32[3] add-dependency(f32[3] %p, token[] %token) +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseHloString(module_string, GetModuleConfigForTest())); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis, + HloDataflowAnalysis::Run(*module)); + const HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kAddDependency); + + // The after-all and parameter should define a value. Add-dependency should + // not. + EXPECT_EQ(analysis->values().size(), 2); + EXPECT_FALSE(analysis->ValueIsDefinedAt(root)); +} + INSTANTIATE_TEST_CASE_P(HloDataflowAnalysisInstantiation, HloDataflowAnalysisTest, ::testing::Values(false, true)); diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc index 6c8095d3977..1fa4259a3e4 100644 --- a/tensorflow/compiler/xla/service/hlo_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc @@ -59,7 +59,7 @@ TEST_F(HloDceTest, NoDeadCode) { builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(3, computation->instruction_count()); @@ -110,7 +110,7 @@ TEST_F(HloDceTest, DeadParameters) { builder.AddInstruction(HloInstruction::CreateUnary( live_param->shape(), HloOpcode::kNegate, live_param)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(5, computation->instruction_count()); @@ -150,7 +150,7 @@ TEST_F(HloDceTest, ControlDependencies) { builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Add a control dependency between two instructions. @@ -175,7 +175,7 @@ TEST_F(HloDceTest, ControlDependencies) { // Tests that a dead call instruction is removed. TEST_F(HloDceTest, DeadInstructionWithCalledComputation) { - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); Shape shape = ShapeUtil::MakeShape(F32, {}); // Called computation for the call instruction. @@ -323,7 +323,7 @@ TEST_F(HloDceTest, CalledComputationWithNestedSideEffect) { } TEST_F(HloDceTest, RemoveDeadSubcomputation) { - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloComputation::Builder subcomp_builder("reduction_subcomp"); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 7fcafafc097..3a7652a8dc8 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -39,6 +39,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_query.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" @@ -396,6 +397,16 @@ StatusOr HloEvaluator::EvaluateDotOp( return Evaluate(cloned_instruction.get()); } +Status HloEvaluator::HandleBitcast(HloInstruction* bitcast) { + const Literal& operand_literal = GetEvaluatedLiteralFor(bitcast->operand(0)); + Literal result(bitcast->shape()); + TF_RET_CHECK(operand_literal.size_bytes() == result.size_bytes()); + memcpy(result.untyped_data(), operand_literal.untyped_data(), + operand_literal.size_bytes()); + evaluated_[bitcast] = std::move(result); + return Status::OK(); +} + Status HloEvaluator::HandleParameter(HloInstruction* parameter) { CHECK_LT(parameter->parameter_number(), arg_literals_.size()); const Literal* input_literal = arg_literals_[parameter->parameter_number()]; @@ -1046,8 +1057,15 @@ Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) { return Status::OK(); } -Status HloEvaluator::HandleAfterAll(HloInstruction* token) { - evaluated_[token] = LiteralUtil::CreateToken(); +Status HloEvaluator::HandleAfterAll(HloInstruction* after_all) { + evaluated_[after_all] = LiteralUtil::CreateToken(); + return Status::OK(); +} + +Status HloEvaluator::HandleAddDependency(HloInstruction* add_dependency) { + // AddDedendency just forwards its zero-th operand. + evaluated_[add_dependency] = + GetEvaluatedLiteralFor(add_dependency->operand(0)).Clone(); return Status::OK(); } @@ -1279,10 +1297,10 @@ StatusOr EvaluateSortInternal(HloInstruction* sort, key_value_vector.push_back( std::make_pair(keys_data[i], values_data[i])); } - std::sort(key_value_vector.begin(), key_value_vector.end(), - [](const kv_pair& a, const kv_pair& b) { - return SafeLess(a.first, b.first); - }); + std::stable_sort(key_value_vector.begin(), key_value_vector.end(), + [](const kv_pair& a, const kv_pair& b) { + return SafeLess(a.first, b.first); + }); std::vector result_keys; // We use a InlinedVector here because we need to convert it to an // absl::Span later, and this would not work with std::vector. diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 07f8d0aad4a..45ed8131dc6 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -144,6 +144,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // Operations that are type-agnostic or always return a specific type, such as // HandleIsFinite where boolean is always returned. // + Status HandleBitcast(HloInstruction* bitcast) override; + Status HandleParameter(HloInstruction* parameter) override; Status HandleConstant(HloInstruction* constant) override; @@ -180,7 +182,9 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status HandleBroadcast(HloInstruction* broadcast) override; - Status HandleAfterAll(HloInstruction* token) override; + Status HandleAfterAll(HloInstruction* after_all) override; + + Status HandleAddDependency(HloInstruction* add_dependency) override; Status HandleSort(HloInstruction* sort) override; @@ -221,16 +225,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault { const Literal& operand_literal) { const auto shape = instruction->shape(); const auto* operand = instruction->operand(0); - - // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is - // removed. - if (!ShapeUtil::SameDimensions(shape, operand->shape())) { - return Unimplemented( - "Implicit broadcasting is currently unsupported in HLO evaluator " - "Shape Mismatch: %s vs %s", - ShapeUtil::HumanString(shape), - ShapeUtil::HumanString(operand->shape())); - } + TF_RET_CHECK(ShapeUtil::SameDimensions(shape, operand->shape())); Literal result(shape); TF_RETURN_IF_ERROR( diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index d95b6ad04f2..4eaaab20ea0 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -35,6 +36,7 @@ limitations under the License. #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -2765,6 +2767,33 @@ ENTRY main { EXPECT_TRUE(LiteralTestUtil::Equal(arg, actual)); } +TEST_P(HloEvaluatorTest, Bitcast) { + // Regression test for b/114735354. + constexpr absl::string_view hlo_text_base = R"( +HloModule Bitcast + +ENTRY main { + param = %s[32,121]{1,0} parameter(0) + ROOT bitcast = %s[121,32,1]{0,1,2} bitcast(%s[32,121]{1,0} param) +} +)"; + string hlo_text; + if (use_bfloat16_) { + hlo_text = absl::StrFormat(hlo_text_base, "bf16", "bf16", "bf16"); + } else { + hlo_text = absl::StrFormat(hlo_text_base, "f32", "f32", "f32"); + } + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie(); + Literal actual = Evaluate({&args[0]}); + if (use_bfloat16_) { + EXPECT_TRUE( + absl::c_equal(args[0].data(), actual.data())); + } else { + EXPECT_TRUE(absl::c_equal(args[0].data(), actual.data())); + } +} + INSTANTIATE_TEST_CASE_P(HloEvaluatorTest_Instantiation, HloEvaluatorTest, ::testing::ValuesIn(use_bf16_params)); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index ebed875eb49..b87fc3e3401 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -161,9 +161,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { HloOpcodeString(hlo_instruction->opcode())); } - // TODO(b/35950897): many of the stl functions used in the handlers are not - // overloaded for every XLA primitive type. - template ::value>::type* = nullptr> @@ -596,7 +593,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } - Status HandleDivide(HloInstruction* divide) { + Status HandleDivide(HloInstruction* divide) override { return HandleDivide(divide); } @@ -1556,10 +1553,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const auto& row_data = row_to_sort.data(); std::vector result_data(row_data.begin(), row_data.end()); - std::sort(result_data.begin(), result_data.end(), - [](const NativeT& a, const NativeT& b) { - return SafeLess(a, b); - }); + std::stable_sort(result_data.begin(), result_data.end(), + [](const NativeT& a, const NativeT& b) { + return SafeLess(a, b); + }); Literal sorted_row(ShapeUtil::MakeShape(keys->shape().element_type(), {sort_dim_elements})); sorted_row.PopulateR1(absl::Span(result_data)); @@ -2546,12 +2543,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template ::value || - std::is_same::value || - std::is_same::value>::type* = nullptr> + std::is_integral::value || + std::is_floating_point::value>::type* = nullptr> Status HandleIota(HloInstruction* instruction) { auto* iota = Cast(instruction); - std::vector data(iota->shape().dimensions(iota->iota_dimension())); + // Avoid using std::vector since std::vector does not convert to + // absl::Span. + absl::InlinedVector data( + iota->shape().dimensions(iota->iota_dimension())); std::iota(data.begin(), data.end(), 0); auto result = LiteralUtil::CreateR1(data); @@ -2568,9 +2567,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } template ::value || - std::is_same::value || - std::is_same::value)>::type* = nullptr> + !(std::is_integral::value || + std::is_floating_point::value)>::type* = nullptr> Status HandleIota(HloInstruction* iota) { return InvalidArgument("Unsupported type for iota"); } @@ -2722,17 +2720,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const auto shape = instruction->shape(); const auto* lhs = instruction->operand(0); const auto* rhs = instruction->operand(1); - - // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast - // is removed. - if (!(ShapeUtil::SameDimensions(shape, rhs->shape()) && - ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) { - return Unimplemented( - "Implicit broadcasting is currently unsupported in HLO evaluator " - "Shape Mismatch: %s vs %s vs %s: ", - ShapeUtil::HumanString(shape), ShapeUtil::HumanString(lhs->shape()), - ShapeUtil::HumanString(rhs->shape())); - } + TF_RET_CHECK(ShapeUtil::SameDimensions(shape, rhs->shape())); + TF_RET_CHECK(ShapeUtil::SameDimensions(lhs->shape(), rhs->shape())); const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); @@ -2756,19 +2745,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const auto* lhs = instruction->operand(0); const auto* rhs = instruction->operand(1); const auto* ehs = instruction->operand(2); - - // TODO(b/35950897, b/27796129): add DCHECK back once implicit - // broadcast is removed. - if (!(ShapeUtil::SameDimensions(shape, lhs->shape()) && - ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()) && - ShapeUtil::SameDimensions(rhs->shape(), ehs->shape()))) { - return Unimplemented( - "Implicit broadcasting is currently unsupported in HLO evaluator " - "Shape Mismatch: %s vs %s vs %s vs %s: ", - ShapeUtil::HumanString(shape), ShapeUtil::HumanString(lhs->shape()), - ShapeUtil::HumanString(rhs->shape()), - ShapeUtil::HumanString(ehs->shape())); - } + TF_RET_CHECK(ShapeUtil::SameDimensions(shape, lhs->shape())); + TF_RET_CHECK(ShapeUtil::SameDimensions(lhs->shape(), rhs->shape())); + TF_RET_CHECK(ShapeUtil::SameDimensions(rhs->shape(), ehs->shape())); const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); diff --git a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc new file mode 100644 index 00000000000..c919dbd82d3 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc @@ -0,0 +1,61 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h" + +#include "absl/algorithm/container.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" + +namespace xla { + +namespace { + +StatusOr ReplaceGetSize(HloInstruction* instr) { + if (instr->opcode() != HloOpcode::kGetDimensionSize) { + return false; + } + HloComputation* computation = instr->parent(); + + TF_ASSIGN_OR_RETURN(auto legal_shape, + ShapeInference::InferGetDimensionSizeShape( + instr->operand(0)->shape(), instr->dimension())); + TF_RET_CHECK(ShapeUtil::Equal(instr->shape(), legal_shape)); + TF_RET_CHECK(ShapeUtil::HasPrimitiveType(instr->shape(), U32)); + uint32 size = instr->operand(0)->shape().dimensions(instr->dimension()); + HloInstruction* new_instr = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(size))); + TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(new_instr)); + return true; +} + +} // namespace + +StatusOr HloGetDimensionSizeRewriter::Run(HloModule* module) { + bool changed = false; + HloProto proto; + *proto.mutable_hlo_module() = module->ToProto(); + for (auto* computation : module->computations()) { + for (auto instruction : computation->instructions()) { + TF_ASSIGN_OR_RETURN(bool replaced, ReplaceGetSize(instruction)); + changed = changed || replaced; + } + } + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h new file mode 100644 index 00000000000..30f44c23a83 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h @@ -0,0 +1,36 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_GET_DIMENSION_SIZE_REWRITER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_GET_DIMENSION_SIZE_REWRITER_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// Pass to replace a kGetDimensionSize instruction with a constant instruction. +class HloGetDimensionSizeRewriter : public HloModulePass { + public: + absl::string_view name() const override { + return "hlo-get-dimension-size-rewriter"; + } + + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_GET_DIMENSION_SIZE_REWRITER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter_test.cc b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter_test.cc new file mode 100644 index 00000000000..a86aebdd5b6 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter_test.cc @@ -0,0 +1,83 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; + +class HloGetDimensionSizeRewriterTest : public HloTestBase { + protected: + HloGetDimensionSizeRewriterTest() {} +}; + +TEST_F(HloGetDimensionSizeRewriterTest, Ok) { + auto module = ParseHloString(R"( +HloModule _ +ENTRY gds { + p = s32[3,4] parameter(0) + size0 = u32[] get-dimension-size(p), dimensions={0} + size1 = u32[] get-dimension-size(p), dimensions={1} + ROOT mul = u32[] multiply(size0, size1) +})") + .ValueOrDie(); + HloGetDimensionSizeRewriter pass; + EXPECT_TRUE(pass.Run(module.get()).ValueOrDie()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Multiply(op::Constant(), op::Constant())); +} + +TEST_F(HloGetDimensionSizeRewriterTest, IllegalType) { + auto module = ParseHloString(R"( +HloModule _ +ENTRY gds { + p = s32[3]{0} parameter(0) + ROOT gds = s64[] get-dimension-size(p), dimensions={0} +})") + .ValueOrDie(); + HloGetDimensionSizeRewriter pass; + EXPECT_FALSE(pass.Run(module.get()).ok()); +} + +TEST_F(HloGetDimensionSizeRewriterTest, IllegalDimension) { + auto module = ParseHloString(R"( +HloModule _ +ENTRY gds { + p = f32[2,5] parameter(0) + ROOT gds = u32[] get-dimension-size(p), dimensions={2} +})") + .ValueOrDie(); + HloGetDimensionSizeRewriter pass; + EXPECT_FALSE(pass.Run(module.get()).ok()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 05cc1593e4e..302eca656be 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -38,6 +39,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/window_util.h" @@ -111,11 +113,6 @@ class NodeFilter { result == kSomeUsersOmitted; } - bool ShowFusionSubcomputation(const HloInstruction* instr) const { - CHECK_EQ(instr->opcode(), HloOpcode::kFusion); - return Show(instr) && !SomeOrAllOperandsOmitted(instr); - } - private: std::function filter_; }; @@ -240,34 +237,28 @@ string HtmlLikeStringSanitize(absl::string_view s) { // it to a short string lets us tell the user what the subcomputation is without // drawing it as a graph. optional MatchTrivialComputation(const HloComputation* computation) { + namespace m = match; + if (computation->instruction_count() != 3) { return nullopt; } - HloInstruction* root = computation->root_instruction(); - if (root->operand_count() != 2) { + const HloInstruction *param0, *param1; + if (!Match(root, m::Op() + .WithNumOperands(2) + .WithShape(m::Shape().IsEffectiveScalar()) + .WithBinaryOperandsAnyOrder( + m::Parameter(¶m0, 0) + .WithShape(m::Shape().IsEffectiveScalar()), + m::Parameter(¶m1, 1) + .WithShape(m::Shape().IsEffectiveScalar())))) { return nullopt; } - // Check that both of the operands to the root are parameters. - const HloInstruction* operand0 = root->operand(0); - const HloInstruction* operand1 = root->operand(1); - if (operand0->opcode() != HloOpcode::kParameter || - operand1->opcode() != HloOpcode::kParameter) { - return nullopt; - } - - // Check that the two operands of root are param0 and param1. All of the - // opcodes we recognize are commutative, so we're OK with either order. - auto n0 = operand0->parameter_number(); - auto n1 = operand1->parameter_number(); - if (!(n0 == 0 && n1 == 1) && !(n1 == 0 && n0 == 1)) { - return nullopt; - } - - // If the params are reversed, check that the operation being performed is - // commutative. - if (n0 == 1) { + // If the params are reversed (i.e. operand0 is param1 and operand1 is + // param0), check that the operation being performed is commutative. + if (root->operand(0) == param1) { + CHECK_EQ(root->operand(1), param0); switch (root->opcode()) { case HloOpcode::kLe: case HloOpcode::kGe: @@ -279,13 +270,6 @@ optional MatchTrivialComputation(const HloComputation* computation) { } } - // Check that the root and params are all effective scalars. - if (!ShapeUtil::IsEffectiveScalar(root->shape()) || - !ShapeUtil::IsEffectiveScalar(operand0->shape()) || - !ShapeUtil::IsEffectiveScalar(operand1->shape())) { - return nullopt; - } - // If we recognize the root's opcode, we've successfully pattern-matched! switch (root->opcode()) { case HloOpcode::kAdd: @@ -578,7 +562,7 @@ bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) { // Show the subcomputation if we're showing any of its members. return std::any_of( - computation_->instructions().begin(), computation_->instructions().end(), + subcomp->instructions().begin(), subcomp->instructions().end(), [&](const HloInstruction* instr) { return filter_.Show(instr); }); } @@ -987,6 +971,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kGetTupleElement: case HloOpcode::kTrace: case HloOpcode::kAfterAll: + case HloOpcode::kAddDependency: case HloOpcode::kTuple: return kWhite; case HloOpcode::kBroadcast: @@ -1267,12 +1252,12 @@ const HloInstruction* HloDotDumper::GetNodeForEdge( class GraphRendererRegistry { public: - void AddRenderer(GraphRendererInterface* graph_renderer) { + void SetRenderer(std::shared_ptr graph_renderer) { tensorflow::mutex_lock lock(mu_); graph_renderer_ = graph_renderer; } - GraphRendererInterface* GetDefaultRenderer() { + std::shared_ptr GetDefaultRenderer() { tensorflow::mutex_lock lock(mu_); return graph_renderer_; } @@ -1284,20 +1269,21 @@ class GraphRendererRegistry { private: tensorflow::mutex mu_; - GraphRendererInterface* graph_renderer_ = nullptr; + std::shared_ptr graph_renderer_ GUARDED_BY(mu_); }; } // namespace -Registrar::Registrar(GraphRendererInterface* dumper) { - GraphRendererRegistry::Default()->AddRenderer(dumper); +Registrar::Registrar(std::shared_ptr dumper) { + GraphRendererRegistry::Default()->SetRenderer(dumper); } namespace { // Gets a NodeFilter that includes roughly all instructions whose distance from // root is <= radius. -NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) { +NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root, + int64 radius) { // First, find the neighborhood of nodes with distance from root <= radius. // These nodes are our initial set of "normal" nodes. std::unordered_map nodes; @@ -1404,6 +1390,56 @@ NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) { }); } +// Gets a node filter that includes nodes on all paths from `from` to `to`. If +// the all-paths set contains more than max_nodes elements, includes the nodes +// on the shortest paths and sets hit_limit to true. +NodeFilter MakeNodeFromToFilter(const HloInstruction* from, + const HloInstruction* to, int64 max_nodes, + bool* hit_limit) { + *hit_limit = false; + + // Elements in the queue are paths through the graph. + std::deque> queue; + queue.push_front({from}); + + // Compute the set of nodes we want to show using a slightly-modified + // Djikstra's algorithm. The only real difference is, rather than stopping + // when we find a (shortest) path, we continue until we've found max_nodes + // nodes on some path. + std::unordered_set visited; + std::unordered_set to_display = {from, to}; + while (!queue.empty() && to_display.size() < max_nodes) { + std::vector path = std::move(queue.front()); + queue.pop_front(); + if (!visited.insert(path.back()).second) { + continue; + } + + for (const auto* user : path.back()->users()) { + if (user == to) { + auto it = path.begin(); + for (; it != path.end() && to_display.size() < max_nodes; ++it) { + to_display.insert(*it); + } + if (it != path.end()) { + *hit_limit = true; + } + } else if (!visited.count(user)) { + auto new_path = path; + new_path.push_back(user); + queue.push_back(std::move(new_path)); + } + } + } + + return NodeFilter([=](const HloInstruction* instr) { + if (instr == from || instr == to) { + return kHighlightNode; + } + return to_display.count(instr) ? kNormalNode : kHideNode; + }); +} + string SaveGraph(const string& graph, GraphRendererInterface::GraphKind graph_kind, const string& dest_path) { @@ -1483,7 +1519,7 @@ string DumpNeighborhoodAround(const HloInstruction& node, int radius, auto debug_options = node.GetModule()->config().debug_options(); string label = StrCat("Neighborhood of ", radius, " nodes around ", node.name()); - NodeFilter filter = MakeNodeFilter(&node, radius); + NodeFilter filter = MakeNodeRadiusAroundFilter(&node, radius); string graph = HloDotDumper(node.parent(), label, debug_options, show_backend_config, /*profile=*/nullptr, filter) @@ -1491,6 +1527,29 @@ string DumpNeighborhoodAround(const HloInstruction& node, int radius, return ExportGraph(graph, GraphRendererInterface::DOT_GRAPH, debug_options); } +string DumpAllPathsFromTo(const HloInstruction& from, const HloInstruction& to, + int64 max_nodes, bool show_backend_config) { + CHECK_EQ(from.parent(), to.parent()) << "Nodes must be in same computation!"; + auto debug_options = from.GetModule()->config().debug_options(); + + bool hit_limit = false; + NodeFilter filter = MakeNodeFromToFilter(&from, &to, max_nodes, &hit_limit); + string label; + if (!hit_limit) { + label = StrCat("All paths from ", from.name(), " to ", to.name()); + } else { + label = StrCat(max_nodes, " nodes on the shortest paths from ", from.name(), + " to ", to.name(), + "

***SHOWING ONLY A SUBSET OF ALL PATHS BETWEEN " + "NODES***

"); + } + string graph = + HloDotDumper(from.parent(), label, debug_options, show_backend_config, + /*profile=*/nullptr, filter) + .Dump(); + return ExportGraph(graph, GraphRendererInterface::DOT_GRAPH, debug_options); +} + void DumpText(const HloModule& module, const string& label, const string& directory_path, bool do_prefix) { Env* env = Env::Default(); diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h index 0b11f34abb7..de1eefab776 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h @@ -66,6 +66,12 @@ string DumpGraph(const HloComputation& computation, const string& label, string DumpNeighborhoodAround(const HloInstruction& node, int radius, bool show_backend_config = false); +// Dumps nodes on any of the paths from `from` to `to`. If there are more than +// max_nodes on all paths, restricts to the max_nodes nodes on the shortest +// paths. +string DumpAllPathsFromTo(const HloInstruction& from, const HloInstruction& to, + int64 max_nodes, bool show_backend_config = false); + // Dumps the HloModule::ToString() as a file into the provided directory path // suffixed with the provided label. // @@ -87,13 +93,13 @@ void DumpText(const HloModule& module, const string& label, // Class that registers a graph renderer. class Registrar { public: - Registrar(GraphRendererInterface* dumper); + Registrar(std::shared_ptr dumper); }; -#define XLA_INTERNAL_REGISTER_GRAPH_RENDERER(factory, ctr, ...) \ - static ::xla::hlo_graph_dumper::Registrar \ - XLA_INTERNAL_REGISTER_GRAPH_RENDERER_NAME(ctr)(new factory, \ - ##__VA_ARGS__) +#define XLA_INTERNAL_REGISTER_GRAPH_RENDERER(factory, ctr, ...) \ + static ::xla::hlo_graph_dumper::Registrar \ + XLA_INTERNAL_REGISTER_GRAPH_RENDERER_NAME(ctr)( \ + std::make_shared(), ##__VA_ARGS__) // __COUNTER__ must go through another macro to be properly expanded #define XLA_INTERNAL_REGISTER_GRAPH_RENDERER_NAME(ctr) ___##ctr##__object_ diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 26786ee950b..21b1dbc1676 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -93,7 +93,8 @@ StatusOr> HloInstruction::CreateFromProto( [&computation_map](int64 id) { return computation_map.contains(id); })) << proto.name() << " instruction references invalid computation id(s)"; - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(proto.shape())); + Shape shape(proto.shape()); + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape)); switch (opcode) { // Ops migrated to subclasses. @@ -101,23 +102,23 @@ StatusOr> HloInstruction::CreateFromProto( TF_RET_CHECK(proto.operand_ids_size() == 3) << "BatchNormTraining instruction should have 3 operands but sees " << proto.operand_ids_size(); - instruction = CreateBatchNormTraining( - proto.shape(), operands(0), operands(1), operands(2), proto.epsilon(), - proto.feature_index()); + instruction = + CreateBatchNormTraining(shape, operands(0), operands(1), operands(2), + proto.epsilon(), proto.feature_index()); break; case HloOpcode::kBatchNormInference: TF_RET_CHECK(proto.operand_ids_size() == 5) << "BatchNormInference instruction should have 5 operands but sees " << proto.operand_ids_size(); instruction = CreateBatchNormInference( - proto.shape(), operands(0), operands(1), operands(2), operands(3), + shape, operands(0), operands(1), operands(2), operands(3), operands(4), proto.epsilon(), proto.feature_index()); break; case HloOpcode::kBatchNormGrad: TF_RET_CHECK(proto.operand_ids_size() == 5) << "BatchNormGrad instruction should have 5 operands but sees " << proto.operand_ids_size(); - instruction = CreateBatchNormGrad(proto.shape(), operands(0), operands(1), + instruction = CreateBatchNormGrad(shape, operands(0), operands(1), operands(2), operands(3), operands(4), proto.epsilon(), proto.feature_index()); break; @@ -127,7 +128,7 @@ StatusOr> HloInstruction::CreateFromProto( << proto.operand_ids_size(); std::vector fft_length(proto.fft_length().begin(), proto.fft_length().end()); - instruction = CreateFft(proto.shape(), operands(0), proto.fft_type(), + instruction = CreateFft(shape, operands(0), proto.fft_type(), absl::Span(fft_length)); break; } @@ -148,7 +149,7 @@ StatusOr> HloInstruction::CreateFromProto( TF_RET_CHECK(proto.operand_ids_size() == 1) << "Recv instruction should have 1 operand but sees " << proto.operand_ids_size(); - instruction = CreateRecv(proto.shape().tuple_shapes(0), operands(0), + instruction = CreateRecv(shape.tuple_shapes(0), operands(0), proto.channel_id(), proto.is_host_transfer()); break; case HloOpcode::kRecvDone: @@ -161,7 +162,7 @@ StatusOr> HloInstruction::CreateFromProto( TF_RET_CHECK(proto.operand_ids_size() == 1) << "Reverse instruction should have 1 operand but sees " << proto.operand_ids_size(); - instruction = CreateReverse(proto.shape(), operands(0), + instruction = CreateReverse(shape, operands(0), std::vector(proto.dimensions().begin(), proto.dimensions().end())); break; @@ -170,7 +171,7 @@ StatusOr> HloInstruction::CreateFromProto( << "Concatenate instruction should have 1 dimension but sees " << proto.dimensions_size(); instruction = - CreateConcatenate(proto.shape(), all_operands(), proto.dimensions(0)); + CreateConcatenate(shape, all_operands(), proto.dimensions(0)); break; case HloOpcode::kReduce: TF_RET_CHECK(proto.operand_ids_size() % 2 == 0) @@ -188,7 +189,7 @@ StatusOr> HloInstruction::CreateFromProto( absl::MakeSpan(reduce_operands) .subspan(reduce_operands.size() / 2, reduce_operands.size()); instruction = - CreateReduce(proto.shape(), inputs, init_values, + CreateReduce(shape, inputs, init_values, std::vector(proto.dimensions().begin(), proto.dimensions().end()), computations(0)); @@ -203,7 +204,7 @@ StatusOr> HloInstruction::CreateFromProto( auto sort_operands = all_operands(); HloInstruction* keys = sort_operands[0]; instruction = CreateSort( - proto.shape(), proto.dimensions(0), keys, + shape, proto.dimensions(0), keys, absl::Span(sort_operands).subspan(1)); break; } @@ -212,7 +213,7 @@ StatusOr> HloInstruction::CreateFromProto( << "Transpose instruction should have 1 operand but sees " << proto.operand_ids_size(); instruction = - CreateTranspose(proto.shape(), operands(0), + CreateTranspose(shape, operands(0), std::vector(proto.dimensions().begin(), proto.dimensions().end())); break; @@ -221,7 +222,7 @@ StatusOr> HloInstruction::CreateFromProto( << "Broadcast instruction should have 1 operand but sees " << proto.operand_ids_size(); instruction = - CreateBroadcast(proto.shape(), operands(0), + CreateBroadcast(shape, operands(0), std::vector(proto.dimensions().begin(), proto.dimensions().end())); break; @@ -229,7 +230,7 @@ StatusOr> HloInstruction::CreateFromProto( TF_RET_CHECK(proto.called_computation_ids_size() == 1) << "Map instruction should have 1 called computation but sees " << proto.called_computation_ids_size(); - instruction = CreateMap(proto.shape(), all_operands(), computations(0)); + instruction = CreateMap(shape, all_operands(), computations(0)); break; case HloOpcode::kSlice: { TF_RET_CHECK(proto.operand_ids_size() == 1) @@ -242,8 +243,8 @@ StatusOr> HloInstruction::CreateFromProto( slice_limits.push_back(slice_dimensions.limit()); slice_strides.push_back(slice_dimensions.stride()); } - instruction = CreateSlice(proto.shape(), operands(0), slice_starts, - slice_limits, slice_strides); + instruction = CreateSlice(shape, operands(0), slice_starts, slice_limits, + slice_strides); break; } case HloOpcode::kConstant: { @@ -253,7 +254,7 @@ StatusOr> HloInstruction::CreateFromProto( Literal::CreateFromProto(proto.literal())); instruction = CreateConstant(std::move(literal)); } else { - instruction = absl::make_unique(proto.shape()); + instruction = absl::make_unique(shape); } break; } @@ -284,55 +285,54 @@ StatusOr> HloInstruction::CreateFromProto( tensorflow::gtl::FindPtrOrNull(computation_map, fusion_id); TF_RET_CHECK(fused_computation != nullptr) << "No fusion computation with id " << fusion_id; - instruction = CreateFusion(proto.shape(), fusion_kind, all_operands(), - fused_computation); + instruction = + CreateFusion(shape, fusion_kind, all_operands(), fused_computation); break; } case HloOpcode::kRng: - instruction = - CreateRng(proto.shape(), proto.distribution(), all_operands()); + instruction = CreateRng(shape, proto.distribution(), all_operands()); break; case HloOpcode::kParameter: - instruction = CreateParameter(proto.parameter_number(), proto.shape(), - proto.name()); + instruction = + CreateParameter(proto.parameter_number(), shape, proto.name()); break; case HloOpcode::kGetTupleElement: TF_RET_CHECK(proto.operand_ids_size() == 1) << "GetTupleElement instruction should have 1 operand but sees " << proto.operand_ids_size(); - instruction = CreateGetTupleElement(proto.shape(), operands(0), - proto.tuple_index()); + instruction = + CreateGetTupleElement(shape, operands(0), proto.tuple_index()); break; case HloOpcode::kReducePrecision: TF_RET_CHECK(proto.operand_ids_size() == 1) << "ReducePrecision instruction should have 1 operand but sees " << proto.operand_ids_size(); - instruction = - CreateReducePrecision(proto.shape(), operands(0), - proto.exponent_bits(), proto.mantissa_bits()); + instruction = CreateReducePrecision( + shape, operands(0), proto.exponent_bits(), proto.mantissa_bits()); break; case HloOpcode::kInfeed: { - TF_RET_CHECK(ShapeUtil::IsTuple(proto.shape()) && - (ShapeUtil::TupleElementCount(proto.shape()) == 2)) + TF_RET_CHECK(ShapeUtil::IsTuple(shape) && + (ShapeUtil::TupleElementCount(shape) == 2)) << "Infeed should have a tuple shape with 2 operands, but has: " - << proto.shape(); - const Shape& data_shape = - ShapeUtil::GetTupleElementShape(proto.shape(), 0); + << shape; + const Shape& data_shape = ShapeUtil::GetTupleElementShape(shape, 0); TF_RET_CHECK(proto.operand_ids_size() == 1) << "Infeed instruction should have 1 operand but sees " << proto.operand_ids_size(); instruction = CreateInfeed(data_shape, operands(0), proto.infeed_config()); } break; - case HloOpcode::kOutfeed: + case HloOpcode::kOutfeed: { TF_RET_CHECK(proto.operand_ids_size() == 2) << "Outfeed instruction should have 2 operands but sees " << proto.operand_ids_size(); + Shape outfeed_shape(proto.outfeed_shape()); TF_RETURN_IF_ERROR( - ShapeUtil::ValidateShapeWithOptionalLayout(proto.outfeed_shape())); - instruction = CreateOutfeed(proto.outfeed_shape(), operands(0), - operands(1), proto.outfeed_config()); + ShapeUtil::ValidateShapeWithOptionalLayout(outfeed_shape)); + instruction = CreateOutfeed(outfeed_shape, operands(0), operands(1), + proto.outfeed_config()); break; + } case HloOpcode::kCrossReplicaSum: { TF_RET_CHECK(proto.called_computation_ids_size() == 1) << "CrossReplicaSum should have 1 called computation but sees " @@ -342,7 +342,7 @@ StatusOr> HloInstruction::CreateFromProto( all_reduce_id = proto.all_reduce_id(); } instruction = CreateCrossReplicaSum( - proto.shape(), all_operands(), computations(0), + shape, all_operands(), computations(0), /*replica_groups=*/ std::vector(proto.replica_groups().begin(), proto.replica_groups().end()), @@ -352,7 +352,7 @@ StatusOr> HloInstruction::CreateFromProto( } case HloOpcode::kAllToAll: { instruction = CreateAllToAll( - proto.shape(), all_operands(), + shape, all_operands(), /*replica_groups=*/ std::vector(proto.replica_groups().begin(), proto.replica_groups().end())); @@ -368,8 +368,8 @@ StatusOr> HloInstruction::CreateFromProto( source_target_pairs[i].first = proto.source_target_pairs(i).source(); source_target_pairs[i].second = proto.source_target_pairs(i).target(); } - instruction = CreateCollectivePermute(proto.shape(), operands(0), - source_target_pairs); + instruction = + CreateCollectivePermute(shape, operands(0), source_target_pairs); break; } case HloOpcode::kConvolution: { @@ -382,7 +382,7 @@ StatusOr> HloInstruction::CreateFromProto( precision_config.mutable_operand_precision()->Resize( proto.operand_ids_size(), PrecisionConfig::DEFAULT); instruction = CreateConvolve( - proto.shape(), operands(0), operands(1), + shape, operands(0), operands(1), std::max(proto.feature_group_count(), 1), proto.window(), proto.convolution_dimension_numbers(), precision_config); break; @@ -394,7 +394,7 @@ StatusOr> HloInstruction::CreateFromProto( TF_RET_CHECK(proto.called_computation_ids_size() == 1) << "ReduceWindow should have 1 called computation but sees " << proto.called_computation_ids_size(); - instruction = CreateReduceWindow(proto.shape(), operands(0), operands(1), + instruction = CreateReduceWindow(shape, operands(0), operands(1), proto.window(), computations(0)); break; case HloOpcode::kSelectAndScatter: @@ -404,9 +404,9 @@ StatusOr> HloInstruction::CreateFromProto( TF_RET_CHECK(proto.called_computation_ids_size() == 2) << "SelectAndScatter should have 2 called computations but sees " << proto.called_computation_ids_size(); - instruction = CreateSelectAndScatter( - proto.shape(), operands(0), computations(0), proto.window(), - operands(1), operands(2), computations(1)); + instruction = CreateSelectAndScatter(shape, operands(0), computations(0), + proto.window(), operands(1), + operands(2), computations(1)); break; case HloOpcode::kCustomCall: if (proto.constrain_layout()) { @@ -414,16 +414,17 @@ StatusOr> HloInstruction::CreateFromProto( // vector of pointers essentially) so create a vector of shapes to pass // in. std::vector operand_shapes; - for (const Shape& shape : proto.operand_shapes_with_layout()) { - operand_shapes.push_back(shape); + for (const ShapeProto& shape_proto : + proto.operand_shapes_with_layout()) { + operand_shapes.emplace_back(shape_proto); } - instruction = CreateCustomCall( - proto.shape(), all_operands(), proto.custom_call_target(), - operand_shapes, proto.custom_call_opaque()); + instruction = + CreateCustomCall(shape, all_operands(), proto.custom_call_target(), + operand_shapes, proto.custom_call_opaque()); } else { - instruction = CreateCustomCall(proto.shape(), all_operands(), - proto.custom_call_target(), - proto.custom_call_opaque()); + instruction = + CreateCustomCall(shape, all_operands(), proto.custom_call_target(), + proto.custom_call_opaque()); } if (proto.has_window()) { static_cast(instruction.get()) @@ -443,8 +444,8 @@ StatusOr> HloInstruction::CreateFromProto( << "Pad instruction should have 2 operands but sees " << proto.operand_ids_size(); TF_RET_CHECK(proto.has_padding_config()); - instruction = CreatePad(proto.shape(), operands(0), operands(1), - proto.padding_config()); + instruction = + CreatePad(shape, operands(0), operands(1), proto.padding_config()); break; case HloOpcode::kDynamicSlice: { TF_RET_CHECK(proto.operand_ids_size() == 2) @@ -452,8 +453,8 @@ StatusOr> HloInstruction::CreateFromProto( << proto.operand_ids_size(); std::vector slice_sizes(proto.dynamic_slice_sizes_size()); absl::c_copy(proto.dynamic_slice_sizes(), slice_sizes.begin()); - instruction = CreateDynamicSlice(proto.shape(), operands(0), operands(1), - slice_sizes); + instruction = + CreateDynamicSlice(shape, operands(0), operands(1), slice_sizes); break; } case HloOpcode::kGather: { @@ -469,7 +470,7 @@ StatusOr> HloInstruction::CreateFromProto( for (int64 bound : proto.gather_slice_sizes()) { gather_slice_sizes.push_back(bound); } - instruction = CreateGather(proto.shape(), operands(0), operands(1), + instruction = CreateGather(shape, operands(0), operands(1), *gather_dimension_numbers, gather_slice_sizes); break; } @@ -485,16 +486,15 @@ StatusOr> HloInstruction::CreateFromProto( auto scatter_dimension_numbers = absl::make_unique( proto.scatter_dimension_numbers()); - instruction = - CreateScatter(proto.shape(), operands(0), operands(1), operands(2), - computations(0), *scatter_dimension_numbers); + instruction = CreateScatter(shape, operands(0), operands(1), operands(2), + computations(0), *scatter_dimension_numbers); break; } case HloOpcode::kIota: TF_RET_CHECK(proto.dimensions_size() == 1) << "Iota instruction should have 1 dimension but sees " << proto.dimensions_size(); - instruction = CreateIota(proto.shape(), proto.dimensions(0)); + instruction = CreateIota(shape, proto.dimensions(0)); break; case HloOpcode::kDot: { TF_RET_CHECK(proto.has_dot_dimension_numbers()) @@ -506,8 +506,8 @@ StatusOr> HloInstruction::CreateFromProto( precision_config.mutable_operand_precision()->Resize( proto.operand_ids_size(), PrecisionConfig::DEFAULT); instruction = absl::make_unique( - proto.shape(), operands(0), operands(1), - proto.dot_dimension_numbers(), precision_config); + shape, operands(0), operands(1), proto.dot_dimension_numbers(), + precision_config); break; } case HloOpcode::kDomain: { @@ -529,7 +529,7 @@ StatusOr> HloInstruction::CreateFromProto( exit_hlo_sharding = std::make_shared(sharding); } instruction = absl::make_unique( - proto.shape(), operands(0), + shape, operands(0), absl::make_unique(entry_hlo_sharding), absl::make_unique(exit_hlo_sharding)); break; @@ -537,11 +537,11 @@ StatusOr> HloInstruction::CreateFromProto( case HloOpcode::kGetDimensionSize: TF_RET_CHECK(proto.operand_ids_size() == 1); TF_RET_CHECK(proto.dimensions_size() == 1); - instruction = CreateGetDimensionSize(proto.shape(), operands(0), - proto.dimensions(0)); + instruction = + CreateGetDimensionSize(shape, operands(0), proto.dimensions(0)); break; default: { - instruction = absl::WrapUnique(new HloInstruction(opcode, proto.shape())); + instruction = absl::WrapUnique(new HloInstruction(opcode, shape)); for (const int64 operand_id : proto.operand_ids()) { instruction->AppendOperand(instruction_map.at(operand_id)); } @@ -855,6 +855,16 @@ HloInstruction::CreateCollectivePermute( new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape())); } +/* static */ std::unique_ptr +HloInstruction::CreateAddDependency(HloInstruction* data_operand, + HloInstruction* token_operand) { + auto instruction = absl::WrapUnique( + new HloInstruction(HloOpcode::kAddDependency, data_operand->shape())); + instruction->AppendOperand(data_operand); + instruction->AppendOperand(token_operand); + return instruction; +} + /* static */ std::unique_ptr HloInstruction::CreateWhile( const Shape& shape, HloComputation* condition, HloComputation* body, HloInstruction* init) { @@ -1394,6 +1404,10 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( clone = CreateAfterAll(new_operands); } break; + case HloOpcode::kAddDependency: + CHECK_EQ(new_operands.size(), 2); + clone = CreateAddDependency(new_operands[0], new_operands[1]); + break; } // SetupDerivedInstruction will setup the precision_config_ field. SetupDerivedInstruction(clone.get()); @@ -1680,6 +1694,7 @@ bool HloInstruction::IdenticalSlowPath( // This opcode has complex or special behavior so just return false. case HloOpcode::kAfterAll: + case HloOpcode::kAddDependency: return false; // Remaining instructions with special values. @@ -1745,6 +1760,26 @@ bool HloInstruction::IdenticalSlowPath( return false; } +uint64 HloInstruction::Hash() const { + using tensorflow::Hash64Combine; + + uint64 hash_value = Hash64Combine(0, static_cast(opcode())); + hash_value = Hash64Combine(hash_value, ShapeUtil::Hash(shape())); + + if (!IsCrossModuleAllReduce()) { + if (!operands().empty()) { + for (size_t i = 0; i < operands().size(); ++i) { + hash_value = Hash64Combine(hash_value, operand(i)->Hash()); + } + } + } + + hash_value = Hash64Combine(hash_value, InnerHash()); + return hash_value; +} + +uint64 HloInstruction::InnerHash() const { return 13; } + void HloInstruction::RemoveUser(HloInstruction* user) { auto set_it = user_set_.find(user); CHECK(set_it != user_set_.end()); @@ -1900,6 +1935,11 @@ void HloInstruction::set_while_body(HloComputation* computation) { called_computations_[kBodyComputationIndex] = computation; } +HloInstruction* HloInstruction::while_init() const { + CHECK_EQ(HloOpcode::kWhile, opcode_); + return operands_[0]; +} + HloComputation* HloInstruction::true_computation() const { CHECK_EQ(HloOpcode::kConditional, opcode_); return called_computations_[kTrueComputationIndex]; @@ -2214,7 +2254,7 @@ HloInstructionProto HloInstruction::ToProto() const { proto.set_id(unique_id_); proto.set_name(name_); proto.set_opcode(HloOpcodeString(opcode_)); - *proto.mutable_shape() = shape_; + *proto.mutable_shape() = shape_.ToProto(); for (const HloInstruction* operand : operands_) { proto.add_operand_ids(operand->unique_id()); } @@ -2462,6 +2502,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleDomain(this); case HloOpcode::kAfterAll: return visitor->HandleAfterAll(this); + case HloOpcode::kAddDependency: + return visitor->HandleAddDependency(this); case HloOpcode::kIota: return visitor->HandleIota(this); case HloOpcode::kGetDimensionSize: @@ -2623,36 +2665,6 @@ Status HloInstruction::AcceptWithOperandOrder( return Status::OK(); } -namespace { - -// Returns true if the given order is a topological sort of the instructions -// it contains. -bool OrderIsTopologicalSort(const std::vector& order) { - // Create a map from instruction to its position in 'order'. - std::unordered_map order_position; - for (int i = 0; i < order.size(); i++) { - if (!order_position.insert({order[i], i}).second) { - // Instruction order[i] is duplicated in the order. - return false; - } - } - // Verify that the operand of each instruction in the order is also in the - // order *and* the operand's position is earlier (defs are before uses for - // all ops). - for (auto* instruction : order) { - for (auto* operand : instruction->operands()) { - if (!ContainsKey(order_position, operand) || - order_position.at(operand) >= order_position.at(instruction)) { - return false; - } - } - } - - return true; -} - -} // namespace - Status HloInstruction::Accept( const std::function& visitor_func) { FunctionVisitor visitor(visitor_func); @@ -3022,6 +3034,16 @@ const PrecisionConfig& HloInstruction::precision_config() const { LOG(FATAL) << "Unimplemented method."; } +PrecisionConfig* HloInstruction::mutable_precision_config() { + if (auto* convolution = DynCast(this)) { + return convolution->mutable_precision_config(); + } + if (auto* dot = DynCast(this)) { + return dot->mutable_precision_config(); + } + LOG(FATAL) << "Unimplemented method."; +} + HloModule* HloInstruction::GetModule() const { if (parent_) { return parent_->parent(); @@ -3064,6 +3086,10 @@ int64 HloInstruction::concatenate_dimension() const { return Cast(this)->concatenate_dimension(); } +int64 HloInstruction::dimension() const { + return Cast(this)->dimension(); +} + bool HloInstruction::IsRank2Transpose() const { auto transpose = DynCast(this); return transpose != nullptr && transpose->IsRank2Transpose(); @@ -3243,6 +3269,11 @@ absl::optional HloInstruction::all_reduce_id() const { return Cast(this)->all_reduce_id(); } +void HloInstruction::set_all_reduce_id( + const absl::optional& all_reduce_id) { + return Cast(this)->set_all_reduce_id(all_reduce_id); +} + const ConvolutionDimensionNumbers& HloInstruction::convolution_dimension_numbers() const { if (auto convolution = DynCast(this)) { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 818d4ede0f3..a54716217d6 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -770,6 +770,9 @@ class HloInstruction { static std::unique_ptr CreateGetDimensionSize( const Shape& shape, HloInstruction* operand, int64 dimension); + static std::unique_ptr CreateAddDependency( + HloInstruction* data_operand, HloInstruction* token_operand); + // Returns the opcode for this instruction. HloOpcode opcode() const { return opcode_; } @@ -883,11 +886,15 @@ class HloInstruction { return false; } - // Use an explicit loop rather than ContainerEquals, because copying around - // std::functions may be too expensive in some cases. - for (size_t i = 0; i < operands().size(); ++i) { - if (!eq_operands(operand(i), other.operand(i))) { - return false; + // Two AllReduces are Identical if they have the same all_reduce_id. + // Their operands don't have to be Identical. + if (!IsCrossModuleAllReduce()) { + // Use an explicit loop rather than ContainerEquals, because copying + // around std::functions may be too expensive in some cases. + for (size_t i = 0; i < operands().size(); ++i) { + if (!eq_operands(operand(i), other.operand(i))) { + return false; + } } } @@ -898,6 +905,12 @@ class HloInstruction { return IdenticalSlowPath(other, eq_computations); } + // Generates a hash value of an HLO instruction. Hash considers + // information on opcode, shape, operands, and typically a root instruction. + // This function returns the same hash value for equivalent HLO instructions, + // with respect to HloInstruction::Identical() method. + uint64 Hash() const; + // Returns whether the instruction has a constant operand. bool HasConstantOperand() const; @@ -997,6 +1010,8 @@ class HloInstruction { void set_while_condition(HloComputation* while_condition); void set_while_body(HloComputation* while_body); + HloInstruction* while_init() const; + // Gets/sets the true and false HloComputation for Conditional. The setters // should only be called by HloModule or HloComputation methods. // @@ -1257,6 +1272,7 @@ class HloInstruction { // superior. // Precondition: opcode must be kConvolution or kDot. const PrecisionConfig& precision_config() const; + PrecisionConfig* mutable_precision_config(); // Sets the debug metadata for this instruction. void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; } @@ -1317,6 +1333,9 @@ class HloInstruction { // Delegates to HloConcatenateInstruction::concatenate_dimension. int64 concatenate_dimension() const; + // Delegates to HloGetDimensionSizeInstruction::dimension. + int64 dimension() const; + // Returns whether this instruction does a rank-2 transposition. bool IsRank2Transpose() const; @@ -1435,6 +1454,7 @@ class HloInstruction { // Delegates to HloAllReduceInstruction::all_reduce_id. absl::optional all_reduce_id() const; + void set_all_reduce_id(const absl::optional& all_reduce_id); // Returns data on the window in a windowed operation such as // convolution. @@ -1599,6 +1619,10 @@ class HloInstruction { const std::function& eq_computations) const; + // Generates a hash value specific to a particular type of an instruction. + // This function typically considers the inner root instruction. + virtual uint64 InnerHash() const; + // Creates an n-ary elementwise operation. static std::unique_ptr CreateNary( const Shape& shape, HloOpcode opcode, diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 4c765aa375c..1ea02cf9c03 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -370,6 +370,11 @@ HloAllReduceInstruction::HloAllReduceInstruction( AppendComputation(reduce_computation); } +void HloAllReduceInstruction::set_all_reduce_id( + const absl::optional& all_reduce_id) { + all_reduce_id_ = all_reduce_id; +} + HloInstructionProto HloAllReduceInstruction::ToProto() const { HloInstructionProto proto = HloCollectiveInstruction::ToProto(); // Proto3 is so sad. @@ -1367,6 +1372,10 @@ bool HloFusionInstruction::IdenticalSlowPath( other.fused_instructions_computation()); } +uint64 HloFusionInstruction::InnerHash() const { + return fused_instructions_computation()->Hash(); +} + std::unique_ptr HloFusionInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { @@ -1610,7 +1619,7 @@ HloOutfeedInstruction::HloOutfeedInstruction(const Shape& outfeed_shape, HloInstructionProto HloOutfeedInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); proto.set_outfeed_config(outfeed_config()); - *proto.mutable_outfeed_shape() = outfeed_shape(); + *proto.mutable_outfeed_shape() = outfeed_shape().ToProto(); return proto; } @@ -1862,7 +1871,7 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const { if (layout_constrained()) { proto.set_constrain_layout(true); for (const Shape& shape : operand_shapes_with_layout_) { - *proto.add_operand_shapes_with_layout() = shape; + *proto.add_operand_shapes_with_layout() = shape.ToProto(); } } return proto; diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index d43a8973ccf..b5c28137a14 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -252,6 +252,7 @@ class HloAllReduceInstruction : public HloCollectiveInstruction { } absl::optional all_reduce_id() const { return all_reduce_id_; } + void set_all_reduce_id(const absl::optional& all_reduce_id); // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -742,6 +743,8 @@ class HloFusionInstruction : public HloInstruction { const HloInstruction& other, const std::function& eq_computations) const override; + uint64 InnerHash() const override; + // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, @@ -954,6 +957,7 @@ class HloConvolutionInstruction : public HloInstruction { // information but it is presumed that the alternate lowering is strictly // superior. const PrecisionConfig& precision_config() const { return precision_config_; } + PrecisionConfig* mutable_precision_config() { return &precision_config_; } string ToCategory() const override; // Returns a serialized representation of this instruction. @@ -1325,6 +1329,7 @@ class HloDotInstruction : public HloInstruction { // information but it is presumed that the alternate lowering is strictly // superior. const PrecisionConfig& precision_config() const { return precision_config_; } + PrecisionConfig* mutable_precision_config() { return &precision_config_; } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; diff --git a/tensorflow/compiler/xla/service/hlo_lexer.h b/tensorflow/compiler/xla/service/hlo_lexer.h index 3e2f8bcd52f..d6a2b292a39 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.h +++ b/tensorflow/compiler/xla/service/hlo_lexer.h @@ -20,6 +20,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_token.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/compiler/xla/service/hlo_matchers.cc b/tensorflow/compiler/xla/service/hlo_matchers.cc index 5269cad94d3..d28e79d41ad 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers.cc @@ -237,8 +237,4 @@ void PrintTo(const HloInstruction* inst, ::std::ostream* os) { *os << (inst ? inst->ToString() : "nullptr"); } -void PrintTo(HloInstruction* inst, ::std::ostream* os) { - PrintTo(const_cast(inst), os); -} - } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index 170ec93a334..235efb19ce4 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -385,7 +385,6 @@ std::vector Pointers(const Container& container) { // Tell GMock to print HloInstruction* by value, so error messages are nice. // Has to be in the same namespace as 'HloInstruction'. void PrintTo(const HloInstruction* inst, ::std::ostream* os); -void PrintTo(HloInstruction* inst, ::std::ostream* os); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc index 234fcd266aa..d2740bcce26 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc @@ -73,7 +73,7 @@ class ListScheduler { // Construct and return a memory-minimizing sequence of HLO instructions // containing the given HLO computation. static StatusOr Run( - const HloComputation& computation, + HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const absl::flat_hash_map& @@ -98,7 +98,7 @@ class ListScheduler { // comparison operators. using Priority = std::pair; - ListScheduler(const HloComputation& computation, + ListScheduler(HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const absl::flat_hash_map& @@ -111,7 +111,7 @@ class ListScheduler { // instruction. An HLO instruction "uses" a LogicalBuffer if the // LogicalBuffer is in an operand of the instruction as indicated by // points-to analysis. - for (auto* instruction : computation.instructions()) { + for (auto* instruction : computation->instructions()) { absl::flat_hash_set instr_uses; for (auto* operand : instruction->operands()) { points_to_analysis.GetPointsToSet(operand).ForEachElement( @@ -126,13 +126,13 @@ class ListScheduler { // Create map containing the number of unscheduled uses (hlo instructions) // of each logical buffer. - for (auto* instruction : computation.instructions()) { + for (auto* instruction : computation->instructions()) { for (auto* buffer : points_to_analysis.GetBuffersDefinedByInstruction(instruction)) { unscheduled_use_count_[buffer] = 0; } } - for (auto* instruction : computation.instructions()) { + for (auto* instruction : computation->instructions()) { for (const LogicalBuffer* buffer : buffer_uses_.at(instruction)) { ++unscheduled_use_count_[buffer]; } @@ -141,7 +141,7 @@ class ListScheduler { // Buffers live out of the computation have an implicit use at the end of // the computation. for (const LogicalBuffer* live_out_buffer : - points_to_analysis.GetPointsToSet(computation.root_instruction()) + points_to_analysis.GetPointsToSet(computation->root_instruction()) .CreateFlattenedSet()) { ++unscheduled_use_count_[live_out_buffer]; } @@ -157,7 +157,7 @@ class ListScheduler { // HloInstruction, plus some cached metadata, saved for the purposes of making // BytesFreedIfScheduled fast. struct ReadyListEntry { - const HloInstruction* instruction; + HloInstruction* instruction; // The total size of all buffers defined by this instruction. int64 bytes_defined; @@ -171,7 +171,7 @@ class ListScheduler { }; // Creates a ReadyListEntry for the given instruction. - ReadyListEntry MakeReadyListEntry(const HloInstruction* instruction) { + ReadyListEntry MakeReadyListEntry(HloInstruction* instruction) { ReadyListEntry entry; entry.instruction = instruction; @@ -250,13 +250,13 @@ class ListScheduler { // Populate the ready list with instructions which have no operands or // control predecessors. absl::flat_hash_map unscheduled_pred_count; - for (auto* instruction : computation_.instructions()) { + for (auto* instruction : computation_->instructions()) { // TODO(b/34466113): Replace this and above with successors() or // predecessors() when these methods are added to HloInstruction. - for (const HloInstruction* user : instruction->users()) { + for (HloInstruction* user : instruction->users()) { unscheduled_pred_count[user]++; } - for (const HloInstruction* succ : instruction->control_successors()) { + for (HloInstruction* succ : instruction->control_successors()) { unscheduled_pred_count[succ]++; } } @@ -275,7 +275,7 @@ class ListScheduler { ready_instructions[inst] = it; }; - for (auto* instruction : computation_.instructions()) { + for (auto* instruction : computation_->instructions()) { if (instruction->operands().empty() && instruction->control_predecessors().empty()) { add_to_ready_queue(instruction); @@ -287,7 +287,7 @@ class ListScheduler { // schedule. auto best_it = ready_queue.end(); --best_it; - const HloInstruction* best = best_it->second.instruction; + HloInstruction* best = best_it->second.instruction; VLOG(2) << "Schedule instruction: " << best->ToShortString() << " Bytes freed: " << best_it->first.first; ready_queue.erase(best_it); @@ -348,13 +348,13 @@ class ListScheduler { } } } - CHECK_EQ(schedule.size(), computation_.instruction_count()); - CHECK_EQ(scheduled_instructions_.size(), computation_.instruction_count()); + CHECK_EQ(schedule.size(), computation_->instruction_count()); + CHECK_EQ(scheduled_instructions_.size(), computation_->instruction_count()); return schedule; } - const HloComputation& computation_; + HloComputation* computation_; const TuplePointsToAnalysis& points_to_analysis_; const LogicalBuffer::SizeFunction& size_function_; // Computations are analyzed in post-order. When scheduling an instruction @@ -386,13 +386,13 @@ int64 SumLogicalBufferSizes( } StatusOr ScheduleComputationHelper( - const HloComputation& computation, + HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const MemorySchedulerAlgorithm& algorithm, const absl::flat_hash_map& memory_by_computation) { - VLOG(2) << "Computation: " << computation.name(); + VLOG(2) << "Computation: " << computation->name(); if (algorithm) { return algorithm(computation, points_to_analysis, size_function, memory_by_computation); @@ -404,17 +404,17 @@ StatusOr ScheduleComputationHelper( } // namespace StatusOr DFSMemoryScheduler( - const HloComputation& computation, + HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const absl::flat_hash_map& memory_by_computation) { // These variables are a hack to prevent overflows. int64 cumulative_total_size = 0; - int64 total_hlos = computation.parent()->instruction_count(); + int64 total_hlos = computation->parent()->instruction_count(); absl::flat_hash_map extra_users; absl::flat_hash_map total_sizes; - for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) { + for (const HloInstruction* hlo : computation->MakeInstructionPostOrder()) { if (ListScheduler::IgnoreInstruction(*hlo)) { extra_users[hlo] = 0; total_sizes[hlo] = 0; @@ -448,8 +448,8 @@ StatusOr DFSMemoryScheduler( total_sizes[hlo] = std::min(total_sizes[hlo], cumulative_total_size); extra_users[hlo] = std::min(extra_users[hlo], total_hlos); } - CHECK_EQ(extra_users.size(), computation.instruction_count()); - CHECK_EQ(total_sizes.size(), computation.instruction_count()); + CHECK_EQ(extra_users.size(), computation->instruction_count()); + CHECK_EQ(total_sizes.size(), computation->instruction_count()); // Construct a total order based on DFS post-order, visiting operands in // decreasing cumulative extra user order, and next by cumulative size, with a @@ -459,7 +459,7 @@ StatusOr DFSMemoryScheduler( sequence.push_back(hlo); return Status::OK(); }); - TF_RETURN_IF_ERROR(computation.AcceptWithOperandOrder( + TF_RETURN_IF_ERROR(computation->AcceptWithOperandOrder( &visitor, [&extra_users, &total_sizes](const HloInstruction* a, const HloInstruction* b) { if (extra_users[a] != extra_users[b]) { @@ -470,12 +470,12 @@ StatusOr DFSMemoryScheduler( } return a->name() < b->name(); })); - CHECK_EQ(sequence.size(), computation.instruction_count()); + CHECK_EQ(sequence.size(), computation->instruction_count()); return sequence; } // namespace xla StatusOr ListMemoryScheduler( - const HloComputation& computation, + HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const absl::flat_hash_map& @@ -485,16 +485,16 @@ StatusOr ListMemoryScheduler( } StatusOr PostOrderMemoryScheduler( - const HloComputation& computation, + HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const absl::flat_hash_map& memory_by_computation) { - return HloInstructionSequence(computation.MakeInstructionPostOrder()); + return HloInstructionSequence(computation->MakeInstructionPostOrder()); } StatusOr DefaultMemoryScheduler( - const HloComputation& computation, + HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const absl::flat_hash_map& @@ -513,7 +513,7 @@ StatusOr DefaultMemoryScheduler( memory_by_computation)); TF_ASSIGN_OR_RETURN(const int64 list_memory, HeapSimulator::MinimumMemoryForComputation( - computation, list_sequence, points_to_analysis, + *computation, list_sequence, points_to_analysis, size_function, &memory_by_computation)); VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory); @@ -522,7 +522,7 @@ StatusOr DefaultMemoryScheduler( size_function, memory_by_computation)); TF_ASSIGN_OR_RETURN(const int64 dfs_memory, HeapSimulator::MinimumMemoryForComputation( - computation, dfs_sequence, points_to_analysis, + *computation, dfs_sequence, points_to_analysis, size_function, &memory_by_computation)); VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory); @@ -532,7 +532,7 @@ StatusOr DefaultMemoryScheduler( memory_by_computation)); TF_ASSIGN_OR_RETURN(const int64 post_order_memory, HeapSimulator::MinimumMemoryForComputation( - computation, post_order_sequence, points_to_analysis, + *computation, post_order_sequence, points_to_analysis, size_function, &memory_by_computation)); VLOG(2) << "Min-memory post order sequence: " << HumanReadableNumBytes(post_order_memory); @@ -555,17 +555,17 @@ StatusOr DefaultMemoryScheduler( } StatusOr ScheduleModule( - const HloModule& module, const LogicalBuffer::SizeFunction& size_function, + HloModule* module, const LogicalBuffer::SizeFunction& size_function, const MemorySchedulerAlgorithm& algorithm) { - HloSchedule schedule(&module); + HloSchedule schedule(module); TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, - TuplePointsToAnalysis::Run(&module)); + TuplePointsToAnalysis::Run(module)); absl::flat_hash_map memory_by_computation; - for (const auto* computation : module.MakeComputationPostOrder()) { + for (auto* computation : module->MakeComputationPostOrder()) { if (!computation->IsFusionComputation()) { TF_ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence, ScheduleComputationHelper( - *computation, *points_to_analysis, size_function, + computation, *points_to_analysis, size_function, algorithm, memory_by_computation)); memory_by_computation[computation] = HeapSimulator::MinimumMemoryForComputation( @@ -583,11 +583,11 @@ StatusOr ScheduleModule( } StatusOr ScheduleComputation( - const HloComputation& computation, + HloComputation* computation, const LogicalBuffer::SizeFunction& size_function) { - CHECK(!computation.IsFusionComputation()); + CHECK(!computation->IsFusionComputation()); TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, - TuplePointsToAnalysis::Run(computation.parent())); + TuplePointsToAnalysis::Run(computation->parent())); absl::flat_hash_map empty_map; return ScheduleComputationHelper(computation, *points_to_analysis, size_function, nullptr, empty_map); @@ -600,7 +600,7 @@ HloMemoryScheduler::HloMemoryScheduler( StatusOr HloMemoryScheduler::Run(HloModule* module) { TF_ASSIGN_OR_RETURN(HloSchedule schedule, - ScheduleModule(*module, size_function_, algorithm_)); + ScheduleModule(module, size_function_, algorithm_)); TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule))); return true; } diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h index cca5dc49398..7227bfb27c7 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h @@ -36,14 +36,14 @@ namespace xla { // that describes buffer aliasing, together with a target-specific size function // that maps a tensor's logical size to its padded size. typedef std::function( - const HloComputation&, const TuplePointsToAnalysis&, + HloComputation*, const TuplePointsToAnalysis&, const LogicalBuffer::SizeFunction&, const absl::flat_hash_map&)> MemorySchedulerAlgorithm; // List scheduler StatusOr ListMemoryScheduler( - const HloComputation& computation, + HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const absl::flat_hash_map& @@ -51,7 +51,7 @@ StatusOr ListMemoryScheduler( // DFS-order scheduler StatusOr DFSMemoryScheduler( - const HloComputation& computation, + HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const absl::flat_hash_map& @@ -59,7 +59,7 @@ StatusOr DFSMemoryScheduler( // Naive Post Order scheduler StatusOr PostOrderMemoryScheduler( - const HloComputation& computation, + HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const absl::flat_hash_map& @@ -69,7 +69,7 @@ StatusOr PostOrderMemoryScheduler( // and the DFS scheduler, and chooses whichever returns a lower min-memory, // not accounting for fragmentation. StatusOr DefaultMemoryScheduler( - const HloComputation& computation, + HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const absl::flat_hash_map& @@ -79,13 +79,13 @@ StatusOr DefaultMemoryScheduler( // the computation. size_function is the function returning the number of bytes // required for a LogicalBuffer. StatusOr ScheduleModule( - const HloModule& module, const LogicalBuffer::SizeFunction& size_function, + HloModule* module, const LogicalBuffer::SizeFunction& size_function, const MemorySchedulerAlgorithm& algorithm = {}); // Computes the schedule for a single computation. // Currently only used by the GPU backend. StatusOr ScheduleComputation( - const HloComputation& computation, + HloComputation* computation, const LogicalBuffer::SizeFunction& size_function); // A pass which schedules the HLO instructions in a module. The HloModule's diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc index 984a6266abb..bc0d7e2bc00 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc @@ -65,7 +65,7 @@ TEST_F(HloSchedulingTest, LastUseScheduledFirst) { auto sub = builder.AddInstruction( HloInstruction::CreateBinary(vec, HloOpcode::kSubtract, add, negate)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); HloMemoryScheduler scheduler([](const BufferValue& buffer) { @@ -78,7 +78,7 @@ TEST_F(HloSchedulingTest, LastUseScheduledFirst) { TF_ASSERT_OK(module->schedule().Verify()); // Verify that all instructions are in the sequence. - const std::vector& sequence = + const std::vector& sequence = module->schedule().sequence(module->entry_computation()).instructions(); EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size()); @@ -124,9 +124,9 @@ ENTRY root { }; TF_ASSERT_OK_AND_ASSIGN( HloSchedule schedule, - ScheduleModule(*module, size_fn, ListMemoryScheduler)); + ScheduleModule(module.get(), size_fn, ListMemoryScheduler)); // Verify that all instructions are in the sequence. - const std::vector& sequence = + const std::vector& sequence = schedule.sequence(module->entry_computation()).instructions(); EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size()); @@ -172,15 +172,16 @@ TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) { builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, tuple_elm, abs_abs2)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule, - ScheduleModule(*module, - [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf( - buffer.shape(), TUPLE_SIZE); - }, - ListMemoryScheduler)); + ScheduleModule( + module.get(), + [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), + TUPLE_SIZE); + }, + ListMemoryScheduler)); // Verify that all instructions are in the sequence. EXPECT_EQ(module->entry_computation()->instruction_count(), @@ -218,19 +219,19 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) { builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, tuple_elm, exp)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto fusion = computation->CreateFusionInstruction( {tuple, mul, add}, HloInstruction::FusionKind::kLoop); TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule, - ScheduleModule(*module, - [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf( - buffer.shape(), 2); - }, - ListMemoryScheduler)); + ScheduleModule( + module.get(), + [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), 2); + }, + ListMemoryScheduler)); // Verify that all instructions are in the sequence. EXPECT_EQ(module->entry_computation()->instruction_count(), @@ -252,7 +253,7 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { HloInstruction::CreateParameter(0, r1f32, "cond_param")); HloInstruction* zero_vector = cond_builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{0, 0, 0, 0}}))); + LiteralUtil::CreateR1({0, 0, 0, 0}))); cond_builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector)); auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); @@ -284,7 +285,7 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { }; TF_ASSERT_OK_AND_ASSIGN( HloSchedule schedule, - ScheduleModule(*module, size_fn, ListMemoryScheduler)); + ScheduleModule(module.get(), size_fn, ListMemoryScheduler)); // Verify that all instructions are in the sequence. auto entry_computation = module->entry_computation(); EXPECT_EQ(module->entry_computation()->instruction_count(), diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 14bf17f4be1..fe8371384c0 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -240,8 +240,10 @@ HloModuleProto HloModule::ToProto() const { *proto.mutable_schedule() = schedule().ToProto().ValueOrDie(); } *proto.mutable_host_program_shape() = - entry_computation_layout().ComputeProgramShape(); + entry_computation_layout().ComputeProgramShape().ToProto(); *proto.mutable_input_output_alias() = input_output_alias_config().ToProto(); + *proto.mutable_dynamic_parameter_binding() = + dynamic_parameter_binding().ToProto(); return proto; } @@ -255,7 +257,7 @@ StatusOr> HloModule::CreateFromProto( // the entry parameters and root. TF_RET_CHECK(proto.has_host_program_shape()) << "No program shape found in the proto"; - const auto& expected_program_shape = proto.host_program_shape(); + ProgramShape expected_program_shape(proto.host_program_shape()); TF_RET_CHECK(expected_program_shape.parameters_size() == module_config.entry_computation_layout().parameter_count()); for (int i = 0; i < expected_program_shape.parameters_size(); ++i) { @@ -325,6 +327,10 @@ StatusOr> HloModule::CreateFromProto( // Because we didn't uniquify the names or the ids, double-check that the // instruction and computation names and ids are unique from the proto. + TF_ASSIGN_OR_RETURN(module->dynamic_parameter_binding_, + DynamicParameterBinding::CreateFromProto( + proto.dynamic_parameter_binding())); + absl::flat_hash_set computation_names; absl::flat_hash_set instruction_names; absl::flat_hash_set computation_ids; @@ -363,9 +369,9 @@ StatusOr HloModule::CreateModuleConfigFromProto( const HloModuleProto& module, const DebugOptions& debug_options) { TF_RET_CHECK(module.has_host_program_shape()) << "No program shape found in the proto"; - const auto& program_shape = module.host_program_shape(); + ProgramShape program_shape(module.host_program_shape()); - HloModuleConfig module_config(program_shape); + HloModuleConfig module_config(ProgramShape{program_shape}); module_config.set_debug_options(debug_options); // The module config is constructed with default layouts regardless of what is diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 8a1f999e3ab..7b9cbf9a53a 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -28,6 +28,7 @@ limitations under the License. #include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/iterator_util.h" +#include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_clone_context.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -103,11 +104,7 @@ class HloModule { HloCloneContext* context = nullptr); // Return a pointer to the entry computation of the module. - const HloComputation* entry_computation() const { - CHECK_NE(nullptr, entry_computation_); - return entry_computation_; - } - HloComputation* entry_computation() { + HloComputation* entry_computation() const { CHECK_NE(nullptr, entry_computation_); return entry_computation_; } @@ -135,6 +132,12 @@ class HloModule { return config_.entry_computation_layout(); } + // Generates a hash value of an HLO module. Hash considers + // information on opcode, shape, operands, and typically a root instruction. + // This function returns the same hash value for equivalent HLO modules, + // with respect to HloInstruction::Identical() method. + uint64 Hash() const { return entry_computation()->Hash(); } + // Gets the computations in this module. // // Returns a view of HloComputation*s, so you can iterate over this in the @@ -232,6 +235,16 @@ class HloModule { return input_output_alias_config_; } + // DynamicParameterBinding holds the list of bindings that indicates which + // parameter dimensions are dynamic and which parameters represent their + // runtime value. + DynamicParameterBinding& dynamic_parameter_binding() { + return dynamic_parameter_binding_; + } + const DynamicParameterBinding& dynamic_parameter_binding() const { + return dynamic_parameter_binding_; + } + // Returns an id that is unique to this module across all modules created over // the lifetime of this process. int unique_id() const { return unique_id_; } @@ -285,6 +298,9 @@ class HloModule { // alias_config indicates the alias information of input/output buffers that // are expected from the module. HloInputOutputAliasConfig input_output_alias_config_; + + // Bindings for dynamic parameter mapping. + DynamicParameterBinding dynamic_parameter_binding_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc index 3ae67e4e5ee..620cb7e01ad 100644 --- a/tensorflow/compiler/xla/service/hlo_module_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_test.cc @@ -63,7 +63,7 @@ class HloModuleTest : public HloTestBase { TEST_F(HloModuleTest, OneComputationPostOrder) { // Create a module with a single computation. - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(CreateConstantComputation()); EXPECT_THAT(module->MakeComputationPostOrder(), @@ -72,7 +72,7 @@ TEST_F(HloModuleTest, OneComputationPostOrder) { TEST_F(HloModuleTest, TwoComputationsPostOrder) { // Create a module with two unconnected computations. - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation1 = module->AddEntryComputation(CreateConstantComputation()); auto computation2 = module->AddEmbeddedComputation(CreateConstantComputation()); @@ -88,7 +88,7 @@ TEST_F(HloModuleTest, TwoComputationsPostOrder) { TEST_F(HloModuleTest, CloneTest) { // Create and copy a module with a diamond call graph of computations. - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation1 = module->AddEmbeddedComputation(CreateConstantComputation()); auto computation2 = @@ -111,7 +111,7 @@ TEST_F(HloModuleTest, CloneTest) { } TEST_F(HloModuleTest, CloneHasFusion) { - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); // Create the fused computation. HloComputation* fused_computation; @@ -154,7 +154,7 @@ TEST_F(HloModuleTest, CloneHasFusion) { TEST_F(HloModuleTest, DiamondComputationsPostOrder) { // Create a module with a diamond call graph of computations. - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation1 = module->AddEmbeddedComputation(CreateConstantComputation()); auto computation2 = @@ -174,7 +174,7 @@ TEST_F(HloModuleTest, DiamondComputationsPostOrder) { TEST_F(HloModuleTest, LargeConstantToString) { // Create a module with a single computation. - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder("Constant"); std::vector values(16, 42.0); builder.AddInstruction( @@ -194,8 +194,8 @@ TEST_F(HloModuleTest, LargeConstantToString) { } TEST_F(HloModuleTest, UniqueModuleId) { - auto module_a = CreateNewUnverifiedModule(); - auto module_b = CreateNewUnverifiedModule(); + auto module_a = CreateNewVerifiedModule(); + auto module_b = CreateNewVerifiedModule(); EXPECT_NE(module_a->unique_id(), module_b->unique_id()); } diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 70c7d70b41c..127cfd165a5 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -47,6 +47,8 @@ namespace xla { #define HLO_OPCODE_LIST(V) \ V(kAbs, "abs") \ V(kAdd, "add") \ + V(kAddDependency, "add-dependency") \ + V(kAfterAll, "after-all", kHloOpcodeIsVariadic) \ V(kAllToAll, "all-to-all") \ V(kAtan2, "atan2") \ V(kBatchNormGrad, "batch-norm-grad") \ @@ -84,7 +86,6 @@ namespace xla { V(kGather, "gather") \ V(kGe, "greater-than-or-equal-to", kHloOpcodeIsComparison) \ V(kGetDimensionSize, "get-dimension-size") \ - V(kAfterAll, "after-all", kHloOpcodeIsVariadic) \ V(kGetTupleElement, "get-tuple-element") \ V(kGt, "greater-than", kHloOpcodeIsComparison) \ V(kImag, "imag") \ diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index f5f99bece18..ca6a154809b 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -356,8 +356,7 @@ void SequentialHloOrdering::Initialize() { // Create a map from instruction to its order position. TF_DCHECK_OK(schedule_.Verify()); for (const auto& computation_sequence : schedule_.sequences()) { - const std::vector& order = - computation_sequence.second.instructions(); + const auto& order = computation_sequence.second.instructions(); for (int i = 0; i < order.size(); ++i) { InsertOrDie(&order_position_, order[i], i); } diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index 2ab8aa57f6e..3ca77e60cd5 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -53,7 +53,7 @@ TEST_F(HloOrderingTest, InstructionsInDifferentComputations) { // %c = Constant(42.0f) // // This results in a diamond-shaped callgraph. - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); auto builder_c = HloComputation::Builder("C"); @@ -126,7 +126,7 @@ TEST_F(HloOrderingTest, InstructionsInWhileComputations) { // %constant = Constant(1.0) // return While(%constant, body, condition) // - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); auto body_builder = HloComputation::Builder("body"); @@ -176,7 +176,7 @@ TEST_F(HloOrderingTest, InstructionsInWhileComputations) { TEST_F(HloOrderingTest, ParametersDefinedBeforeOthers) { // Entry parameter should always be defined before other instruction. - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( @@ -209,7 +209,7 @@ TEST_F(HloOrderingTest, ValuesInWhileComputations) { // %while = While(%constant, body, condition) // %add = Add(%constant, %while) // - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); auto body_builder = HloComputation::Builder("body"); @@ -407,7 +407,7 @@ TEST_F(HloOrderingTest, // %dead = Constant(123.0) // // %root should interfere with %dead. - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); auto builder = HloComputation::Builder(TestName()); @@ -455,7 +455,7 @@ TEST_F(HloOrderingTest, // ROOT %call = call({%c}), subcomputation // // %root should interfere with %dead. - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); auto subbuilder = HloComputation::Builder(TestName() + ".sub"); diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 4390145c6bd..9b5bb5d0bd6 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -47,11 +47,11 @@ const double kF16max = 65504; // Creates and returns a schedule created using the order of the instructions in // the HloComputation::instructions() vectors in the module. -HloSchedule ScheduleFromInstructionOrder(const HloModule* module) { +HloSchedule ScheduleFromInstructionOrder(HloModule* module) { HloSchedule schedule(module); - for (const HloComputation* computation : module->computations()) { + for (HloComputation* computation : module->computations()) { if (!computation->IsFusionComputation()) { - for (const HloInstruction* instruction : computation->instructions()) { + for (HloInstruction* instruction : computation->instructions()) { schedule.GetOrCreateSequence(computation).push_back(instruction); } } @@ -850,6 +850,15 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, } break; } + case HloOpcode::kAddDependency: { + if (!ParseOperands(&operands, /*expected_size=*/2) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateAddDependency(operands[0], operands[1])); + break; + } case HloOpcode::kSort: { optional> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index c59bdc0a0b3..ab71f011ac9 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -21,7 +21,8 @@ limitations under the License. #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" -#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -29,7 +30,7 @@ limitations under the License. namespace xla { namespace { -namespace op = ::xla::testing::opcode_matchers; +namespace m = ::xla::match; using absl::string_view; struct TestData { @@ -195,7 +196,7 @@ ENTRY %add_constants () -> f32[] { R"(HloModule TupleConstant_module ENTRY %TupleConstant.v1 () -> (f32[2,1], f32[2]) { - ROOT %constant = (f32[2,1]{1,0}, f32[2]{0}) constant((f32[2,1], f32[2]) ( f32[2,1] { { 1 }, { 2 } }, {2, 42} )) + ROOT %constant = (f32[2,1]{1,0}, f32[2]{0}) constant((f32[2,1], f32[2]) ( f32[2,1] { {1}, {2} }, {2, 42} )) } )" @@ -587,7 +588,7 @@ ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_ R"(HloModule BasicTraining_module ENTRY %BasicTraining.v4 () -> (f32[2,2,1,2], f32[2], f32[2]) { - %constant = f32[2,2,1,2]{3,2,1,0} constant(f32[2,2,1,2] { { /*i0=0*/ { /*i1=0*/ {1, 2} }, { /*i1=1*/ {3, 4} } }, { /*i0=1*/ { /*i1=0*/ {5, 6} }, { /*i1=1*/ {7, 8} } } }) + %constant = f32[2,2,1,2]{3,2,1,0} constant(f32[2,2,1,2] { { /*i0=0*/ { /*i1=0*/ { 1, 2 } }, { /*i1=1*/ { 3, 4 } } }, { /*i0=1*/ { /*i1=0*/ { 5, 6 } }, { /*i1=1*/ { 7, 8 } } } }) %constant.1 = f32[2]{0} constant({2, 3}) %constant.2 = f32[2]{0} constant({1, 2}) ROOT %batch-norm-training = (f32[2,2,1,2]{3,2,1,0}, f32[2]{0}, f32[2]{0}) batch-norm-training(f32[2,2,1,2]{3,2,1,0} %constant, f32[2]{0} %constant.1, f32[2]{0} %constant.2), epsilon=0.001, feature_index=3 @@ -1241,7 +1242,38 @@ ENTRY Sort { } )" + }, +// AfterAll with multiple operands +{ +"AfterAllWithMultipleOperands", +R"(HloModule AfterAllWithMultipleOperands + +ENTRY AfterAllWithMultipleOperands { + p0 = f32[] parameter(0) + token0 = token[] after-all() + token1 = token[] after-all() + ROOT after-all = token[] after-all(p0, token0, token1) } + +)" +}, +// AddDependency +// A dependency chain is created from 'neg' to 'exp' using tokens. +{ +"AddDependency", +R"(HloModule AddDependency + +ENTRY AddDependency { + p = f32[] parameter(0) + neg = f32[] negate(p) + token = token[] after-all(neg) + p_after_token = f32[] add-dependency(p, token) + exp = f32[] exponential(p_after_token) + ROOT sum = f32[] add(neg, exp) +} + +)" +}, }); // clang-format on } @@ -1862,7 +1894,8 @@ ENTRY ReduceR3ToR2 { )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(original)); ASSERT_NE(module->entry_computation(), nullptr); - EXPECT_THAT(module->entry_computation()->root_instruction(), op::Reduce()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Reduce())); } TEST_F(HloParserTest, ParseSharding) { @@ -1922,7 +1955,7 @@ TEST(HloParserSingleOpTest, SingleOp) { const HloComputation* computation = module->entry_computation(); ASSERT_NE(computation, nullptr); EXPECT_THAT(computation->root_instruction(), - op::Multiply(op::Parameter(0), op::Parameter(1))); + GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(1)))); } TEST(HloParserSingleOpTest, SingleOpNoShapeProducesError) { @@ -1950,7 +1983,7 @@ TEST(HloParserSingleOpTest, SingleOpNoNames) { const HloComputation* computation = module->entry_computation(); ASSERT_NE(computation, nullptr); EXPECT_THAT(computation->root_instruction(), - op::Multiply(op::Parameter(0), op::Parameter(1))); + GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(1)))); } TEST(HloParserSingleOpTest, CanonicalOp) { @@ -1959,7 +1992,7 @@ TEST(HloParserSingleOpTest, CanonicalOp) { const HloComputation* computation = module->entry_computation(); ASSERT_NE(computation, nullptr); EXPECT_THAT(computation->root_instruction(), - op::Multiply(op::Parameter(0), op::Parameter(1))); + GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(1)))); EXPECT_EQ( computation->root_instruction()->ToString(HloPrintOptions::Canonical()), text); @@ -2013,7 +2046,11 @@ TEST(HloParserSingleOpTest, SingleOpWithNested) { const HloComputation* computation = module->entry_computation(); ASSERT_NE(computation, nullptr); EXPECT_THAT(computation->root_instruction(), - op::Fusion(op::Parameter(0), op::Parameter(1))); + GmockMatch(m::Op() + .WithOpcode(HloOpcode::kFusion) + .WithNumOperands(2) + .WithOperand(0, m::Parameter(0)) + .WithOperand(1, m::Parameter(1)))); } TEST(HloParserSingleOpTest, SingleOpWithNested_DoesNotExist) { @@ -2057,7 +2094,7 @@ TEST(HloParserSingleOpTest, ConvolutionTrivialFeatureGroupCount) { const HloComputation* computation = module->entry_computation(); ASSERT_NE(computation, nullptr); EXPECT_THAT(computation->root_instruction(), - op::Convolution(op::Parameter(0), op::Parameter(1))); + GmockMatch(m::Convolution(m::Parameter(0), m::Parameter(1)))); auto* convolution = Cast(computation->root_instruction()); EXPECT_EQ(convolution->feature_group_count(), 1); @@ -2121,8 +2158,10 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { module->schedule().is_computation_scheduled(module->entry_computation())); EXPECT_THAT( module->schedule().sequence(module->entry_computation()).instructions(), - ::testing::ElementsAre(op::Parameter(), op::Broadcast(), op::Parameter(), - op::Multiply(), op::Parameter(), op::Add())); + ::testing::ElementsAre( + GmockMatch(m::Parameter()), GmockMatch(m::Broadcast()), + GmockMatch(m::Parameter()), GmockMatch(m::Multiply()), + GmockMatch(m::Parameter()), GmockMatch(m::Add()))); } TEST_F(HloParserTest, IsScheduledIsTrueDifferentOrder) { @@ -2148,8 +2187,10 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { module->schedule().is_computation_scheduled(module->entry_computation())); EXPECT_THAT( module->schedule().sequence(module->entry_computation()).instructions(), - ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Parameter(), - op::Broadcast(), op::Multiply(), op::Add())); + ::testing::ElementsAre( + GmockMatch(m::Parameter()), GmockMatch(m::Parameter()), + GmockMatch(m::Parameter()), GmockMatch(m::Broadcast()), + GmockMatch(m::Multiply()), GmockMatch(m::Add()))); } TEST_F(HloParserTest, CustomCallWrongNumberofOperandConstraints) { diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.cc b/tensorflow/compiler/xla/service/hlo_proto_util.cc index cf33668f5bf..981d06ce101 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util.cc +++ b/tensorflow/compiler/xla/service/hlo_proto_util.cc @@ -48,7 +48,7 @@ StatusOr> CreateModuleFromProto( return std::move(module); } -StatusOr> EntryComputationParameterShapes( +StatusOr> EntryComputationParameterShapes( const HloProto& hlo_proto) { if (!hlo_proto.has_hlo_module()) { return NotFound("HloProto missing HloModuleProto."); @@ -57,15 +57,16 @@ StatusOr> EntryComputationParameterShapes( return NotFound("HloProto missing program shape."); } - std::vector parameter_shapes; + std::vector parameter_shapes; const auto& program_shape = hlo_proto.hlo_module().host_program_shape(); - for (const Shape& shape : program_shape.parameters()) { + for (const ShapeProto& shape : program_shape.parameters()) { parameter_shapes.push_back(&shape); } return parameter_shapes; } -StatusOr EntryComputationOutputShape(const HloProto& hlo_proto) { +StatusOr EntryComputationOutputShape( + const HloProto& hlo_proto) { if (!hlo_proto.has_hlo_module()) { return NotFound("HloProto missing HloModuleProto."); } diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.h b/tensorflow/compiler/xla/service/hlo_proto_util.h index 1db82dd6fca..31ea2aaffd9 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util.h +++ b/tensorflow/compiler/xla/service/hlo_proto_util.h @@ -43,12 +43,13 @@ StatusOr> CreateModuleFromProto( // Returns the shapes of the parameters of the entry computation. Shape pointers // refer to shapes inside of the given HloProto. -StatusOr> EntryComputationParameterShapes( +StatusOr> EntryComputationParameterShapes( const HloProto& hlo_proto); // Returns the shape of the output of the entry computation. The shape pointer // refers to the output shape inside of the given HloProto. -StatusOr EntryComputationOutputShape(const HloProto& hlo_proto); +StatusOr EntryComputationOutputShape( + const HloProto& hlo_proto); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 49e46ecd00e..48add75523f 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -130,10 +130,10 @@ using ItemList = absl::InlinedVector; // before arbitrary elements. class InstructionList { public: - explicit InstructionList(const std::vector& order) { + explicit InstructionList(const HloInstructionSequence& order) { int64 position = 0; Item* last = nullptr; - for (const HloInstruction* inst : order) { + for (HloInstruction* inst : order.instructions()) { // Add a new item to the linked list. Item* item = new Item; item->next = nullptr; @@ -151,7 +151,7 @@ class InstructionList { // to be monotonically increasing through the list, and so is still useful // for quickly(-ish) determining the order of arbitrary instructions in // the list. - item->instruction = const_cast(inst); + item->instruction = inst; item->position = position; position++; @@ -927,7 +927,7 @@ Item* PickRematerializationCandidate( StatusOr HloRematerialization::ComputePeakMemory( const HloComputation* computation, - const std::vector& order) const { + const HloInstructionSequence& order) const { InstructionList instruction_list(order); MemoryUsageTracker tracker(computation, size_function_, *points_to_analysis_, instruction_list); @@ -971,8 +971,7 @@ StatusOr HloRematerialization::RematerializeComputation( << HumanReadableNumBytes(computation_peak_memory_.at(computation)); CHECK(!ContainsKey(rematerialized_computations_, computation)); - InstructionList instruction_list( - schedule->sequence(computation).instructions()); + InstructionList instruction_list(schedule->sequence(computation)); MemoryUsageTracker memory_tracker(computation, size_function_, *points_to_analysis_, instruction_list); bool changed = false; @@ -1184,7 +1183,7 @@ StatusOr HloRematerialization::RematerializeComputation( sequence.clear(); for (auto* item = instruction_list.first(); item != nullptr; item = instruction_list.next(item)) { - const HloInstruction* instruction = item->instruction; + HloInstruction* instruction = item->instruction; sequence.push_back(instruction); } rematerialized_computations_.insert(computation); @@ -1235,10 +1234,8 @@ StatusOr HloRematerialization::Run(HloModule* module) { if (node.context() == CallContext::kSequential) { TF_ASSIGN_OR_RETURN( computation_peak_memory_[node.computation()], - ComputePeakMemory(node.computation(), - module->schedule() - .sequence(node.computation()) - .instructions())); + ComputePeakMemory(node.computation(), module->schedule().sequence( + node.computation()))); } return Status::OK(); }, diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index 70d83c04f07..a07d348041b 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -87,9 +87,8 @@ class HloRematerialization : public HloModulePass { // peak memory is the maximum total size of all live HLO instruction values at // any program point. 'order' is the order in which the HLO instructions will // be emitted which is used to determine lifespans of HLO values. - StatusOr ComputePeakMemory( - const HloComputation* computation, - const std::vector& order) const; + StatusOr ComputePeakMemory(const HloComputation* computation, + const HloInstructionSequence& order) const; // Returns the peak memory usage of the called computations for the given // instruction. Zero is returned if the instruction calls no computations. diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 3f0ca342b4c..5a9b820a9d7 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -205,6 +205,40 @@ StatusOr HloRunner::ExecuteWithDeviceBuffers( /*profile=*/profile); } +StatusOr HloRunner::ExecuteWithDeviceBuffers( + std::unique_ptr executable, + const absl::Span arguments, + ExecutionProfile* profile) { + // Get service run options. + se::Stream stream(backend().default_stream_executor()); + stream.Init(); + ServiceExecutableRunOptions service_run_options = + GetServiceRunOptionsForDevice(backend().default_device_ordinal(), &stream, + nullptr); + + TF_ASSIGN_OR_RETURN( + ScopedShapedBuffer retval, + executable->ExecuteOnStreamWrapper(&service_run_options, + /*profile=*/profile, arguments)); + TF_RETURN_IF_ERROR(stream.BlockHostUntilDone()); + return std::move(retval); +} + +StatusOr HloRunner::ExecuteWithDeviceBuffers( + std::unique_ptr executable, + const absl::Span arguments, + ExecutionProfile* profile) { + std::vector argument_pointers; + argument_pointers.reserve(arguments.size()); + for (const auto& argument : arguments) { + argument_pointers.push_back(&argument); + } + return ExecuteWithDeviceBuffers( + /*executable=*/std::move(executable), + /*arguments=*/argument_pointers, + /*profile=*/profile); +} + StatusOr> HloRunner::ExecuteReplicated( std::unique_ptr module, const ReplicatedExecuteOptions& options) { diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index 2e934bf66ae..bb792cf8c98 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -136,6 +136,21 @@ class HloRunner { const absl::Span arguments, bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); + StatusOr ExecuteWithDeviceBuffers( + std::unique_ptr executable, + const absl::Span arguments, + ExecutionProfile* profile = nullptr); + + StatusOr ExecuteWithDeviceBuffers( + std::unique_ptr executable, + const absl::Span arguments, + ExecutionProfile* profile = nullptr); + + // Creates an executable object given an HLO module. If run_hlo_passes is + // true, the HLO passes will be run as part of compilation. + StatusOr> CreateExecutable( + std::unique_ptr module, bool run_hlo_passes); + // Executes a given HLO module into a set of replicas, and returns a map // with the replica number as key, and the corresponding returned literal as // value. @@ -152,11 +167,6 @@ class HloRunner { const Backend& backend() const; private: - // Creates an executable object given an HLO module. If run_hlo_passes is - // true, the HLO passes will be run before. - StatusOr> CreateExecutable( - std::unique_ptr module, bool run_hlo_passes); - // Creates a ServiceExecutableRunOptions object to configure a run on device, // using the provided stream object. If device_assignment is not nullptr, it // will be used to configure the replication parameters. Replicated executions diff --git a/tensorflow/compiler/xla/service/hlo_schedule.cc b/tensorflow/compiler/xla/service/hlo_schedule.cc index a5780b7551a..8f6eb974c51 100644 --- a/tensorflow/compiler/xla/service/hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/hlo_schedule.cc @@ -46,8 +46,8 @@ namespace xla { << "No computation exists in HLO module with id " << computation_id; const HloComputation* computation = comp_it->second; - absl::flat_hash_map id_to_instruction; - for (const HloInstruction* instruction : computation->instructions()) { + absl::flat_hash_map id_to_instruction; + for (HloInstruction* instruction : computation->instructions()) { id_to_instruction[instruction->unique_id()] = instruction; } @@ -81,9 +81,8 @@ StatusOr HloSchedule::ToProto() const { return std::move(proto); } -void HloSchedule::set_sequence( - const HloComputation* computation, - absl::Span sequence) { +void HloSchedule::set_sequence(const HloComputation* computation, + absl::Span sequence) { set_sequence(computation, HloInstructionSequence(sequence)); } @@ -114,8 +113,8 @@ Status HloSchedule::UpdateComputationSchedule( const HloComputation* computation) { // Map from unique ID to HloInstruction pointer for instructions in the // computation. - absl::flat_hash_map id_to_instruction; - for (const HloInstruction* instruction : computation->instructions()) { + absl::flat_hash_map id_to_instruction; + for (HloInstruction* instruction : computation->instructions()) { InsertOrDie(&id_to_instruction, instruction->unique_id(), instruction); } @@ -128,7 +127,7 @@ Status HloSchedule::UpdateComputationSchedule( // Map from HloInstruction X to newly added instructions (instruction is in // computation, but not in schedule) which use X. If an instruction is not in // the map, then it has no users which are newly added instructions. - absl::flat_hash_map> + absl::flat_hash_map> new_instruction_uses; // For each newly added instruction, this is the count of the instruction's @@ -138,9 +137,9 @@ Status HloSchedule::UpdateComputationSchedule( // Create a worklist of newly added instructions which are ready to be added // to the schedule. Initialize worklist with those that have zero operands. - std::queue worklist; + std::queue worklist; - for (const HloInstruction* instruction : computation->instructions()) { + for (HloInstruction* instruction : computation->instructions()) { if (ids_in_schedule.count(instruction->unique_id()) == 0) { // This is a newly added instruction which is not in the schedule. if (instruction->operands().empty()) { @@ -161,17 +160,17 @@ Status HloSchedule::UpdateComputationSchedule( // Lambda which schedules all instructions on the worklist. auto schedule_worklist = [&]() { while (!worklist.empty()) { - const HloInstruction* instruction = worklist.front(); + HloInstruction* instruction = worklist.front(); worklist.pop(); new_sequence.push_back(instruction); - std::vector* new_users = + std::vector* new_users = tensorflow::gtl::FindOrNull(new_instruction_uses, instruction); if (new_users != nullptr) { // This just-scheduled instruction has users which are newly added to // the module. Update the number of unscheduled operands and push the // newly added instruction to the worklist if it is ready to // schedule. - for (const HloInstruction* new_user : *new_users) { + for (HloInstruction* new_user : *new_users) { unscheduled_operand_count.at(new_user)--; CHECK_GE(unscheduled_operand_count.at(new_user), 0); if (unscheduled_operand_count.at(new_user) == 0) { diff --git a/tensorflow/compiler/xla/service/hlo_schedule.h b/tensorflow/compiler/xla/service/hlo_schedule.h index 0a714101ee5..486ddbf499d 100644 --- a/tensorflow/compiler/xla/service/hlo_schedule.h +++ b/tensorflow/compiler/xla/service/hlo_schedule.h @@ -35,14 +35,14 @@ class HloInstructionSequence { public: HloInstructionSequence() = default; explicit HloInstructionSequence( - absl::Span instructions) { - for (const HloInstruction* instruction : instructions) { + absl::Span instructions) { + for (HloInstruction* instruction : instructions) { push_back(instruction); } } // Adds the instruction to the end of the sequence. - void push_back(const HloInstruction* instruction) { + void push_back(HloInstruction* instruction) { instruction_sequence_.push_back(instruction); id_sequence_.push_back(instruction->unique_id()); } @@ -56,7 +56,7 @@ class HloInstructionSequence { int64 size() const { return instruction_sequence_.size(); } // Returns the sequence of HLO instructions. - const std::vector& instructions() const { + const std::vector& instructions() const { return instruction_sequence_; } @@ -65,7 +65,7 @@ class HloInstructionSequence { private: // The sequence as HloInstructions. - std::vector instruction_sequence_; + std::vector instruction_sequence_; // The sequence of HLO instructions, represented by their unique IDs. The // sequence is stored as both HloInstructions and unique IDs because the @@ -98,7 +98,7 @@ class HloSchedule { // Sets the sequence for the given computation to the given sequence. void set_sequence(const HloComputation* computation, - absl::Span sequence); + absl::Span sequence); void set_sequence(const HloComputation* computation, HloInstructionSequence sequence); diff --git a/tensorflow/compiler/xla/service/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/hlo_schedule_test.cc index 1424569ac1f..0e56e6f760e 100644 --- a/tensorflow/compiler/xla/service/hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/hlo_schedule_test.cc @@ -56,10 +56,10 @@ ENTRY main { ParseHloString(module_str)); TF_ASSERT_OK_AND_ASSIGN( HloSchedule schedule, - ScheduleModule(*module, [](const BufferValue& buffer) { + ScheduleModule(module.get(), [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape()); })); - const std::vector& entry_schedule = + const auto& entry_schedule = schedule.sequence(module->entry_computation()).instructions(); EXPECT_EQ(entry_schedule.size(), 6); @@ -90,7 +90,7 @@ ENTRY main { ParseHloString(module_str)); TF_ASSERT_OK_AND_ASSIGN( HloSchedule schedule, - ScheduleModule(*module, [](const BufferValue& buffer) { + ScheduleModule(module.get(), [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape()); })); @@ -139,7 +139,7 @@ ENTRY main { ParseHloString(module_str)); TF_ASSERT_OK_AND_ASSIGN( HloSchedule schedule, - ScheduleModule(*module, [](const BufferValue& buffer) { + ScheduleModule(module.get(), [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape()); })); @@ -183,7 +183,7 @@ ENTRY main { ParseHloString(module_str)); TF_ASSERT_OK_AND_ASSIGN( HloSchedule schedule, - ScheduleModule(*module, [](const BufferValue& buffer) { + ScheduleModule(module.get(), [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape()); })); @@ -244,7 +244,7 @@ ENTRY %WhileLoop () -> s32[] { ParseHloString(module_str)); TF_ASSERT_OK_AND_ASSIGN( HloSchedule schedule, - ScheduleModule(*module, [](const BufferValue& buffer) { + ScheduleModule(module.get(), [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/sizeof(void*)); })); @@ -313,7 +313,7 @@ ENTRY %WhileLoop () -> s32[] { ParseHloString(module_str)); TF_ASSERT_OK_AND_ASSIGN( HloSchedule schedule, - ScheduleModule(*module, [](const BufferValue& buffer) { + ScheduleModule(module.get(), [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/sizeof(void*)); })); diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc index 88329c89979..f5061304456 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc @@ -253,7 +253,7 @@ StatusOr ApplyShardingFromUsers(HloInstruction* instruction, instruction->shape(), HloSharding::AssignDevice(kUnassignedDevice)); for (HloInstruction* user : instruction->users()) { if (user->opcode() == HloOpcode::kDomain && - domain.exit_domains.count(const_cast(user)) > 0) { + domain.exit_domains.count(user) > 0) { // If a user is a domain and it is registered in the domain exits, then // the instruction sharding is taken directly from the domain, and no // further users need to be visited. diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc index 11994d99c93..c1073911ea9 100644 --- a/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc +++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc @@ -66,7 +66,7 @@ class HloSubcomputationUnificationTest : public HloTestBase { }; TEST_F(HloSubcomputationUnificationTest, UnifyIdentities) { - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto callee1 = @@ -103,7 +103,7 @@ TEST_F(HloSubcomputationUnificationTest, UnifyIdentities) { } TEST_F(HloSubcomputationUnificationTest, UnifyAdditions) { - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto callee1 = @@ -184,7 +184,7 @@ TEST_F(HloSubcomputationUnificationTest, DifferentParameterShapes) { // Regression test for b/31466798. Checks that entry_computation is still valid // after unification. TEST_F(HloSubcomputationUnificationTest, TwoIdenticalComputations) { - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); for (int i = 0; i < 2; ++i) { HloComputation::Builder builder("pow"); auto x = diff --git a/tensorflow/compiler/xla/service/hlo_value.h b/tensorflow/compiler/xla/service/hlo_value.h index b6670d409b9..1f01b0bb365 100644 --- a/tensorflow/compiler/xla/service/hlo_value.h +++ b/tensorflow/compiler/xla/service/hlo_value.h @@ -166,9 +166,6 @@ class HloValue : public BufferValue { // Whether this value is live out of the HLO module. bool live_out_of_module_ = false; - - // Whether this value is live out of its computation. - bool live_out_of_computation_ = false; }; std::ostream& operator<<(std::ostream& out, const HloValue& hlo_value); diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 27fd685a69a..77db7b098a3 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -753,13 +753,19 @@ Status ShapeVerifier::HandleAfterAll(HloInstruction* token) { for (const HloInstruction* operand : token->operands()) { operand_shapes.push_back(&operand->shape()); } - return CheckShape(token, ShapeInference::InferAfterAllShape(operand_shapes)); + return CheckShape(token, ShapeUtil::MakeTokenShape()); +} + +Status ShapeVerifier::HandleAddDependency(HloInstruction* add_dependency) { + TF_RETURN_IF_ERROR(CheckOperandCount(add_dependency, 2)); + TF_RETURN_IF_ERROR(CheckIsTokenOperand(add_dependency, 1)); + return CheckShape(add_dependency, add_dependency->operand(0)->shape()); } Status ShapeVerifier::HandleGetDimensionSize(HloInstruction* get_size) { - return CheckShape( - get_size, ShapeInference::InferGetDimensionSizeShape( - get_size->operand(0)->shape(), get_size->dimensions(0))); + return CheckShape(get_size, + ShapeInference::InferGetDimensionSizeShape( + get_size->operand(0)->shape(), get_size->dimension())); } Status ShapeVerifier::CheckShape(const HloInstruction* instruction, @@ -1373,9 +1379,8 @@ class InstructionVerifier : public DfsHloVisitorWithDefault { const Layout& operand_layout = operand_shape.layout(); TF_RET_CHECK(LayoutUtil::Equal(result_layout, operand_layout)) << "Instruction shouldn't change layouts " - << instruction->ToString() << " From " - << ShapeUtil::HumanString(result_shape) << " To " - << ShapeUtil::HumanString(operand_shape); + << instruction->ToString() << " From " << result_shape << " To " + << operand_shape; } } } @@ -1426,6 +1431,8 @@ StatusOr HloVerifier::Run(HloModule* module) { return target_metadata_->ShapeSize(shape); })); + TF_RETURN_IF_ERROR(module->dynamic_parameter_binding().Verify(*module)); + return false; } diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 9fbfd6a21c1..e4d0c3d6957 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -95,6 +95,7 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleScatter(HloInstruction* scatter) override; Status HandleAfterAll(HloInstruction* token) override; Status HandleGetDimensionSize(HloInstruction* get_size) override; + Status HandleAddDependency(HloInstruction* add_dependency) override; Status FinishVisit(HloInstruction*) override { return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index 5ddfe0a944f..4bc557e4e62 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -35,6 +35,10 @@ namespace { using ::testing::HasSubstr; +std::unique_ptr CreateUnverifiedModule() { + return absl::make_unique("module", HloModuleConfig()); +} + // This class cannot be converted to use HloTestBase. It explicitly // uses HloTestBase to create and test malformed HLOs. class HloVerifierTest : public HloTestBase { @@ -66,7 +70,7 @@ TEST_F(HloVerifierTest, NullInstructionParent) { HloInstruction::CreateParameter(0, scalar_shape, "param")); HloInstruction* negate = builder.AddInstruction( HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateUnverifiedModule(); module->AddEntryComputation(builder.Build()); TF_ASSERT_OK(verifier().Run(module.get()).status()); @@ -85,7 +89,7 @@ TEST_F(HloVerifierTest, NullComputationParent) { HloInstruction::CreateParameter(0, scalar_shape, "param")); builder.AddInstruction( HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateUnverifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); TF_ASSERT_OK(verifier().Run(module.get()).status()); @@ -104,7 +108,7 @@ TEST_F(HloVerifierTest, DifferentOperandParents) { HloInstruction::CreateParameter(0, scalar_shape, "param")); HloInstruction* negate = builder.AddInstruction( HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateUnverifiedModule(); module->AddEntryComputation(builder.Build()); HloComputation::Builder emb_builder(TestName()); @@ -138,7 +142,7 @@ TEST_F(HloVerifierTest, ResetsShapeVerifierState) { builder.AddInstruction( HloInstruction::CreateBinary(s2, HloOpcode::kMultiply, add, add)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateUnverifiedModule(); module->AddEntryComputation(builder.Build()); // Run the verifier twice. It should fail both times, because it shouldn't @@ -303,7 +307,7 @@ TEST_F(HloVerifierTest, NegativeInteriorPaddingNotAllowed) { HloInstruction::CreateConstant(LiteralUtil::Zero(F32))), padding_config)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateUnverifiedModule(); module->AddEntryComputation(builder.Build()); auto status = verifier().Run(module.get()).status(); @@ -327,7 +331,7 @@ TEST_F(HloVerifierTest, PadNegativeInteriorDilationNotAllowed) { HloInstruction::CreateConstant(LiteralUtil::Zero(F32).Clone())), padding_config)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateUnverifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_THAT(verifier().Run(module.get()).status().error_message(), diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc index 20cc18f9815..98246d5403e 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc @@ -481,8 +481,8 @@ ENTRY main { const char* expected_root_expression = R"( (scalar-indexed-const (constant s32[2,1,1,1,6] s32[2,1,1,1,6] { - { /*i0=0*/ { /*i1=0*/ { /*i2=0*/ {1, 2, 3, 4, 5, 6} } } }, - { /*i0=1*/ { /*i1=0*/ { /*i2=0*/ {1, 2, 3, 4, 5, 6} } } } }) + { /*i0=0*/ { /*i1=0*/ { /*i2=0*/ { 1, 2, 3, 4, 5, 6 } } } }, + { /*i0=1*/ { /*i1=0*/ { /*i2=0*/ { 1, 2, 3, 4, 5, 6 } } } } }) (reshape %indices to s32[]) 0->[]) )"; @@ -512,8 +512,8 @@ ENTRY main { const char* expected_root_expression = R"( (scalar-indexed-const (constant s32[2,1,1,6] s32[2,1,1,6] { - { /*i0=0*/ { /*i1=0*/ {1, 2, 3, 4, 5, 6} } }, - { /*i0=1*/ { /*i1=0*/ {1, 2, 3, 4, 5, 6} } } }) + { /*i0=0*/ { /*i1=0*/ { 1, 2, 3, 4, 5, 6 } } }, + { /*i0=1*/ { /*i1=0*/ { 1, 2, 3, 4, 5, 6 } } } }) (reshape %indices to s32[5]) 0->[2]) )"; diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 7f2d7e7cffc..7559ed1bab8 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -103,7 +103,6 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kShiftRightLogical: case HloOpcode::kSlice: case HloOpcode::kSubtract: - case HloOpcode::kAfterAll: case HloOpcode::kTranspose: case HloOpcode::kTuple: case HloOpcode::kTupleSelect: @@ -116,7 +115,10 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kSin: return ShapeUtil::ElementIsComplex(instruction.shape()); - // Expensive instructions. + // Expensive instructions or unusual instructions for which fusion is + // nonsensical. + case HloOpcode::kAddDependency: + case HloOpcode::kAfterAll: case HloOpcode::kAtan2: case HloOpcode::kBatchNormGrad: case HloOpcode::kBatchNormInference: @@ -455,8 +457,13 @@ StatusOr InstructionFusion::Run(HloModule* module) { computation_ = computation; reachability_ = HloReachabilityMap::Build(computation_); - HloInstructionSet do_not_duplicate = - ComputeGloballyUnfusible(computation_->MakeInstructionPostOrder()); + HloInstructionSet do_not_duplicate; + // If we allow duplications, we need to compute which instructions we do not + // want to duplicate based on a global analysis of the graph. + if (may_duplicate_) { + do_not_duplicate = + ComputeGloballyUnfusible(computation_->MakeInstructionPostOrder()); + } auto fusion_queue = GetFusionQueue(computation_); // Instruction fusion effectively fuses edges in the computation graph @@ -564,8 +571,8 @@ HloInstruction* InstructionFusion::FuseIntoMultiOutput( bool InstructionFusion::MultiOutputFusionCreatesCycle( HloInstruction* producer, HloInstruction* consumer) { auto is_reachable = [&](const HloInstruction* a, const HloInstruction* b) { - // A consumer operand may have been multii-output fused into a parallel - // consumer and thus be missing from the oridinal reachability map. + // A consumer operand may have been multi-output fused into a parallel + // consumer and thus be missing from the original reachability map. if (!reachability_->IsPresent(a) || !reachability_->IsPresent(b)) { reachability_ = HloReachabilityMap::Build(consumer->parent()); } diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index 39904bd54b0..58b7135cea7 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -117,7 +117,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfParameterUnfused) { auto reshape1 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {1, 1}), param0)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(reshape1, computation->root_instruction()); EXPECT_FALSE( @@ -133,7 +133,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastSimpleReshapeOfParameterUnfused) { auto reshape1 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {1, 1}), param0)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(reshape1, computation->root_instruction()); EXPECT_FALSE( @@ -149,7 +149,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) { auto transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {}), param0, {})); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(transpose1, computation->root_instruction()); EXPECT_FALSE( @@ -394,6 +394,56 @@ TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) { .ValueOrDie()); } +TEST_F(InstructionFusionTest, FuseDiamondGraphsNoDuplication) { + auto module = ParseHloString(R"( + HloModule test_module + ENTRY Test { + p0 = f32[100] parameter(0) + p1 = f32[100] parameter(1) + add = f32[100] add(p0, p1) + slice1 = f32[99] slice(add), slice={[0:99:1]} + slice2 = f32[99] slice(add), slice={[1:100:1]} + ROOT add2 = f32[99] add(slice1, slice2) + })") + .ValueOrDie(); + EXPECT_TRUE( + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/false) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); + + HloInstruction* root = module->entry_computation()->root_instruction(); + // 'add' would originally need to be duplicated if fused. However after its + // two users 'slice1' and 'slice2' are fused into 'add2', 'add' has only one + // user and can now be also fused. + EXPECT_THAT(root, op::Fusion(op::Parameter(), op::Parameter())); +} + +TEST_F(InstructionFusionTest, FuseDiamondGraphsAllowDuplication) { + auto module = ParseHloString(R"( + HloModule test_module + ENTRY Test { + p0 = f32[100] parameter(0) + p1 = f32[100] parameter(1) + add = f32[100] add(p0, p1) + slice1 = f32[99] slice(add), slice={[0:99:1]} + slice2 = f32[99] slice(add), slice={[1:100:1]} + ROOT add2 = f32[99] add(slice1, slice2) + })") + .ValueOrDie(); + EXPECT_TRUE( + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); + + HloInstruction* root = module->entry_computation()->root_instruction(); + // 'add' would originally need to be duplicated if fused. However after its + // two users 'slice1' and 'slice2' are fused into 'add2', 'add' has only one + // user and can now be also fused. + EXPECT_THAT(root, op::Fusion(op::Parameter(), op::Parameter())); +} + TEST_F(InstructionFusionTest, WideningConvertsAreAlwaysDuplicableIntoConsumers) { auto module = ParseHloString(R"( diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index a06d6113e84..7635fbfed6f 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -37,7 +37,7 @@ namespace xla { namespace interpreter { InterpreterExecutable::InterpreterExecutable( - std::unique_ptr hlo_module, + std::unique_ptr hlo_module, std::unique_ptr evaluator) : Executable(std::move(hlo_module), /*hlo_profile_printer=*/nullptr, /*hlo_profile_index_map=*/nullptr), diff --git a/tensorflow/compiler/xla/service/interpreter/executable.h b/tensorflow/compiler/xla/service/interpreter/executable.h index 3b1ebce0c75..bda13d37636 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.h +++ b/tensorflow/compiler/xla/service/interpreter/executable.h @@ -42,7 +42,7 @@ namespace interpreter { // buffer allocation. Refer to interpreter/README.md for more. class InterpreterExecutable : public Executable { public: - InterpreterExecutable(std::unique_ptr hlo_module, + InterpreterExecutable(std::unique_ptr hlo_module, std::unique_ptr evaluator); ~InterpreterExecutable() override; diff --git a/tensorflow/compiler/xla/service/interpreter/executor.cc b/tensorflow/compiler/xla/service/interpreter/executor.cc index 4fb67bd0b72..e3e5fa71543 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.cc +++ b/tensorflow/compiler/xla/service/interpreter/executor.cc @@ -78,9 +78,14 @@ port::Status XlaInterpreterExecutor::SynchronousMemcpy( return port::Status::OK(); } -bool XlaInterpreterExecutor::HostCallback(Stream *stream, - std::function callback) { - AsExecutorStream(stream)->EnqueueTask(callback); +bool XlaInterpreterExecutor::HostCallback( + Stream *stream, std::function callback) { + AsExecutorStream(stream)->EnqueueTask([callback]() { + port::Status s = callback(); + if (!s.ok()) { + LOG(WARNING) << "Host callback failed: " << s; + } + }); return true; } diff --git a/tensorflow/compiler/xla/service/interpreter/executor.h b/tensorflow/compiler/xla/service/interpreter/executor.h index fbb99457847..400c3051546 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.h +++ b/tensorflow/compiler/xla/service/interpreter/executor.h @@ -125,7 +125,8 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface { return port::Status{port::error::UNIMPLEMENTED, ""}; } - bool HostCallback(Stream *stream, std::function callback) override; + bool HostCallback(Stream *stream, + std::function callback) override; port::Status AllocateEvent(Event *event) override { return port::Status{port::error::UNIMPLEMENTED, ""}; diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index a9041192220..eddef850cf5 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -2000,6 +2000,7 @@ bool LayoutAssignment::InstructionCanChangeLayout( switch (instruction->opcode()) { case HloOpcode::kAbs: case HloOpcode::kAdd: + case HloOpcode::kAddDependency: case HloOpcode::kAnd: case HloOpcode::kAtan2: case HloOpcode::kBitcastConvert: diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 2400b7bb7c4..311bd789054 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -31,6 +31,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" @@ -42,11 +44,10 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" -namespace op = xla::testing::opcode_matchers; - namespace xla { namespace { +namespace m = xla::match; using ::testing::ElementsAre; class LayoutAssignmentTest : public HloTestBase { @@ -328,11 +329,10 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { // %tuple.1 = Tuple(%copy) layout=({0,1}) // %tuple.2 = Tuple(%tuple.0, %tuple.1) layout=(({1,0}), ({0,1})) // - EXPECT_TRUE( - AlgebraicSimplifier(/*is_layout_sensitive=*/true, - [](const Shape&, const Shape&) { return false; }) - .Run(m.get()) - .ValueOrDie()); + AlgebraicSimplifierOptions options( + [](const Shape&, const Shape&) { return false; }); + options.set_is_layout_sensitive(true); + EXPECT_TRUE(AlgebraicSimplifier(options).Run(m.get()).ValueOrDie()); HloInstruction* root = m->entry_computation()->root_instruction(); // Verify layout of the root and the root's operands. EXPECT_TRUE(ShapeUtil::Equal(result_shape, root->shape())); @@ -343,7 +343,8 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { // Verify the structure of the HLO graph. EXPECT_THAT(root, - op::Tuple(op::Tuple(constant), op::Tuple(op::Copy(constant)))); + GmockMatch(m::Tuple(m::Tuple(m::Op().Is(constant)), + m::Tuple(m::Copy(m::Op().Is(constant)))))); } TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) { @@ -947,9 +948,11 @@ TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) { HloInstruction* root = compiled_module->entry_computation()->root_instruction(); Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0}); - EXPECT_THAT(root, op::Add(op::Parameter(), - op::Slice(AllOf(op::Copy(op::Parameter(1)), - op::ShapeWithLayout(shape_copy))))); + EXPECT_THAT( + root, + GmockMatch(m::Add( + m::Parameter(), + m::Slice(m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy))))); } TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) { @@ -977,10 +980,11 @@ TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) { compiled_module->entry_computation()->root_instruction(); Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0}); EXPECT_THAT(root, - op::Add(op::Parameter(), - op::DynamicSlice(AllOf(op::Copy(op::Parameter(1)), - op::ShapeWithLayout(shape_copy)), - op::Parameter(2)))); + GmockMatch(m::Add( + m::Parameter(), + m::DynamicSlice( + m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy), + m::Parameter(2))))); } TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) { @@ -1008,11 +1012,12 @@ TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) { HloInstruction* root = compiled_module->entry_computation()->root_instruction(); Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {3, 5}, {1, 0}); - EXPECT_THAT(root, - op::Add(op::Parameter(), - op::Concatenate(AllOf(op::Copy(op::Parameter(1)), - op::ShapeWithLayout(shape_copy)), - op::Parameter(2)))); + EXPECT_THAT( + root, + GmockMatch(m::Add( + m::Parameter(), + m::Concatenate(m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy), + m::Parameter(2))))); } TEST_F(LayoutAssignmentTest, @@ -1039,7 +1044,8 @@ TEST_F(LayoutAssignmentTest, .ConsumeValueOrDie(); HloInstruction* root = compiled_module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Convolution(op::Parameter(0), op::Parameter(1))); + EXPECT_THAT(root, + GmockMatch(m::Convolution(m::Parameter(0), m::Parameter(1)))); } TEST_F(LayoutAssignmentTest, PropagatingLayoutFromResultToOperand) { @@ -1063,8 +1069,9 @@ TEST_F(LayoutAssignmentTest, PropagatingLayoutFromResultToOperand) { HloInstruction* root = compiled_module->entry_computation()->root_instruction(); Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {0, 1}); - EXPECT_THAT(root, op::Slice(AllOf(op::Copy(op::Parameter(0)), - op::ShapeWithLayout(shape_copy)))); + EXPECT_THAT(root, + GmockMatch(m::Slice( + m::Copy(m::Parameter(0)).WithShapeEqualTo(&shape_copy)))); } TEST_F(LayoutAssignmentTest, TupleCopyOnLayoutMismatch) { @@ -1150,7 +1157,7 @@ ENTRY %CustomCallWithNotLayoutConstrained (p: f32[42,2,3]) -> f32[1,2,3,4] { AssignLayouts(m.get(), &computation_layout); HloInstruction* root = m->entry_computation()->root_instruction(); - ASSERT_THAT(root, op::CustomCall(op::Parameter())); + ASSERT_THAT(root, GmockMatch(m::CustomCall(m::Parameter()))); ExpectLayoutIs(root->shape(), {3, 2, 0, 1}); ExpectLayoutIs(root->operand(0)->shape(), {0, 2, 1}); } @@ -1166,7 +1173,7 @@ ENTRY %CustomCallWithNotLayoutConstrained (p: f32[42,2,3]) -> f32[1,2,3,4] { AssignLayouts(m.get(), &computation_layout); HloInstruction* root = m->entry_computation()->root_instruction(); - ASSERT_THAT(root, op::CustomCall(op::Parameter())); + ASSERT_THAT(root, GmockMatch(m::CustomCall(m::Parameter()))); ExpectLayoutIs(root->shape(), {0, 2, 3, 1}); ExpectLayoutIs(root->operand(0)->shape(), {0, 1, 2}); } @@ -1197,7 +1204,7 @@ ENTRY %CustomCallWithLayoutConstraints (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3 // The custom call should be partially encapsulated in kCopy instructions // because of the layout mismatches. ASSERT_THAT(m->entry_computation()->root_instruction(), - op::Copy(op::CustomCall(op::Copy(), op::Parameter()))); + GmockMatch(m::Copy(m::CustomCall(m::Copy(), m::Parameter())))); const HloInstruction* custom_call = m->entry_computation()->root_instruction()->operand(0); @@ -1223,7 +1230,7 @@ ENTRY %CustomCallLayoutConstrainedZeroOperands () -> f32[1,2,3,4] { AssignLayouts(m.get(), &computation_layout); ASSERT_THAT(m->entry_computation()->root_instruction(), - op::Copy(op::CustomCall())); + GmockMatch(m::Copy(m::CustomCall()))); const HloInstruction* custom_call = m->entry_computation()->root_instruction()->operand(0); @@ -1257,7 +1264,7 @@ ENTRY %CustomCallLayoutConstrainedTupleOperand (p0: f32[4,4], p1: f32[2,3]) -> f ExpectLayoutIs(root->shape(), {2, 1, 0, 3}); ASSERT_THAT(m->entry_computation()->root_instruction(), - op::Copy(op::CustomCall(op::Tuple()))); + GmockMatch(m::Copy(m::CustomCall(m::Tuple())))); const HloInstruction* custom_call = m->entry_computation()->root_instruction()->operand(0); diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index f4b05f29c38..d6d84994ee1 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -25,6 +25,7 @@ limitations under the License. #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" @@ -108,6 +109,14 @@ class IrArray { Index(absl::Span multidim, llvm::Value* linear, const Shape& shape); + // Returns an index that adds `addend` to the given `dim` of the object. + Index AddOffsetToDim(llvm::Value* addend, int64 dim, + llvm::IRBuilder<>* b) const { + IrArray::Index index = *this; + index[dim] = b->CreateAdd(index[dim], addend); + return index; + } + const std::vector& multidim() const { return multidim_; } llvm::Value* linear() const { return linear_; } diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc index e5fbdbd51b8..1aa85eb8d2d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc @@ -52,6 +52,29 @@ Shape MergeDimensions(absl::Span segs, const Shape& shape) { return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(), dimensions); } + +// Given an index for a shape, return the equivalent new index if the shape is +// reshaped to another shape. +IrArray::Index GetReshapedIndex(const IrArray::Index& index, const Shape& shape, + const Shape& reshaped_shape, + llvm::IRBuilder<>* b) { + auto bounds = shape.dimensions(); + auto minor_to_major = shape.layout().minor_to_major(); + llvm::Value* linear_index = index.GetConstantWithIndexType(0); + int64 multiplier = 1; + for (int i = 0; i < index.size(); ++i) { + int64 dim = minor_to_major[i]; + llvm::Value* addend = b->CreateMul( + index[dim], index.GetConstantWithIndexType(multiplier), "linearizing", + /*HasNUW=*/true, /*HasNSW=*/true); + linear_index = b->CreateAdd(linear_index, addend, "", + /*HasNUW=*/true, /*HasNSW=*/true); + multiplier *= bounds[dim]; + } + + return IrArray::Index(linear_index, reshaped_shape, b); +} + } // namespace absl::optional > FindTranspose021(const Shape& a, @@ -60,28 +83,30 @@ absl::optional > FindTranspose021(const Shape& a, return absl::nullopt; } - std::vector perm(a.dimensions().size()); - { - auto layout_a_orig = LayoutUtil::MinorToMajor(a); - std::vector layout_a(layout_a_orig.rbegin(), layout_a_orig.rend()); - auto layout_b_orig = LayoutUtil::MinorToMajor(b); - std::vector layout_b(layout_b_orig.rbegin(), layout_b_orig.rend()); - for (size_t i = 0; i < perm.size(); ++i) { - perm[i] = PositionInContainer(layout_b, layout_a[i]); - } + std::vector permutation(a.dimensions().size()); + absl::Span minor_to_major_a = LayoutUtil::MinorToMajor(a); + std::vector major_to_minor_a(minor_to_major_a.rbegin(), + minor_to_major_a.rend()); + absl::Span minor_to_major_b = LayoutUtil::MinorToMajor(b); + std::vector major_to_minor_b(minor_to_major_b.rbegin(), + minor_to_major_b.rend()); + for (size_t i = 0; i < permutation.size(); ++i) { + permutation[i] = PositionInContainer(major_to_minor_b, major_to_minor_a[i]); } - auto segs = ConsecutiveSegments(perm); - if ((3 == segs.size() && 0 == perm[0]) || 2 == segs.size()) { - Shape norm_a = + + std::vector segments = ConsecutiveSegments(permutation); + if ((3 == segments.size() && 0 == permutation[0]) || 2 == segments.size()) { + Shape descending_layout_shape = ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a); - Shape reduced_a = MergeDimensions(segs, norm_a); - auto reduced_a_dims = reduced_a.dimensions(); + Shape normalized_shape = MergeDimensions(segments, descending_layout_shape); + absl::Span normalized_dims = + AsInt64Slice(normalized_shape.dimensions()); std::vector dims_021; - if (2 == segs.size()) { + if (2 == segments.size()) { // The logical component-0 is of size one. - dims_021 = {1, reduced_a_dims[1], reduced_a_dims[0]}; + dims_021 = {1, normalized_dims[1], normalized_dims[0]}; } else { - dims_021 = {reduced_a_dims[0], reduced_a_dims[2], reduced_a_dims[1]}; + dims_021 = {normalized_dims[0], normalized_dims[2], normalized_dims[1]}; } return dims_021; @@ -90,27 +115,117 @@ absl::optional > FindTranspose021(const Shape& a, return absl::nullopt; } -IrArray::Index GetUnreducedOutputIndex( - const IrArray::Index& reduced_output_index, - const Shape& reduced_output_shape, const Shape& unreduced_output_shape, - llvm::IRBuilder<>* b) { - auto bounds = reduced_output_shape.dimensions(); - auto minor_to_major = reduced_output_shape.layout().minor_to_major(); - llvm::Value* linear_index = reduced_output_index.GetConstantWithIndexType(0); - int64 multiplier = 1; - for (int i = 0; i < reduced_output_index.size(); ++i) { - int64 dim = minor_to_major[i]; - llvm::Value* addend = - b->CreateMul(reduced_output_index[dim], - reduced_output_index.GetConstantWithIndexType(multiplier), - "linearizing", - /*HasNUW=*/true, /*HasNSW=*/true); - linear_index = b->CreateAdd(linear_index, addend, "", - /*HasNUW=*/true, /*HasNSW=*/true); - multiplier *= bounds[dim]; - } +KernelMappingScheme::KernelMappingScheme( + absl::Span dims_in_elems, int64 tile_size_y, int64 tile_size_x, + absl::Span req_block_sizes, int64 num_threads_y, + int64 num_threads_x, llvm::IRBuilder<>* b) + : b_(b), + dims_in_elems_(dims_in_elems.begin(), dims_in_elems.end()), + tile_sizes_{1, tile_size_y, tile_size_x}, + num_threads_x_(num_threads_x), + num_threads_y_(num_threads_y) { + DCHECK_EQ(dims_in_elems_.size(), 3); + DCHECK_EQ(req_block_sizes.size(), 3); - return IrArray::Index(linear_index, unreduced_output_shape, b); + DCHECK_EQ(tile_size_y % num_threads_y_, 0); + DCHECK_EQ(tile_size_x % num_threads_x_, 0); + + dims_in_tiles_ = ElementWiseCeilOfRatio(dims_in_elems_, tile_sizes_); + block_sizes_.reserve(req_block_sizes.size()); + absl::c_transform(req_block_sizes, dims_in_tiles_, + std::back_inserter(block_sizes_), + [](const int64 requested_size, const int64 max_size) { + return std::min(requested_size, max_size); + }); + dims_in_blocks_ = ElementWiseCeilOfRatio(dims_in_tiles_, block_sizes_); + + VLOG(10) << "dims_in_elems_ = [" << absl::StrJoin(dims_in_elems_, ",") << "]"; + VLOG(10) << "dims_in_tiles_ = [" << absl::StrJoin(dims_in_tiles_, ",") << "]"; + VLOG(10) << "dims_in_blocks_ = [" << absl::StrJoin(dims_in_blocks_, ",") + << "]"; +} + +IrArray::Index KernelMappingScheme::GetUnnormalizedIndex( + const IrArray::Index& normalized_shape_index, + const Shape& unnormalized_shape) { + DCHECK_EQ(normalized_shape_index.size(), dims_in_elems_.size()); + Shape output_shape = ShapeUtil::MakeShapeWithDescendingLayout( + unnormalized_shape.element_type(), GetDimensionsInElements()); + return GetReshapedIndex(normalized_shape_index, output_shape, + unnormalized_shape, b_); +} + +IrArray::Index KernelMappingScheme::EmitBlockIndex(llvm::Type* index_ty) { + llvm::Value* block_id = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, b_); + llvm_ir::AddRangeMetadata(0, GetNumberOfBlocks(), + llvm::cast(block_id)); + llvm::Value* linear_block_id = + b_->CreateIntCast(block_id, index_ty, /*isSigned=*/true, "block.id.x"); + return IrArray::Index(linear_block_id, + ShapeUtil::MakeShapeWithDescendingLayout( + PRED /*arbitrary*/, dims_in_blocks_), + b_); +} + +IrArray::Index KernelMappingScheme::GetTileIndexForBlockOrigin( + const IrArray::Index& block_index) { + IrArray::Index tile_index = block_index; + for (int i = 0; i < block_sizes_.size(); ++i) { + tile_index[i] = b_->CreateMul( + block_index[i], + llvm::ConstantInt::get(block_index[i]->getType(), block_sizes_[i]), + "block_origin." + std::to_string(i)); + } + return tile_index; +} + +IrArray::Index KernelMappingScheme::GetElementIndexForTileOrigin( + const IrArray::Index& tile_index) { + IrArray::Index elem_index = tile_index; + for (int i = DimY; i < DimTot; ++i) { + elem_index[i] = + b_->CreateMul(tile_index[i], + llvm::ConstantInt::get(tile_index[i]->getType(), + GetTileSizeForDimension(i)), + "tile_origin." + std::to_string(i)); + } + return elem_index; +} + +llvm::GlobalVariable* KernelMappingScheme::GetSharedMemoryBufferForElementType( + llvm::Type* elem_ty, absl::string_view buffer_name) { + // If shared memory tranpose is needed, we use square tiles. + CHECK_EQ(GetTileSizeForDimensionX(), GetTileSizeForDimensionY()); + + // For Nvidia GPUs, the warp size is 32 threads and the shared memory bank is + // organized into 32-way. We usually use the warp size or a multiplier or a + // the warp size as the size for tiling. This may cause all elements in the + // same column of a tile use the same memory bank and therefore shared memory + // bank conflicts. Adding 1 to the minor dimension of the shared memory buffer + // can reduce such shared memory bank conflicts. + llvm::Type* buffer_type = llvm::ArrayType::get( + llvm::ArrayType::get(elem_ty, GetTileSizeForDimension(DimX) + 1), + GetTileSizeForDimension(DimY)); + return llvm_ir::AllocateSharedMemoryTile(b_->GetInsertBlock()->getModule(), + buffer_type, buffer_name); +} + +std::tuple +KernelMappingScheme::EmitThreadYXCoordinate(llvm::Type* index_ty) { + // Calculate (y, x) coordinate of the thread in the 2D view of thread block + // defined by (num_thread_y, num_thread_x) from thread_id. + llvm::CallInst* thread_id_raw = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b_); + llvm_ir::AddRangeMetadata(0, GetThreadsPerTile(), thread_id_raw); + llvm::Value* thread_id_int = + b_->CreateIntCast(thread_id_raw, index_ty, + /*isSigned=*/true, "thread.id.x"); + llvm::Value* num_thread_x = + llvm::ConstantInt::get(index_ty, GetNumberOfThreadsForDimensionX()); + llvm::Value* x = b_->CreateURem(thread_id_int, num_thread_x); + llvm::Value* y = b_->CreateUDiv(thread_id_int, num_thread_x); + return std::make_tuple(y, x); } } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h index 5ea05b3188a..7277aeac8ad 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h @@ -28,23 +28,165 @@ namespace llvm_ir { // If a shape can be viewed as three logical components 0-1-2 in the order of // major to minor, a 0-2-1-transpose changes the order of such logical // components to 0-2-1. We call the shape being transposed the input shape and -// the transposed shape the output shape. The logical view of the input and -// output shapes for the transpose are called the 0-1-2 shape or reduced input -// shape and the 0-2-1 shape or the reduced output shape respectively. The -// original input and output shapes are called the unreduced input and output -// shapes. - +// the transposed shape the output shape. The logical view of the input/output +// shapes for the transpose are called the 0-1-2/0-2-1 shapes or the normalized +// shapes. The original input/output shapes are called unnormalized shapes. +// // If `b` is a 0-2-1 transpose of `a` in 0-1-2, return the dimensions for the -// reduced shape of `b` or the 0-2-1 shape. +// normalized shape of `b` or the 0-2-1 shape. absl::optional > FindTranspose021(const Shape& a, const Shape& b); -// Return the unreduced output index corresponding to the given reduced output -// index. -IrArray::Index GetUnreducedOutputIndex( - const IrArray::Index& reduced_output_index, - const Shape& reduced_output_shape, const Shape& unreduced_output_shape, - llvm::IRBuilder<>* b); +// A tile is a spatial subdivision of a tensor. We group tensor elements into +// tiles so that we can launch kernels to process the tensor elements in blocks +// of tiles. +// +// A kernel mapping scheme describes a method to partition the tensors accessed +// by an unnested HLO instruction into tiles and blocks of tiles, and the +// associated information to use hardware threads to process the tensor elements +// in blocks of tiles. +// +// Currently, there are two main use cases for a tiling scheme. First, we +// implement kernels with 0-2-1 memory transpose using shared memory to improve +// memory access pattern. Second, we implement reduction to contiguous +// dimensions in layout, with or without memory tranpsose, to achieve better +// memory access pattern as well as to reduce the need numbers of executed +// expensive instructions, such as thread synchronization related instructions +// and atomic operations. For both use cases, we can apply a normalization to +// the original tensors, to collapse contiguous dimensions for the same purpose +// and produce normlized three dimensional tensors. For this reason, the tiling +// scheme class only needs to handle normalized three dimensional tensors and +// two dimensional tiles. +// +// The current implementation of the class is somewhat NVIDIA GPU oriented. This +// situation can be improved when there is a need though. The idea of 0-2-1 +// transpose using shared memory can be found in the following CUDA algorithm in +// TensorFlow: https://goo.gl/MStRV6. +// +// We use a thread block to process a tile because we want to use the HW thread +// block synchronization primitives to synchronize the processing of all the +// elements in the same tile. A thread block can be viewed as a two dimensional +// array of threads, described by the number of threads for the Y and X +// dimensions. A thread block (num_threads_y, num_threads_x) processes a tile of +// (tile_size_y, tile_size_x) as follows: each thread in the thread block +// processes one element in the tile so that all the threads in the thread block +// together process a subdivision of the tile that has the same dimension as the +// thread block array. Then the thread block moves on to process the next +// subdivision of the tile until the whole tile is processed. Therefore, each +// thread in the thread block processes +// tile_size_x/num_threads_x * tile_size_y/num_threads_y elements in a tile. +// +// There are situations where we want a thread block to process multiple +// tiles. We can't group those tiles into a bigger tiles because we limit a tile +// to a two dimensional spatial subdivision of a tensor. For example, when we +// use tiling to implement reduction with tranpose, we want the partial sum +// produced by each thread to accumulate values for more elements before using +// shlf_down and atomic_add instructions for further reduction, to amortize the +// cost of such expensive instructions. The concept of tile block is introduced +// for this purpose. A tile block is a three dimensional array of tiles, of +// which some dimensions may be degenerated to only one tile. +class KernelMappingScheme { + public: + enum { DimZ = 0, DimY, DimX, DimTot }; + + public: + KernelMappingScheme() {} + // dims_in_elems: the normalized tensor dimensions. + // req_block_sizes: the requested block size in number of tiles for each + // dimension. The actual block size is set to min(req_block_size, + // dims_in_number_of_blocks). + KernelMappingScheme(absl::Span dims_in_elems, int64 tile_size_y, + int64 tile_size_x, + absl::Span req_block_sizes, + int64 num_threads_y, int64 num_threads_x, + llvm::IRBuilder<>* b); + + absl::Span GetDimensionsInElements() const { + return dims_in_elems_; + } + absl::Span GetDimensionsInTiles() const { + return dims_in_tiles_; + } + absl::Span GetDimensionsInBlocks() const { + return dims_in_blocks_; + } + + int64 GetNumberOfTilesInTotal() const { + return absl::c_accumulate(dims_in_tiles_, 1LL, std::multiplies()); + } + int64 GetNumberOfTilesInOneBlock() const { + return absl::c_accumulate(block_sizes_, 1, std::multiplies()); + } + + int64 GetNumberOfBlocks() const { + return absl::c_accumulate(dims_in_blocks_, 1, std::multiplies()); + } + + int64 GetTileSizeForDimension(int d) const { + DCHECK(d >= DimZ && d <= DimX); + return tile_sizes_[d]; + } + int64 GetTileSizeForDimensionX() const { + return GetTileSizeForDimension(DimX); + } + int64 GetTileSizeForDimensionY() const { + return GetTileSizeForDimension(DimY); + } + + absl::Span GetBlockSizes() const { return block_sizes_; } + int64 GetTileBlockSizeForDimension(int d) const { + DCHECK(d >= DimZ && d <= DimX); + return dims_in_blocks_[d]; + } + + int64 GetNumberOfThreadsForDimensionX() const { return num_threads_x_; } + int64 GetNumberOfThreadsForDimensionY() const { return num_threads_y_; } + + int64 GetThreadsPerTile() const { + return GetNumberOfThreadsForDimensionX() * + GetNumberOfThreadsForDimensionY(); + } + + IrArray::Index EmitBlockIndex(llvm::Type* index_ty); + // Returns the index for the first tile in the block with the given block + // index. + IrArray::Index GetTileIndexForBlockOrigin(const IrArray::Index& block_index); + // Returns the index for the first element in the tile with the given tile + // index. + IrArray::Index GetElementIndexForTileOrigin(const IrArray::Index& tile_index); + + std::tuple EmitThreadYXCoordinate( + llvm::Type* index_ty); + + IrArray::Index GetUnnormalizedIndex( + const IrArray::Index& normalized_shape_index, + const Shape& unnormalized_shape); + + llvm::GlobalVariable* GetSharedMemoryBufferForElementType( + llvm::Type* elem_ty, absl::string_view buffer_name); + + private: + llvm::IRBuilder<>* b_; + // The number of elements in each dimension. + std::vector dims_in_elems_; + + // The number of elements for each dimension of a tile. + std::vector tile_sizes_; + // The number of tiles in each dimension. It is computed from dims_in_elem_ + // and tile_sizes_. + std::vector dims_in_tiles_; + + // The number of tiles for each dimension of a tile block. + std::vector block_sizes_; + // The number of blocks in each dimension of a tile block. It is computed from + // dims_in_tile_ and block_sizes_. + std::vector dims_in_blocks_; + + // Number of threads used to process elements in the X direction of a tile. + int64 num_threads_x_; + // Number of threads used to process elements in the Y direction of a tile. + int64 num_threads_y_; +}; // A class to represent information for tiled parameters to support IR emission // for 021 transpose. diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index df78726166e..ceea24685af 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -244,10 +244,11 @@ StatusOr EncodeSelfDescribingShapeConstant(const Shape& shape, StatusOr DecodeSelfDescribingShapeConstant(const void* shape_ptr, int32 size_bytes) { - Shape shape; - TF_RET_CHECK(shape.ParseFromArray(shape_ptr, size_bytes)); + ShapeProto shape_proto; + TF_RET_CHECK(shape_proto.ParseFromArray(shape_ptr, size_bytes)); + Shape shape(shape_proto); TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape)); - return shape; + return std::move(shape); } llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc index fd16af67fe9..e22c2173c27 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc @@ -47,7 +47,8 @@ namespace { // Adds the inner comparison loop body where we compare elements. void EmitCompareLoopBody( int64 iteration_bound, PrimitiveType key_type, int64 num_values, - llvm::Value* element_pair_index, int64 xor_mask, llvm::Type* index_type, + int64 iota_values_parameter_index, llvm::Value* element_pair_index, + int64 xor_mask, llvm::Type* index_type, std::function read_element, std::function write_element, @@ -139,34 +140,42 @@ void EmitCompareLoopBody( is_signed_comparison = false; } // If key2 < key1 - ksl.IfReturnVoid( - "is_smaller_than", + auto is_smaller_than = b->CreateICmp(is_signed_comparison ? llvm::ICmpInst::ICMP_SLT : llvm::ICmpInst::ICMP_ULT, - compare_key2, compare_key1), - [&]() { - // Swap key1 with key2. - write_element(0, current_keys_index, key2); - write_element(0, compare_keys_index, key1); - for (int64 i = 1; i <= num_values; ++i) { - // Also swap the values. - auto value1 = read_element(i, current_keys_index); - auto value2 = read_element(i, compare_keys_index); - write_element(i, current_keys_index, value2); - write_element(i, compare_keys_index, value1); - } - }); + compare_key2, compare_key1); + if (iota_values_parameter_index >= 0) { + auto keys_equal = b->CreateICmpEQ(compare_key1, compare_key2); + auto key_index1 = + read_element(iota_values_parameter_index, current_keys_index); + auto key_index2 = + read_element(iota_values_parameter_index, compare_keys_index); + auto index_is_smaller_than = + b->CreateICmp(llvm::ICmpInst::ICMP_ULT, key_index2, key_index1); + is_smaller_than = b->CreateOr( + is_smaller_than, b->CreateAnd(keys_equal, index_is_smaller_than)); + } + ksl.IfReturnVoid("is_smaller_than", is_smaller_than, [&]() { + // Swap key1 with key2. + write_element(0, current_keys_index, key2); + write_element(0, compare_keys_index, key1); + for (int64 i = 1; i <= num_values; ++i) { + // Also swap the values. + auto value1 = read_element(i, current_keys_index); + auto value2 = read_element(i, compare_keys_index); + write_element(i, current_keys_index, value2); + write_element(i, compare_keys_index, value1); + } + }); }); } -void EmitTiledCompareLoop(const IrArray::Index& tiled_keys_index, - int64 dimension_to_sort, - int64 dimension_to_sort_bound, - PrimitiveType keys_type, - absl::Span xor_masks, - const std::vector& params, - const std::vector& param_shmem_buffers, - int64 tile_size, llvm::IRBuilder<>* b) { +void EmitTiledCompareLoop( + const IrArray::Index& tiled_keys_index, int64 dimension_to_sort, + int64 dimension_to_sort_bound, PrimitiveType keys_type, + absl::Span xor_masks, const std::vector& params, + const std::vector& param_shmem_buffers, + int64 iota_values_parameter_index, int64 tile_size, llvm::IRBuilder<>* b) { KernelSupportLibrary ksl(b); llvm::Value* thread_id = llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b); @@ -253,20 +262,22 @@ void EmitTiledCompareLoop(const IrArray::Index& tiled_keys_index, RoundDownToNearest(dimension_to_sort_bound, tile_size))), [&]() { EmitCompareLoopBody(dimension_to_sort_bound % tile_size, keys_type, - params.size() - 1, element_pair_index, xor_mask, + params.size() - 1, iota_values_parameter_index, + element_pair_index, xor_mask, tiled_keys_index.GetType(), read_element, write_element, b); }, [&]() { - EmitCompareLoopBody( - tile_size, keys_type, params.size() - 1, element_pair_index, - xor_mask, tiled_keys_index.GetType(), read_element, - write_element, b, /*needs_bounds_checks=*/false); + EmitCompareLoopBody(tile_size, keys_type, params.size() - 1, + iota_values_parameter_index, element_pair_index, + xor_mask, tiled_keys_index.GetType(), + read_element, write_element, b, + /*needs_bounds_checks=*/false); }); } else { EmitCompareLoopBody(tile_size, keys_type, params.size() - 1, - element_pair_index, xor_mask, - tiled_keys_index.GetType(), read_element, + iota_values_parameter_index, element_pair_index, + xor_mask, tiled_keys_index.GetType(), read_element, write_element, b, /*needs_bounds_checks=*/false); } // Wait until all comparisons have happened. @@ -296,6 +307,7 @@ void EmitTiledCompareLoop(const IrArray::Index& tiled_keys_index, Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, const std::vector& values_arrays, + int64 iota_values_parameter_index, absl::string_view name, absl::Span xor_masks, llvm::IRBuilder<>* b, const gpu::LaunchDimensions& launch_dimensions, @@ -367,8 +379,8 @@ Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, if (xor_masks.size() > 1) { EmitTiledCompareLoop(keys_index, dimension_to_sort, dimension_to_sort_bound, keys_shape.element_type(), - xor_masks, params, param_shmem_buffers, tile_size, - b); + xor_masks, params, param_shmem_buffers, + iota_values_parameter_index, tile_size, b); } else { auto read_element = [&](int64 operand, llvm::Value* index) { keys_index[dimension_to_sort] = index; @@ -380,9 +392,10 @@ Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, params[operand].EmitWriteArrayElement(keys_index, value, b); }; EmitCompareLoopBody(dimension_to_sort_bound, keys_shape.element_type(), - values_arrays.size(), tiles_index[rank - 1], - xor_masks[0], tiles_index.GetType(), read_element, - write_element, b); + values_arrays.size(), iota_values_parameter_index, + tiles_index[rank - 1], xor_masks[0], + tiles_index.GetType(), read_element, write_element, + b); } return Status::OK(); }; diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h index 556a217322d..685f9383acb 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h @@ -31,9 +31,12 @@ namespace llvm_ir { // Emits llvm IR to do pairwise comparisons/swaps in the 'dimension_to_sort' // dimension of 'keys_array'. All other dimensions are kept as-is. This // implements the inner loop of BitonicSort. It is assumed that 'xor_masks' -// contains only powers of 2, or values 2^k - 1 (k > 0). +// contains only powers of 2, or values 2^k - 1 (k > 0). If +// 'iota_values_parameter_index' is >= 0, it points at a 'values_arrays' operand +// that is a iota and can be used to make the sorting stable. Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, const std::vector& values_arrays, + int64 iota_values_parameter_index, absl::string_view name, absl::Span xor_masks, llvm::IRBuilder<>* b, const gpu::LaunchDimensions& launch_dimensions, diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index cca37556173..6c897009833 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -96,44 +96,18 @@ ExecutionOptions CreateExecutionOptions( const ExecutableBuildOptions& build_options, const ProgramShape* program_shape) { ExecutionOptions execution_options = CreateDefaultExecutionOptions(); - if (build_options.hlo_profile().has_value()) { - execution_options.mutable_debug_options()->set_xla_hlo_profile( - *build_options.hlo_profile()); - } - if (build_options.generate_hlo_graph().has_value()) { - execution_options.mutable_debug_options()->set_xla_generate_hlo_graph( - build_options.generate_hlo_graph().value()); - } - if (build_options.dump_optimized_hlo_proto_to().has_value()) { - execution_options.mutable_debug_options() - ->set_xla_dump_optimized_hlo_proto_to( - build_options.dump_optimized_hlo_proto_to().value()); - } - if (build_options.dump_unoptimized_hlo_proto_to().has_value()) { - execution_options.mutable_debug_options() - ->set_xla_dump_unoptimized_hlo_proto_to( - build_options.dump_unoptimized_hlo_proto_to().value()); - } - if (build_options.dump_per_pass_hlo_proto_to().has_value()) { - execution_options.mutable_debug_options() - ->set_xla_dump_per_pass_hlo_proto_to( - build_options.dump_per_pass_hlo_proto_to().value()); + if (build_options.has_debug_options()) { + *execution_options.mutable_debug_options() = build_options.debug_options(); } if (build_options.result_layout() != nullptr) { *execution_options.mutable_shape_with_output_layout() = - *build_options.result_layout(); + build_options.result_layout()->ToProto(); } else { + Shape result_shape(program_shape->result()); + LayoutUtil::SetToDefaultLayout(&result_shape); *execution_options.mutable_shape_with_output_layout() = - program_shape->result(); - LayoutUtil::SetToDefaultLayout( - execution_options.mutable_shape_with_output_layout()); + result_shape.ToProto(); } - - for (const std::string& disabled_pass : build_options.disabled_hlo_passes()) { - execution_options.mutable_debug_options()->add_xla_disable_hlo_passes( - disabled_pass); - } - return execution_options; } @@ -145,7 +119,7 @@ StatusOr> LocalService::CompileExecutable( const ExecutableBuildOptions& build_options) { const HloModuleProto& proto = computation.proto(); TF_RET_CHECK(proto.has_host_program_shape()); - const ProgramShape& program_shape = proto.host_program_shape(); + ProgramShape program_shape(proto.host_program_shape()); // Validate incoming layouts. if (argument_layouts.size() != program_shape.parameters_size()) { @@ -220,4 +194,10 @@ StatusOr LocalService::GlobalDataToShapedBuffer( return buffers[replica_number]; } +StatusOr LocalService::RegisterReplicatedBuffers( + std::vector replicated_buffers, const string& tag) { + return allocation_tracker_.RegisterReplicatedBuffers( + std::move(replicated_buffers), tag); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h index 3b4f0b50832..f56ba32b04b 100644 --- a/tensorflow/compiler/xla/service/local_service.h +++ b/tensorflow/compiler/xla/service/local_service.h @@ -63,6 +63,11 @@ class LocalService : public Service { StatusOr GlobalDataToShapedBuffer( const GlobalDataHandle& data, int replica_number); + // Registers a vector of shaped buffers of device memory, one per replica, and + // returns a corresponding handle that can be used for talking to XLA clients. + StatusOr RegisterReplicatedBuffers( + std::vector replicated_buffers, const string& tag); + private: explicit LocalService(const ServiceOptions& options, std::unique_ptr backend); diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc index ec52a24d782..972a5b9ced0 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc @@ -113,6 +113,13 @@ Status LogicalBufferAnalysis::HandleGetTupleElement(HloInstruction*) { return Status::OK(); } +Status LogicalBufferAnalysis::HandleAddDependency( + HloInstruction* add_dependency) { + // AddDependency just forwards the value of its zero-th operand and does not + // create buffers. + return Status::OK(); +} + Status LogicalBufferAnalysis::HandleCopy(HloInstruction* copy) { // The top-level buffer (index={}) for kCopy is newly created, but all other // buffers (in the case of a tuple shape) come from the operand diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.h b/tensorflow/compiler/xla/service/logical_buffer_analysis.h index 81f524d84a8..7ffca943d0f 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.h +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.h @@ -64,6 +64,7 @@ class LogicalBufferAnalysis : public DfsHloVisitorWithDefault { Status HandleRecvDone(HloInstruction* recv_done) override; Status HandleSend(HloInstruction* send) override; Status HandleTupleSelect(HloInstruction* tuple_select) override; + Status HandleAddDependency(HloInstruction* add_dependency) override; // A map from the buffer ID to the logical buffer std::vector> logical_buffers_; diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index 6152cdc6099..432aa1ea0b6 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_ +#include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "absl/utility/utility.h" #include "tensorflow/compiler/xla/layout_util.h" @@ -44,32 +45,48 @@ namespace xla { // // This pattern will match Add instructions whose first operand is a constant. // -// Each pattern type has the following modifiers: +// Each pattern type has the following modifiers, which are described where +// nontrivial. // // Op(): -// - WithName: match operations with the given name -// - WithOpcode: match operations with the given opcode -// - WithShape: match operations whose shape matches the given pattern -// - WithOperand: match operations whose operand matches the given pattern +// - Is: is the given HloInstruction* (i.e. pointer equality) +// - WithName +// - WithOpcode +// - WithoutOpcode: anything other than the given opcode +// - WithShape: instr's shape matches the given pattern +// - WithShapeEqualTo: instr's shape is equal to the given Shape +// - WithShapeCompatibleTo: instr's shape is compatible with the given Shape +// - WithNumOperands +// - WithOperand: operand at the given index matches the given pattern +// - IsConstant +// - IsNonConstant +// - IsConstantScalar/IsEffectiveConstantScalar: Optionally accepts a value, +// e.g. IsConstantScalar() or IsConstantScalar(42). +// - WithFusionKind +// - WithTupleIndex: get-tuple-element operations with the given tuple index +// - WithOneUse: Instruction is used as an operand exactly once. +// - WithOneUser: Instruction is used by exactly one other instruction, but +// is possibly used more than once as an operand (e.g. multiply(x,x)). // // Shape(): -// - EqualTo: matches shapes that are equal to the argument -// - CompatibleTo: matches shapes that are compatible to the argument -// - IsScalar/IsArray/IsTuple: matches scalar/array/tuple shapes -// - IsDenseArray/IsSparseArray: matches arrays with dense/sparse format -// - WithLayout: match shapes whose layout matches the given pattern -// - WithLayoutEqualTo: matches shapes whose layouts equal the argument -// - WithSubshape: matches tuple shapes whose subshape matches the given -// pattern -// - WithSubshapeEqualTo: matches shapes with a subshape equal the argument -// - WithElementType: matches array/scalar shapes with the given element -// type -// - WithRank: matches array/scalar types with the given rank +// - EqualTo +// - CompatibleTo +// - IsScalar/IsEffectiveScalar/IsArray/IsTuple +// - IsDenseArray/IsSparseArray +// - WithLayout: layout shape's layout matches the given pattern (e.g. +// Layout().WithDenseFormat()) +// - WithLayoutEqualTo: shape's layout equals the argument (i.e. another +// Layout, but not the result of Layout().foo()) +// - WithSubshape: shape is a tuple whose subshape matches the given pattern +// (e.g. Shape().IsScalar()). +// - WithSubshapeEqualTo: shape is a tuple with a subshape equal to the arg +// (i.e. another Shape, but not the result of Shape().foo()) +// - WithElementType: shape is an array/scalar with the given elem type +// - WithRank: shape is an array/scalar with the given rank // // Layout(): -// - EqualTo: matches layouts that are equal to the argument -// - WithDenseFormat/WithSparseFormat: matches layouts with dense/sparse -// format +// - EqualTo +// - WithDenseFormat/WithSparseFormat // // Op(), Shape(), and Layout() may be passed an argument of type // HloInstruction**, Shape**, or Layout**, respectively, or const versions of @@ -82,53 +99,55 @@ namespace xla { // CHECK(Match(foo, // match::Op().WithOperand(0, match::Op(&matched_operand)))); // -// Helpers are provided for common nullary, unary, binary, and ternary -// instructions. These helpers can be called with no arguments, in which case -// they will match any instruction matching the opcode. They may also be called -// with matches for the operands and with an optional capture. (The capture must -// be the first argument.) Some examples of these helpers and their equivalents -// are provided below. -// +// Helpers are provided for most HLO instructions. These helpers can be called +// with no arguments, in which case they will match any instruction matching the +// opcode. They may also be called with matches for the operands and with an +// optional capture. (The capture must be the first argument.) Some examples of +// these helpers and their equivalents are provided below. + // Example nullary instruction: -// Param() == Op().WithOpcode(HloOpcode::kParam) -// Param(&a) == Op(&a).WithOpcode(HloOpcode::kParam) +// Parameter() == Op().WithOpcode(HloOpcode::kParameter) +// Parameter(&a) == Op(&a).WithOpcode(HloOpcode::kParameter) // // Example unary instruction: -// Abs() == Op().WithOpcode(HloOpcode::kAbs) -// Abs(Op(&a)) == Op().WithOpcode(HloOpcode::kAbs) -// .WithOperand(0, Op(&a))) -// Abs(&a, Op(&b)) == Op(&a).WithOpcode(HloOpcode::kAbs) -// .WithOperand(0, Op(&b)) +// Abs() == Op().WithOpcode(HloOpcode::kAbs) +// Abs(Op(&a)) == Op().WithOpcode(HloOpcode::kAbs) +// .WithOperand(0, Op(&a))) +// Abs(&a, Op(&b)) == Op(&a).WithOpcode(HloOpcode::kAbs) +// .WithOperand(0, Op(&b)) // -// Example binary instruction: -// Add() == Op().WithOpcode(HloOpcode::kAdd) -// Add(Op(&a), Op(&b)) == Op().WithOpcode(HloOpcode::kAdd) -// .WithOperand(0, Op(&a)) -// .WithOperand(1, Op(&b)) -// Add(&a, Op(&b), Op(&c)) == Op(&a).WithOpcode(HloOpcode::kAdd) -// .WithOperand(0, Op(&b)) -// .WithOperand(1, Op(&c)) +// Commutative binary instructions have a special form that accepts either order +// of args, e.g.: // -// Example ternary instruction: -// Clamp() == Op().WithOpcode(HloOpcode::kClamp) -// Clamp(Op(&a), Op(&b), Op(&c)) == Op().WithOpcode(HloOpcode::kClamp) -// .WithOperand(0, Op(&a)) -// .WithOperand(1, Op(&b)) -// .WithOperand(2, Op(&c)) -// Clamp(&a, Op(&b), Op(&c), Op(&d)) == Op(&a).WithOpcode(HloOpcode::kClamp) -// .WithOperand(0, Op(&b)) -// .WithOperand(1, Op(&c)) -// .WithOperand(2, Op(&d)) +// AddAnyOrder(Parameter(1), Abs()) == +// Op().WithOpcode(HloOpcode::kAdd) +// .WithBinaryOperandsAnyOrder(Op().WithParameterNum(1), Abs()); // +// MultiplyAnyOrder(&a, Parameter(), Abs()) // Captures the mul in `a`. +// +// The following additional helpers are provided. In all cases, `&a` is +// optional. +// +// ConstantScalar(&a) == Op(&a).IsConstantScalar(); +// ConstantScalar(&a, v) == Op(&a).IsConstantScalar(v); +// ConstantEffectiveScalar(&a) == Op(&a).IsConstantEffectiveScalar(); +// ConstantEffectiveScalar(&a, v) == Op(&a).IsConstantEffectiveScalar(&a, v) +// NonConstant(&a) == Op(&a).IsNonConstant() +// GetTupleElement(&a, b, index) == Op(&a).WithTupleIndex(index) +// .WithOperand(0, b); +// Parameter(&a, n) == Op(&a).WithParameterNum(n); struct MatchOption { // If true, actually capture matched item into the user pointer. bool capture; + + // An explanation for why we failed to match is streamed here, if not-null. + std::ostream* explain_os; }; template bool Match(Value* value, const Pattern& pattern, - MatchOption option = {/*.capture=*/true}) { + MatchOption option = {/*.capture=*/true, /*.explain_os=*/nullptr}) { if (option.capture) { auto new_option = option; new_option.capture = false; @@ -143,6 +162,77 @@ namespace match { namespace detail { +// Macro for streaming to option.explain_os if it's not null. +// +// EXPLAIN << "value of foo(): " << foo() +// +#pragma push_macro("EXPLAIN") +#define EXPLAIN \ + if (option.explain_os) *option.explain_os + +// kIndentInc is the additional number of spaces that we indent by when we +// increase the indent "by one". +enum { + kIndentInc = 2, +}; + +// Writes a newline and then `indent` spaces. +// +// We follow an unintuitive convention in this file's pretty-printers: Indents +// are performed by the caller, not the callee. For example, if you want to +// print +// +// foo: +// - bar +// +// you'd do: +// +// Foo::DescribeTo(std::ostream* os, int64 indent) { +// *os << "foo:"; +// Indent(os, indent) // Create a newline at the *current* indent level. +// *os << " - "; +// bar.DescribeTo(os, indent + 3); // + 3 because strlen(" * ") == 3. +// } +// +// Bar::DescribeTo(std::ostream* os, int64 indent) { *os << "bar"; } +// +// Notice that Bar::DescribeTo() does not call Indent; the indenting is +// performed by Foo. This convention allows the caller to decide whether a +// matcher is preceded by a newline, which is important e.g. for the AllOf +// matcher. +// +// (Incidentally, indenting in Match's explanations is handled differently. +// Indents are a common case in DescribeTo [we're printing a whole tree], but +// they're a special case in Match [we're printing only a path through the tree +// that encounters a failing node]. Indents in Match only appear when we +// encounter a failing disjunction, so we just handle them as a special case +// there.) +inline void Indent(std::ostream* os, int64 indent) { + *os << "\n"; + for (int64 i = 0; i < indent; ++i) { + *os << " "; + } +} + +// SFINAE template that determines whether T declares a static member +// kIsTrivialMatcher. +// +// Trivial matchers get special treatment. For example, when printing +// a conjunction of matchers, we don't print "and" after a trivial matcher. This +// yields e.g. +// "a shape compatible with f32[1,2]" +// rather than +// "a shape AND compatible with f32[1,2]" +template +struct IsTrivialMatcher { + static constexpr bool value = false; +}; +template +struct IsTrivialMatcher::type> { + static constexpr bool value = true; +}; + template class AllOfPattern { public: @@ -162,10 +252,19 @@ class AllOfPattern { return matched; } + void DescribeTo(std::ostream* os, int64 indent = 0) const { + DescribeToImpl(os, std::integral_constant(), indent); + } + + // Accessor for patterns_. Please don't use this outside of this file. + const std::tuple& patterns() const { return patterns_; } + private: template bool MatchImpl(ItemType* item, MatchOption option, std::integral_constant) const { + // We don't need to do any EXPLAINing here; it's all correctly handled by + // our sub-matchers (if any fail). return std::get(patterns_).Match(item, option) && MatchImpl(item, option, std::integral_constant()); } @@ -176,6 +275,73 @@ class AllOfPattern { return true; } + // Pretty-printing a conjunction has some special cases to make it easy to + // read in the simple (common) case. + // + // If sizeof...(Patterns) == 1, prints as e.g. + // + // a shape + // + // If sizeof...(Patterns) == 2 and patterns_[0] is a trivial matcher (e.g. "a + // shape") prints as + // + // a shape compatible with f32[1,2] + // + // If sizeof...(Patterns) > 2 and patterns_[0] is a trivial matcher, prints as + // + // a shape: + // * compatible with f32[1,2] AND + // * that represents a scalar + // + // Otherwise prints as: + // + // all of: + // * foo AND + // * bar + // + template + void DescribeToImpl(std::ostream* os, std::integral_constant, + int64 indent) const { + constexpr bool first_is_trivial = + IsTrivialMatcher(patterns_))>::type>::value; + constexpr bool is_last = index == sizeof...(Patterns) - 1; + const auto& submatcher = std::get(patterns_); + + auto print_bulleted_item = [&] { + *os << " * "; + submatcher.DescribeTo(os, indent + 3); + if (!is_last) { + *os << " AND"; + Indent(os, indent); + } + }; + + if (index == 0) { + if (first_is_trivial || is_last) { + submatcher.DescribeTo(os, indent + kIndentInc); + if (sizeof...(Patterns) > 2) { + *os << ":"; + Indent(os, indent); + } + } else { + *os << "all of:"; + Indent(os, indent); + print_bulleted_item(); + } + } else if (first_is_trivial && index == 1 && sizeof...(Patterns) == 2) { + *os << " "; + submatcher.DescribeTo(os, indent); + } else { + print_bulleted_item(); + } + DescribeToImpl(os, std::integral_constant(), indent); + } + + void DescribeToImpl(std::ostream* os, + std::integral_constant, + int64 indent) const {} + std::tuple patterns_; }; @@ -183,10 +349,6 @@ class AllOfPattern { // Returns a pattern that represents the conjunction of all input patterns. All // patterns need to match in order to have the AllOf pattern match. -// -// TODO(timshen): Currently AllOf is still nested, e.g. AllOf, B> is -// not AllOf. We might want to flatten the AllOf type structure if the -// C++ compile error message gets annoying. template detail::AllOfPattern::type, Patterns...> AllOf( const Patterns&... patterns) { @@ -194,6 +356,25 @@ detail::AllOfPattern::type, Patterns...> AllOf( Patterns...>(patterns...); } +// AllOf, X, Y, ...> => AllOf. +// +// This transformation is necessary for good pretty-printing. +template +detail::AllOfPattern::type, InnerPs..., + OuterPs...> +AllOf(const detail::AllOfPattern& inner_p, + const OuterPs&... outer_ps) { + // Invoke constructor of AllOfPattern. + auto make_all_of = [](const InnerPs&... inner_ps, + const OuterPs&... outer_ps) { + return detail::AllOfPattern::type, + InnerPs..., OuterPs...>(inner_ps..., + outer_ps...); + }; + return absl::apply(make_all_of, std::tuple_cat(inner_p.patterns(), + std::make_tuple(outer_ps...))); +} + namespace detail { template @@ -204,8 +385,18 @@ class LayoutPattern; class LayoutPatternBaseImpl { public: bool Match(const ::xla::Layout* layout, MatchOption option) const { - return layout != nullptr; + if (layout == nullptr) { + EXPLAIN << "Layout is null"; + return false; + } + return true; } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "a layout"; + } + + static constexpr bool kIsTrivialMatcher = true; }; // A LayoutPattern implementation that matches only if the layout equals a @@ -216,7 +407,17 @@ class LayoutPatternEqualImpl { : layout_(layout) {} bool Match(const ::xla::Layout* layout, MatchOption option) const { - return LayoutUtil::Equal(*layout_, *layout); + if (!LayoutUtil::Equal(*layout_, *layout)) { + EXPLAIN << "Layout " << LayoutUtil::HumanString(*layout) + << " is not equal to expected " + << LayoutUtil::HumanString(*layout_); + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "equal to " << LayoutUtil::HumanString(*layout_); } private: @@ -230,7 +431,16 @@ class LayoutPatternFormatImpl { explicit constexpr LayoutPatternFormatImpl(Format format) : format_(format) {} bool Match(const ::xla::Layout* layout, MatchOption option) const { - return layout->format() == format_; + if (layout->format() != format_) { + EXPLAIN << "Layout has format " << Format_Name(layout->format()) + << " but expected " << Format_Name(format_); + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "with format " << Format_Name(format_); } private: @@ -242,11 +452,13 @@ template class LayoutPattern { private: template - LayoutPattern> - AppendImpl(NewImpl new_impl) const { - return LayoutPattern>( - AllOf(impl_, std::move(new_impl)), matched_layout_); + auto AppendImpl(NewImpl new_impl) const + -> LayoutPattern(std::declval(), + std::move(new_impl)))> { + auto new_allof = AllOf(impl_, std::move(new_impl)); + return LayoutPattern(std::move(new_allof), + matched_layout_); } public: @@ -276,6 +488,10 @@ class LayoutPattern { return false; } + void DescribeTo(std::ostream* os, int64 indent = 0) const { + impl_.DescribeTo(os, indent); + } + // Modifies the pattern to match only if the layout equals the given proto. // The layout must outlive the returned pattern. constexpr auto EqualTo(const ::xla::Layout* layout) const @@ -306,19 +522,48 @@ class AnyOfPattern { explicit AnyOfPattern(const Patterns&... patterns) : patterns_(patterns...) {} bool Match(const Item* item, MatchOption option) const { - return MatchImpl(item, option, std::integral_constant()); + return MatchImpl(item, option); } bool Match(Item* item, MatchOption option) const { - return MatchImpl(item, option, std::integral_constant()); + return MatchImpl(item, option); + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "any of:"; + Indent(os, indent); + DescribeToImpl(os, std::integral_constant(), indent); } private: + template + bool MatchImpl(ItemType* item, MatchOption option) const { + // If we're generating an explanation, buffer it until we know we failed. + absl::optional explanation; + MatchOption new_option = option; + if (option.explain_os) { + new_option.explain_os = &explanation.emplace(); + } + bool rv = MatchRecursiveImpl(item, new_option, + std::integral_constant()); + if (!rv && option.explain_os) { + EXPLAIN << "None of the following matchers succeeded:"; + EXPLAIN << explanation->str(); + } + return rv; + } + template - bool MatchImpl(ItemType* item, MatchOption option, - std::integral_constant) const { + bool MatchRecursiveImpl(ItemType* item, MatchOption option, + std::integral_constant) const { auto new_option = option; new_option.capture = false; + + absl::optional explanation; + if (option.explain_os) { + new_option.explain_os = &explanation.emplace(); + } + // Try to match the sub-pattern without capturing behavior. if (std::get(patterns_).Match(item, new_option)) { // Capture the branch. @@ -337,20 +582,46 @@ class AnyOfPattern { // AnyOf will be a runtime number indicate which sub-pattern is matched. // Then we run another pass to do captures only with the help of the // trace. - bool ret = std::get(patterns_).Match(item, option); - DCHECK(ret); + bool matched = std::get(patterns_).Match(item, option); + DCHECK(matched); } return true; } - return MatchImpl(item, option, std::integral_constant()); + if (option.explain_os) { + EXPLAIN << "\nMatcher #" << index + 1; + EXPLAIN << "\n - "; + std::get(patterns_).DescribeTo(option.explain_os, /*indent=*/3); + EXPLAIN << "\nfailed with"; + EXPLAIN << "\n - "; + EXPLAIN << absl::StrReplaceAll(explanation->str(), {{"\n", "\n "}}); + } + return MatchRecursiveImpl(item, option, + std::integral_constant()); } template - bool MatchImpl(ItemType* item, MatchOption option, - std::integral_constant) const { + bool MatchRecursiveImpl( + ItemType* item, MatchOption option, + std::integral_constant) const { return false; } + template + void DescribeToImpl(std::ostream* os, std::integral_constant, + int64 indent) const { + *os << " - "; + std::get(patterns_).DescribeTo(os, indent + 3); + if (index != sizeof...(Patterns) - 1) { + *os << " OR"; + Indent(os, indent); + } + DescribeToImpl(os, std::integral_constant(), indent); + } + + void DescribeToImpl(std::ostream* os, + std::integral_constant, + int64 indent) const {} + std::tuple patterns_; }; @@ -395,8 +666,17 @@ class ShapePattern; class ShapePatternBaseImpl { public: bool Match(const ::xla::Shape* shape, MatchOption option) const { + if (shape == nullptr) { + EXPLAIN << "Shape is null"; + } return shape != nullptr; } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "a shape"; + } + + static constexpr bool kIsTrivialMatcher = true; }; // A ShapePattern implementation that matches only if the shape equals a Shape @@ -407,7 +687,16 @@ class ShapePatternEqualImpl { : shape_(shape) {} bool Match(const ::xla::Shape* shape, MatchOption option) const { - return ShapeUtil::Equal(*shape_, *shape); + if (!ShapeUtil::Equal(*shape_, *shape)) { + EXPLAIN << "Shape not equal to " + << ShapeUtil::HumanStringWithLayout(*shape_); + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "equal to " << ShapeUtil::HumanStringWithLayout(*shape_); } private: @@ -422,7 +711,16 @@ class ShapePatternCompatibleImpl { : shape_(shape) {} bool Match(const ::xla::Shape* shape, MatchOption option) const { - return ShapeUtil::Compatible(*shape_, *shape); + if (!ShapeUtil::Compatible(*shape_, *shape)) { + EXPLAIN << "Shape not compatible with " + << ShapeUtil::HumanString(*shape_); + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "compatible with " << ShapeUtil::HumanString(*shape_); } private: @@ -437,7 +735,16 @@ class ShapePatternElementTypeImpl { : element_type_(element_type) {} bool Match(const ::xla::Shape* shape, MatchOption option) const { - return shape->element_type() == element_type_; + if (shape->element_type() != element_type_) { + EXPLAIN << "Shape does not have element type " + << PrimitiveType_Name(element_type_); + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "with element type " << PrimitiveType_Name(element_type_); } private: @@ -450,7 +757,15 @@ class ShapePatternIsScalarImpl { explicit constexpr ShapePatternIsScalarImpl() {} bool Match(const ::xla::Shape* shape, MatchOption option) const { - return ShapeUtil::IsScalar(*shape); + if (!ShapeUtil::IsScalar(*shape)) { + EXPLAIN << "Shape is not a scalar"; + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "that represents a scalar"; } }; @@ -460,7 +775,15 @@ class ShapePatternIsArrayImpl { explicit constexpr ShapePatternIsArrayImpl() {} bool Match(const ::xla::Shape* shape, MatchOption option) const { - return ShapeUtil::IsArray(*shape); + if (!ShapeUtil::IsArray(*shape)) { + EXPLAIN << "Shape is not an array"; + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "that represents an array"; } }; @@ -470,7 +793,34 @@ class ShapePatternIsTupleImpl { explicit constexpr ShapePatternIsTupleImpl() {} bool Match(const ::xla::Shape* shape, MatchOption option) const { - return ShapeUtil::IsTuple(*shape); + if (!ShapeUtil::IsTuple(*shape)) { + EXPLAIN << "Shape is not a tuple"; + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "that represents a tuple"; + } +}; + +// A ShapePattern implementation that matches only if the shape is an effective +// scalar. +class ShapePatternEffectiveScalarImpl { + public: + explicit constexpr ShapePatternEffectiveScalarImpl() {} + + bool Match(const ::xla::Shape* shape, MatchOption option) const { + if (!ShapeUtil::IsEffectiveScalar(*shape)) { + EXPLAIN << "Shape is not an effective scalar"; + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "that is an effective scalar"; } }; @@ -481,7 +831,23 @@ class ShapePatternRankImpl { explicit constexpr ShapePatternRankImpl(int64 rank) : rank_(rank) {} bool Match(const ::xla::Shape* shape, MatchOption option) const { - return ShapeUtil::Rank(*shape) == rank_; + if (ShapeUtil::Rank(*shape) != rank_) { + if (rank_ == 0) { + EXPLAIN << "Shape is not a scalar"; + } else { + EXPLAIN << "Shape does not have rank " << rank_; + } + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + if (rank_ == 0) { + *os << "that is a scalar"; + } else { + *os << "that has " << rank_ << " dimension" << (rank_ != 1 ? "s" : ""); + } } private: @@ -503,8 +869,21 @@ class ShapePatternLayoutImpl { } bool Match(Shape* shape, MatchOption option) const { - return LayoutUtil::HasLayout(*shape) && - layout_.Match(shape->mutable_layout(), option); + if (!LayoutUtil::HasLayout(*shape)) { + EXPLAIN << "Shape does not have a layout"; + return false; + } + if (!layout_.Match(shape->mutable_layout(), option)) { + EXPLAIN << "\nin layout"; + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "with"; + Indent(os, indent + kIndentInc); + layout_.DescribeTo(os, indent + kIndentInc); } private: @@ -522,17 +901,40 @@ class ShapePatternSubshapeImpl { : index_(index), subshape_(subshape) {} bool Match(const ::xla::Shape* shape, MatchOption option) const { - return ShapeUtil::IndexIsValid(*shape, index_) && - subshape_.Match(&ShapeUtil::GetSubshape(*shape, index_), option); + return MatchImpl(shape, option); } bool Match(::xla::Shape* shape, MatchOption option) const { - return ShapeUtil::IndexIsValid(*shape, index_) && - subshape_.Match(ShapeUtil::GetMutableSubshape(shape, index_), - option); + return MatchImpl(shape, option); + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "with subshape at index " << index_.ToString() << " which is"; + Indent(os, indent + kIndentInc); + subshape_.DescribeTo(os, indent + kIndentInc); } private: + Shape* GetSubshape(Shape* shape) const { + return ShapeUtil::GetMutableSubshape(shape, index_); + } + const Shape* GetSubshape(const Shape* shape) const { + return &ShapeUtil::GetSubshape(*shape, index_); + } + + template + bool MatchImpl(ShapeType* shape, MatchOption option) const { + if (!ShapeUtil::IndexIsValid(*shape, index_)) { + EXPLAIN << "No subshape at " << index_.ToString(); + return false; + } + if (!subshape_.Match(GetSubshape(shape), option)) { + EXPLAIN << "\nin subshape at " << index_.ToString(); + return false; + } + return true; + } + ShapeIndexView index_; ShapePattern subshape_; }; @@ -542,10 +944,12 @@ template class ShapePattern { private: template - ShapePattern> AppendImpl( - NewImpl new_impl) const { - return ShapePattern>( - AllOf(impl_, std::move(new_impl)), matched_shape_); + auto AppendImpl(NewImpl new_impl) const + -> ShapePattern(std::declval(), + std::move(new_impl)))> { + auto new_all_of = AllOf(impl_, std::move(new_impl)); + return ShapePattern(std::move(new_all_of), + matched_shape_); } public: @@ -560,6 +964,11 @@ class ShapePattern { } return true; } + if (shape) { + EXPLAIN << "\nin " + << (shape->has_layout() ? ShapeUtil::HumanStringWithLayout(*shape) + : ShapeUtil::HumanString(*shape)); + } return false; } @@ -571,9 +980,16 @@ class ShapePattern { } return true; } + EXPLAIN << "\nin " + << (shape->has_layout() ? ShapeUtil::HumanStringWithLayout(*shape) + : ShapeUtil::HumanString(*shape)); return false; } + void DescribeTo(std::ostream* os, int64 indent = 0) const { + return impl_.DescribeTo(os, indent); + } + // Modifies the pattern to match only if the shape equals the given proto. // The layout must outlive the returned pattern. constexpr auto EqualTo(const ::xla::Shape* shape) const @@ -612,6 +1028,11 @@ class ShapePattern { return AppendImpl(ShapePatternIsTupleImpl()); } + constexpr auto IsEffectiveScalar() const + -> decltype(this->AppendImpl(ShapePatternEffectiveScalarImpl())) { + return AppendImpl(ShapePatternEffectiveScalarImpl()); + } + // Modifies the pattern to match only if the shape has the given rank. constexpr auto WithRank(int64 rank) const -> decltype(this->AppendImpl(ShapePatternRankImpl(rank))) { @@ -706,6 +1127,22 @@ Shape(::xla::Shape** matched_shape) { namespace detail { +// Overloads to get a const or non-const operand out of an instruction. +inline HloInstruction* HloOperand(HloInstruction* instr, int64 idx) { + return instr->mutable_operand(idx); +} +inline const HloInstruction* HloOperand(const HloInstruction* instr, + int64 idx) { + return instr->operand(idx); +} + +// Pretty-printer for HloInstruction. Sort of like ToShortString, but with +// fewer %s and more shapes. +inline string InstToString(const HloInstruction* inst) { + return inst->ToString( + HloPrintOptions().set_print_metadata(false).set_print_percent(false)); +} + template class HloInstructionPattern; @@ -714,8 +1151,18 @@ class HloInstructionPattern; class HloInstructionPatternBaseImpl { public: bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { - return inst != nullptr; + if (inst == nullptr) { + EXPLAIN << "HloInstruction* is null"; + return false; + } + return true; } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "an HloInstruction"; + } + + static constexpr bool kIsTrivialMatcher = true; }; // An HloInstructionPattern implementation that matches only if the instruction @@ -726,13 +1173,44 @@ class HloInstructionPatternNameImpl { : name_(name) {} bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { - return inst->name() == name_; + if (inst->name() != name_) { + EXPLAIN << "HloInstruction not named \"" << name_ << "\""; + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "named \"" << name_ << "\""; } private: absl::string_view name_; }; +// An HloInstructionPattern implementation that matches only if the instruction +// equals a particular pointer. +class HloInstructionIsImpl { + public: + explicit HloInstructionIsImpl(const HloInstruction* inst) : inst_(inst) {} + + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + if (inst != inst_) { + EXPLAIN << "HloInstruction " << inst << " is not " << inst_ << " (" + << InstToString(inst_) << ")"; + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "which is " << inst_ << " (" << InstToString(inst_) << ")"; + } + + private: + const HloInstruction* inst_; +}; + // An HloInstructionPattern implementation that matches only if the instruction // has a given opcode. class HloInstructionPatternOpcodeImpl { @@ -742,7 +1220,25 @@ class HloInstructionPatternOpcodeImpl { : opcode_(opcode), invert_(invert) {} bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { - return (invert_ ^ (inst->opcode() == opcode_)); + if (invert_ && inst->opcode() == opcode_) { + EXPLAIN << "HloInstruction has opcode " << HloOpcodeString(opcode_) + << ", expected anything else"; + return false; + } + if (!invert_ && inst->opcode() != opcode_) { + EXPLAIN << "HloInstruction doesn't have opcode " + << HloOpcodeString(opcode_); + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + if (!invert_) { + *os << "with opcode " << HloOpcodeString(opcode_); + } else { + *os << "with any opcode other than " << HloOpcodeString(opcode_); + } } private: @@ -757,8 +1253,17 @@ class HloInstructionPatternNumOperandsImpl { explicit constexpr HloInstructionPatternNumOperandsImpl(int64 num_operands) : num_operands_(num_operands) {} - bool Match(const ::xla::HloInstruction* inst, MatchOption /*option*/) const { - return inst->operand_count() == num_operands_; + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + if (inst->operand_count() != num_operands_) { + EXPLAIN << "HloInstruction doesn't have " << num_operands_ << " operands"; + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "with " << num_operands_ << " operand" + << (num_operands_ != 1 ? "s" : ""); } private: @@ -775,11 +1280,25 @@ class HloInstructionPatternShapeImpl { : shape_(shape) {} bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { - return shape_.Match(&inst->shape(), option); + if (!shape_.Match(&inst->shape(), option)) { + EXPLAIN << "\nin output shape"; + return false; + } + return true; } bool Match(::xla::HloInstruction* inst, MatchOption option) const { - return shape_.Match(inst->mutable_shape(), option); + if (!shape_.Match(inst->mutable_shape(), option)) { + EXPLAIN << "\nin output shape"; + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "outputting"; + Indent(os, indent + kIndentInc); + shape_.DescribeTo(os, indent + kIndentInc); } private: @@ -797,20 +1316,197 @@ class HloInstructionPatternOperandImpl { : operand_index_(operand_index), operand_(operand) {} bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { - return operand_index_ < inst->operand_count() && - operand_.Match(inst->operand(operand_index_), option); + return MatchImpl(inst, option); } bool Match(::xla::HloInstruction* inst, MatchOption option) const { - return operand_index_ < inst->operand_count() && - operand_.Match(inst->mutable_operand(operand_index_), option); + return MatchImpl(inst, option); + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "with operand " << operand_index_ << " which is:"; + Indent(os, indent + kIndentInc); + operand_.DescribeTo(os, indent + kIndentInc); } private: + template + bool MatchImpl(HloInstructionType* inst, MatchOption option) const { + if (operand_index_ >= inst->operand_count()) { + EXPLAIN << "desired operand index " << operand_index_ + << " is out of bounds"; + return false; + } + if (!operand_.Match(HloOperand(inst, operand_index_), option)) { + EXPLAIN << "\nin operand " << operand_index_; + return false; + } + return true; + } + int64 operand_index_; HloInstructionPattern operand_; }; +// Matches a binary instruction whose operands come in any order. +template +class HloInstructionPatternBinaryOperandsAnyOrderImpl { + public: + explicit constexpr HloInstructionPatternBinaryOperandsAnyOrderImpl( + const HloInstructionPattern& op1, + const HloInstructionPattern& op2) + : op1_(op1), op2_(op2) {} + + bool Match(HloInstruction* inst, MatchOption option) const { + return MatchImpl(inst, option); + } + + bool Match(const HloInstruction* inst, MatchOption option) const { + return MatchImpl(inst, option); + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "with two operands in either order:"; + Indent(os, indent); + *os << " - "; + op1_.DescribeTo(os, indent + 3); + Indent(os, indent); + *os << " - "; + op2_.DescribeTo(os, indent + 3); + } + + private: + HloInstruction* operand(HloInstruction* inst, int64 idx) const { + return inst->mutable_operand(idx); + } + const HloInstruction* operand(const HloInstruction* inst, int64 idx) const { + return inst->operand(idx); + } + + template + bool MatchImpl(HloInstructionType* inst, MatchOption option) const { + // We could implement this using AnyOf and AllOf matchers, but the templates + // get pretty difficult to debug, since any compile error herein becomes + // not-an-error via SFINAE. Also this way lets us give better messages on + // failure. + if (inst->operand_count() != 2) { + EXPLAIN << "HloInstruction did not have two operands"; + return false; + } + + // If we're not generating explanations, this is pretty simple. + if (!option.explain_os) { + auto try_match = [&](int64 idx1, int64 idx2) { + MatchOption new_option = option; + new_option.capture = false; + if (op1_.Match(operand(inst, idx1), new_option) && + op2_.Match(operand(inst, idx2), new_option)) { + if (option.capture) { + bool matched = op1_.Match(operand(inst, idx1), option) && + op2_.Match(operand(inst, idx2), option); + DCHECK(matched); + } + return true; + } + return false; + }; + return try_match(0, 1) || try_match(1, 0); + } + + // If we are generating explanations, we have some work to do in order to + // generate a helpful error. + // + // First, try all four operand/matcher combinations, recording the + // failure explanations separately from option.explain_os. matches[i][j] + // tells us if matcher_i matches operand j. + bool matches[/*matcher*/ 2][/*operand*/ 2]; + std::stringstream explanations[/*matcher*/ 2][/*operand*/ 2]; + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 2; ++j) { + MatchOption new_option = option; + new_option.capture = false; + new_option.explain_os = &explanations[i][j]; + matches[i][j] = i == 0 ? op1_.Match(operand(inst, j), new_option) + : op2_.Match(operand(inst, j), new_option); + } + } + + // Check if the match succeeded. + for (int i = 0; i < 2; ++i) { + if (matches[0][i] && matches[1][(i + 1) % 2]) { + // Rerun the matches with capture enabled if necessary. + if (option.capture) { + auto* operand1 = operand(inst, i); + auto* operand2 = operand(inst, (i + 1) % 2); + bool matched = + op1_.Match(operand1, option) && op2_.Match(operand2, option); + DCHECK(matched); + } + return true; + } + } + + auto describe_matcher = [&](int matcher_idx) { + EXPLAIN << "\n - "; + if (matcher_idx == 0) { + op1_.DescribeTo(option.explain_os, /*indent=*/3); + } else { + CHECK_EQ(matcher_idx, 1); + op2_.DescribeTo(option.explain_os, /*indent=*/3); + } + for (int i = 0; i < 2; ++i) { + if (matches[matcher_idx][/*operand*/ i]) { + continue; + } + EXPLAIN << "\ndoes not match " << (i == 0 ? "LHS" : "RHS") << ":\n"; + EXPLAIN << " - "; + EXPLAIN << absl::StrReplaceAll( + explanations[matcher_idx][/*operand*/ i].str(), {{"\n", "\n "}}); + } + }; + + // If we failed to match, one of the following is true: + // 1. op1 (op2) matches neither LHS nor RHS, or + // 2. op1 and op2 both match LHS (RHS), but neither matches RHS (LHS). + // We print different explanations depending on which case we're in. + + // Case 1. + bool wrote_explanation = false; + for (int i = 0; !wrote_explanation && i < 2; ++i) { + if (!matches[i][0] && !matches[i][1]) { + EXPLAIN << "HloInstruction's operands (ignoring order) did not match " + << (i == 0 ? "first" : "second") << " matcher. Specifically,"; + describe_matcher(i); + wrote_explanation = true; + } + } + + // Case 2. + for (int i = 0; !wrote_explanation && i < 2; ++i) { + if (matches[/*matcher*/ 0][/*operand*/ i] && + matches[/*matcher*/ 1][/*operand*/ i]) { + CHECK(!matches[0][(i + 1) % 2]); + CHECK(!matches[1][(i + 1) % 2]); + CHECK(!wrote_explanation); + EXPLAIN << "HloInstruction's " << (i == 1 ? "LHS" : "RHS") + << " operand did not match either of the two matchers. " + "Specifically,"; + describe_matcher(0); + EXPLAIN << "\nand"; + describe_matcher(1); + wrote_explanation = true; + } + } + + CHECK(wrote_explanation); + return false; + } + + HloInstructionPattern op1_; + HloInstructionPattern op2_; +}; + // An HloInstructionPattern implementation that matches only if the instruction // is a fusion node with a particular kind. class HloInstructionPatternFusionKindImpl { @@ -820,14 +1516,32 @@ class HloInstructionPatternFusionKindImpl { : kind_(kind) {} bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { - return inst->opcode() == HloOpcode::kFusion && inst->fusion_kind() == kind_; + return MatchImpl(inst, option); } bool Match(::xla::HloInstruction* inst, MatchOption option) const { - return inst->opcode() == HloOpcode::kFusion && inst->fusion_kind() == kind_; + return MatchImpl(inst, option); + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "with fusion kind " << ToString(kind_); } private: + template + bool MatchImpl(HloInstructionType* inst, MatchOption option) const { + if (inst->opcode() != HloOpcode::kFusion) { + EXPLAIN << "HloInstruction does not have fusion kind " << ToString(kind_) + << "; it's not a fusion"; + return false; + } + if (inst->fusion_kind() != kind_) { + EXPLAIN << "HloInstruction does not have fusion kind " << ToString(kind_); + return false; + } + return true; + } + ::xla::HloInstruction::FusionKind kind_; }; @@ -839,47 +1553,211 @@ class HloInstructionPatternTupleIndexImpl { : tuple_index_(tuple_index) {} bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { - return inst->opcode() == HloOpcode::kGetTupleElement && - inst->tuple_index() == tuple_index_; + return MatchImpl(inst, option); } bool Match(::xla::HloInstruction* inst, MatchOption option) const { - return inst->opcode() == HloOpcode::kGetTupleElement && - inst->tuple_index() == tuple_index_; + return MatchImpl(inst, option); + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "which is a GTE with index " << tuple_index_; } private: + template + bool MatchImpl(HloInstructionType* inst, MatchOption option) const { + if (inst->opcode() != HloOpcode::kGetTupleElement) { + EXPLAIN << "HloInstruction is not a GTE with index " << tuple_index_ + << "; it's not a GTE at all"; + return false; + } + if (inst->tuple_index() != tuple_index_) { + EXPLAIN << "HloInstruction is not a GTE with index " << tuple_index_; + return false; + } + return true; + } + int64 tuple_index_; }; -template -class HloPredicatePatternImpl { +class HloInstructionPatternParameterNumImpl { public: - explicit HloPredicatePatternImpl(Predicate pred) : pred_(std::move(pred)) {} + explicit constexpr HloInstructionPatternParameterNumImpl(int64 parameter_num) + : parameter_num_(parameter_num) {} - bool Match(const ItemType* item, MatchOption option) const { - return pred_(item); + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + return MatchImpl(inst, option); } - bool Match(ItemType* item, MatchOption option) const { return pred_(item); } + bool Match(::xla::HloInstruction* inst, MatchOption option) const { + return MatchImpl(inst, option); + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "which is parameter " << parameter_num_; + } private: - Predicate pred_; + template + bool MatchImpl(HloInstructionType* inst, MatchOption option) const { + if (inst->opcode() != HloOpcode::kParameter || + inst->parameter_number() != parameter_num_) { + EXPLAIN << "HloInstruction is not parameter " << parameter_num_; + return false; + } + return true; + } + + int64 parameter_num_; }; -struct PatternFriend; +// Superclass that contains common code used by Op::WithOneUse() and +// Op::WithOneUser(). +class HloInstructionPatternOneUseOrUserImpl { + protected: + bool MatchOneUser(const HloInstruction* inst, MatchOption option) const { + if (inst->user_count() != 1) { + EXPLAIN << "HloInstruction has " << inst->user_count() + << " users, but expected exactly one."; + if (inst->user_count() > 1) { + EXPLAIN << "\nAll users:"; + for (const HloInstruction* user : inst->users()) { + EXPLAIN << "\n - " << InstToString(user); + } + } + return false; + } + return true; + } +}; + +class HloInstructionPatternOneUseImpl + : public HloInstructionPatternOneUseOrUserImpl { + public: + bool Match(const HloInstruction* inst, MatchOption option) const { + if (!MatchOneUser(inst, option)) { + return false; + } + + int64 use_count = absl::c_count_if( + inst->users()[0]->operands(), + [&](const HloInstruction* operand) { return operand == inst; }); + if (use_count != 1) { + EXPLAIN << "HloInstruction is used " << use_count + << " times by its user, but is expected to be used just once: " + << InstToString(inst->users()[0]); + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "which has exactly one use"; + } +}; + +class HloInstructionPatternOneUserImpl + : public HloInstructionPatternOneUseOrUserImpl { + public: + bool Match(const HloInstruction* inst, MatchOption option) const { + return MatchOneUser(inst, option); + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "which has exactly one user (but possibly is used multiple times by " + "that instruction)"; + } +}; + +// Matches a constant scalar or effective scalar, optionally with a given value. +template +class HloConstantScalarImpl { + public: + explicit constexpr HloConstantScalarImpl(bool match_effective_scalar) + : val_(absl::nullopt), match_effective_scalar_(match_effective_scalar) {} + + constexpr HloConstantScalarImpl(ScalarTy val, bool match_effective_scalar) + : val_(val), match_effective_scalar_(match_effective_scalar) {} + + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + return MatchImpl(inst, option); + } + + bool Match(::xla::HloInstruction* inst, MatchOption option) const { + return MatchImpl(inst, option); + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "which is a constant " + << (match_effective_scalar_ ? "effective " : "") << "scalar"; + if (val_.has_value()) { + *os << " with value " << *val_; + } + } + + private: + template + bool MatchImpl(InstTy* inst, MatchOption option) const { + const auto* const_inst = DynCast(inst); + if (!const_inst) { + EXPLAIN << "HloInstruction is not a constant"; + return false; + } + if (match_effective_scalar_ && + !ShapeUtil::IsEffectiveScalar(inst->shape())) { + EXPLAIN << "HloInstruction is not an effective scalar"; + return false; + } + if (!match_effective_scalar_ && !ShapeUtil::IsScalar(inst->shape())) { + EXPLAIN << "HloInstruction is not a scalar"; + return false; + } + if (!val_.has_value()) { + return true; + } + + // Check that literal == static_cast(val) and + // val == static_cast(literal). This is sufficient to ensure that + // the two constant scalars are actually "equal". + auto val_literal = LiteralUtil::CreateR0(*val_); + auto literal_r0_or = const_inst->literal().Reshape({}); + auto val_as_literal_ty_or = + val_literal.Convert(const_inst->shape().element_type()); + if (!literal_r0_or.ok() || !val_as_literal_ty_or.ok()) { + EXPLAIN << "could not construct relevant Literals (how did this happen?)"; + return false; + } + auto literal_r0 = std::move(literal_r0_or).ValueOrDie(); + auto val_as_literal_ty = std::move(val_as_literal_ty_or).ValueOrDie(); + auto literal_r0_as_val_ty_or = + literal_r0.Convert(val_literal.shape().element_type()); + bool rv = literal_r0_as_val_ty_or.ok() && // + literal_r0_as_val_ty_or.ValueOrDie() == val_literal && + literal_r0 == val_as_literal_ty; + if (!rv) { + EXPLAIN << "HloInstruction's constant value " << literal_r0.ToString() + << " did not match expected value " << *val_; + } + return rv; + } + + absl::optional val_; + bool match_effective_scalar_; +}; // A pattern that matches HloInstructions. template class HloInstructionPattern { private: template - HloInstructionPattern> - AppendImpl(NewImpl new_impl) const { - return HloInstructionPattern< - HloInstructionType, AllOfPattern<::xla::HloInstruction, Impl, NewImpl>>( - AllOf(impl_, std::move(new_impl)), matched_inst_); + auto AppendImpl(NewImpl new_impl) const -> HloInstructionPattern< + HloInstructionType, decltype(AllOf( + std::declval(), std::move(new_impl)))> { + auto new_allof = AllOf(impl_, std::move(new_impl)); + return HloInstructionPattern( + std::move(new_allof), matched_inst_); } public: @@ -895,6 +1773,9 @@ class HloInstructionPattern { } return true; } + if (inst != nullptr) { + EXPLAIN << "\nin " << InstToString(inst); + } return false; } @@ -906,6 +1787,7 @@ class HloInstructionPattern { } return true; } + EXPLAIN << "\nin " << InstToString(inst); return false; } @@ -935,12 +1817,47 @@ class HloInstructionPattern { return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, true)); } + constexpr auto Is(const HloInstruction* instr) const + -> decltype(this->AppendImpl(HloInstructionIsImpl(instr))) { + return AppendImpl(HloInstructionIsImpl(instr)); + } + // Modifies the pattern to match only if the instruction is a constant. constexpr auto IsConstant() const -> decltype(this->WithOpcode(HloOpcode::kConstant)) { return WithOpcode(HloOpcode::kConstant); } + constexpr auto IsConstantScalar() const -> decltype(this->AppendImpl( + HloConstantScalarImpl(/*match_effective_scalar=*/false))) { + return AppendImpl( + HloConstantScalarImpl(/*match_effective_scalar=*/false)); + } + + // This does not check that T has the same type as the instruction, so e.g. + // IsConstantScalar(1.0) may match a constant of shape int32[]. + template + constexpr auto IsConstantScalar(const ScalarTy& val) const + -> decltype(this->AppendImpl(HloConstantScalarImpl( + val, /*match_effective_scalar=*/false))) { + return AppendImpl( + HloConstantScalarImpl(val, /*match_effective_scalar=*/false)); + } + + constexpr auto IsConstantEffectiveScalar() const -> decltype(this->AppendImpl( + HloConstantScalarImpl(/*match_effective_scalar=*/true))) { + return AppendImpl( + HloConstantScalarImpl(/*match_effective_scalar=*/true)); + } + + template + constexpr auto IsConstantEffectiveScalar(const ScalarTy& val) const + -> decltype(this->AppendImpl(HloConstantScalarImpl( + val, /*match_effective_scalar=*/true))) { + return AppendImpl( + HloConstantScalarImpl(val, /*match_effective_scalar=*/true)); + } + // Modifies the pattern to match only if the instruction is not a constant. constexpr auto IsNonConstant() const -> decltype(this->WithoutOpcode(HloOpcode::kConstant)) { @@ -957,6 +1874,22 @@ class HloInstructionPattern { HloInstructionPatternShapeImpl(shape)); } + // Make this a templated function to work around gcc 4.9.4 template infinite + // recursion bug. + template + constexpr auto WithShapeEqualTo(const ::xla::Shape* shape) + -> decltype(this->WithShape(Shape().EqualTo(shape))) { + return WithShape(Shape().EqualTo(shape)); + } + + // Make this a templated function to work around gcc 4.9.4 template infinite + // recursion bug. + template + constexpr auto WithShapeCompatibleTo(const ::xla::Shape* shape) + -> decltype(this->WithShape(Shape().CompatibleTo(shape))) { + return WithShape(Shape().CompatibleTo(shape)); + } + // Modifies the pattern to match only if the instruction has an operand that // matches the given pattern. template @@ -971,6 +1904,20 @@ class HloInstructionPattern { operand_index, operand)); } + template + constexpr auto WithBinaryOperandsAnyOrder( + const HloInstructionPattern& op1, + const HloInstructionPattern& op2) const + -> decltype(this->AppendImpl( + HloInstructionPatternBinaryOperandsAnyOrderImpl< + OperandType1, OperandImpl1, OperandType2, OperandImpl2>(op1, + op2))) { + return AppendImpl( + HloInstructionPatternBinaryOperandsAnyOrderImpl< + OperandType1, OperandImpl1, OperandType2, OperandImpl2>(op1, op2)); + } + // Modifies the pattern to match only if the instruction is a fusion node with // the given kind. constexpr auto WithFusionKind(HloInstruction::FusionKind kind) const @@ -985,17 +1932,34 @@ class HloInstructionPattern { return AppendImpl(HloInstructionPatternTupleIndexImpl(tuple_index)); } - private: - template - constexpr auto WithPredicate(Predicate pred) const -> decltype( - this->AppendImpl(HloPredicatePatternImpl( - std::move(pred)))) { - return AppendImpl( - HloPredicatePatternImpl(std::move(pred))); + // Modifies the pattern to match only if the instruction is a parameter + // with the given parameter number. + constexpr auto WithParameterNum(int64 parameter_num) const -> decltype( + this->AppendImpl(HloInstructionPatternParameterNumImpl(parameter_num))) { + return AppendImpl(HloInstructionPatternParameterNumImpl(parameter_num)); } - friend struct PatternFriend; + // Modifies the pattern to match if the instruction is used exactly once. + // Does not match if the instruction is used twice by the same user (e.g. + // multiply(x,x)). + constexpr auto WithOneUse() const + -> decltype(this->AppendImpl(HloInstructionPatternOneUseImpl())) { + return AppendImpl(HloInstructionPatternOneUseImpl()); + } + // Modifies the pattern to match if the instruction is used by exactly one + // other instruction. Will match if the instruction is used twice, so long as + // it's by the same user (e.g. multiply(x,x)). + constexpr auto WithOneUser() const + -> decltype(this->AppendImpl(HloInstructionPatternOneUserImpl())) { + return AppendImpl(HloInstructionPatternOneUserImpl()); + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + impl_.DescribeTo(os, indent); + } + + private: Impl impl_; HloInstructionType** matched_inst_; }; @@ -1036,6 +2000,7 @@ Op(::xla::HloInstruction** matched_inst) { XLA_NULLOP_PATTERN(Constant) XLA_NULLOP_PATTERN(Parameter) XLA_NULLOP_PATTERN(Iota) +XLA_NULLOP_PATTERN(Rng) #undef XLA_NULLOP_PATTERN // Helpers for unary instructions. @@ -1067,8 +2032,10 @@ XLA_UNOP_PATTERN(RoundNearestAfz) XLA_UNOP_PATTERN(Bitcast) XLA_UNOP_PATTERN(Broadcast) XLA_UNOP_PATTERN(Ceil) +XLA_UNOP_PATTERN(Convert) XLA_UNOP_PATTERN(Copy) XLA_UNOP_PATTERN(Cos) +XLA_UNOP_PATTERN(CrossReplicaSum) XLA_UNOP_PATTERN(Exp) XLA_UNOP_PATTERN(Fft) XLA_UNOP_PATTERN(Floor) @@ -1088,6 +2055,7 @@ XLA_UNOP_PATTERN(Reverse) XLA_UNOP_PATTERN(SendDone) XLA_UNOP_PATTERN(Sign) XLA_UNOP_PATTERN(Sin) +XLA_UNOP_PATTERN(Slice) XLA_UNOP_PATTERN(Sort) XLA_UNOP_PATTERN(Tanh) XLA_UNOP_PATTERN(Transpose) @@ -1125,25 +2093,32 @@ XLA_UNOP_PATTERN(Transpose) #define XLA_COMMUTATIVE_BINOP_PATTERN(NAME) \ XLA_BINOP_PATTERN(NAME) \ \ - template \ - inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) \ - ->decltype(AnyOf(NAME(lhs, rhs), NAME(rhs, lhs))) { \ - return AnyOf(NAME(lhs, rhs), NAME(rhs, lhs)); \ - } \ - \ template \ inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \ Rhs&& rhs) \ - ->decltype(AnyOf(NAME(matched_inst, lhs, rhs), \ - NAME(matched_inst, rhs, lhs))) { \ - return AnyOf(NAME(matched_inst, lhs, rhs), \ - NAME(matched_inst, rhs, lhs)); \ + ->decltype(Op(matched_inst) \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithBinaryOperandsAnyOrder(std::forward(lhs), \ + std::forward(rhs))) { \ + return Op(matched_inst) \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithBinaryOperandsAnyOrder(std::forward(lhs), \ + std::forward(rhs)); \ + } \ + template \ + inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) \ + ->decltype(NAME##AnyOrder( \ + nullptr, std::forward(lhs), std::forward(rhs))) { \ + return NAME##AnyOrder( \ + nullptr, std::forward(lhs), std::forward(rhs)); \ } XLA_COMMUTATIVE_BINOP_PATTERN(Add) XLA_BINOP_PATTERN(Atan2) XLA_BINOP_PATTERN(Divide) XLA_BINOP_PATTERN(Complex) +XLA_BINOP_PATTERN(Convolution) XLA_BINOP_PATTERN(Dot) +XLA_BINOP_PATTERN(DynamicSlice) XLA_COMMUTATIVE_BINOP_PATTERN(Eq) XLA_BINOP_PATTERN(Gather) XLA_BINOP_PATTERN(Ge) @@ -1155,7 +2130,9 @@ XLA_COMMUTATIVE_BINOP_PATTERN(Minimum) XLA_COMMUTATIVE_BINOP_PATTERN(Multiply) XLA_COMMUTATIVE_BINOP_PATTERN(Ne) XLA_BINOP_PATTERN(Outfeed) +XLA_BINOP_PATTERN(Pad) XLA_BINOP_PATTERN(Power) +XLA_BINOP_PATTERN(ReduceWindow) XLA_BINOP_PATTERN(Remainder) XLA_BINOP_PATTERN(Send) XLA_BINOP_PATTERN(Subtract) @@ -1202,6 +2179,7 @@ XLA_BINOP_PATTERN(ShiftRightLogical) .WithOperand(2, std::forward(arg2)); \ } XLA_TERNOP_PATTERN(Clamp); +XLA_TERNOP_PATTERN(Scatter); XLA_TERNOP_PATTERN(Select); #undef XLA_TERNOP_PATTERN @@ -1255,31 +2233,10 @@ inline auto WithOperands(Matcher&& m, int64 operand_num, FirstArg&& first_arg, // We could implement all ops as "variadic" ops, but it would make the // already-bad compile errors even worse. XLA_VARIADIC_OP_PATTERN(Concatenate); +XLA_VARIADIC_OP_PATTERN(CustomCall); +XLA_VARIADIC_OP_PATTERN(Map) XLA_VARIADIC_OP_PATTERN(Reduce); - -namespace detail { -struct PatternFriend { - template - static auto ConstantScalar(T constant) -> decltype( - Constant() - .WithShape(match::Shape().IsScalar()) - .WithPredicate( - std::declval>())) { - std::function pred = - [constant](const HloInstruction* instr) { - const auto& literal = Cast(instr)->literal(); - auto status_or_const = LiteralUtil::CreateR0(constant).Convert( - literal.shape().element_type()); - return status_or_const.ok() && - literal == status_or_const.ConsumeValueOrDie(); - }; - - return Constant() - .WithShape(match::Shape().IsScalar()) - .WithPredicate(std::move(pred)); - } -}; -} // namespace detail +XLA_VARIADIC_OP_PATTERN(Tuple); // Helpers for matching non-constant instructions. inline auto NonConstant() -> decltype(Op().IsNonConstant()) { @@ -1318,14 +2275,71 @@ inline auto GetTupleElement(HloInstructionType** matched_inst, Arg&& arg, .WithTupleIndex(tuple_index); } -template -inline auto ConstantScalar(T constant) - -> decltype(detail::PatternFriend::ConstantScalar(constant)) { - return detail::PatternFriend::ConstantScalar(constant); +// Add overloads for Parameter which take an int64 specifying the parameter +// number. +inline auto Parameter(int64 parameter_num) -> decltype( + Op().WithOpcode(HloOpcode::kParameter).WithParameterNum(parameter_num)) { + return Op().WithOpcode(HloOpcode::kParameter).WithParameterNum(parameter_num); +} +template +inline auto Parameter(HloInstructionType** matched_inst, int64 parameter_num) + -> decltype(Op(matched_inst) + .WithOpcode(HloOpcode::kParameter) + .WithParameterNum(parameter_num)) { + return Op(matched_inst) + .WithOpcode(HloOpcode::kParameter) + .WithParameterNum(parameter_num); +} + +inline auto ConstantScalar() -> decltype(Op().IsConstantScalar()) { + return Op().IsConstantScalar(); +} + +template +inline auto ConstantScalar(HloInstructionType** matched_inst) + -> decltype(Op(matched_inst).IsConstantScalar()) { + return Op(matched_inst).IsConstantScalar(); +} + +template +inline auto ConstantScalar(ScalarTy val) + -> decltype(Op().IsConstantScalar(val)) { + return Op().IsConstantScalar(val); +} + +template +inline auto ConstantScalar(HloInstructionType** matched_inst, ScalarTy val) + -> decltype(Op(matched_inst).IsConstantScalar(val)) { + return Op(matched_inst).IsConstantScalar(val); +} + +inline auto ConstantEffectiveScalar() -> decltype(Op().IsConstantScalar()) { + return Op().IsConstantEffectiveScalar(); +} + +template +inline auto ConstantEffectiveScalar(HloInstructionType** matched_inst) + -> decltype(Op(matched_inst).IsConstantScalar()) { + return Op(matched_inst).IsConstantEffectiveScalar(); +} + +template +inline auto ConstantEffectiveScalar(ScalarTy val) + -> decltype(Op().IsConstantEffectiveScalar(val)) { + return Op().IsConstantEffectiveScalar(val); +} + +template +inline auto ConstantEffectiveScalar(HloInstructionType** matched_inst, + ScalarTy val) + -> decltype(Op(matched_inst).IsConstantEffectiveScalar(val)) { + return Op(matched_inst).IsConstantEffectiveScalar(val); } } // namespace match } // namespace xla +#undef EXPLAIN +#pragma pop_macro("EXPLAIN") #endif // TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_ diff --git a/tensorflow/compiler/xla/service/pattern_matcher_gmock.h b/tensorflow/compiler/xla/service/pattern_matcher_gmock.h new file mode 100644 index 00000000000..8fe2d10a11b --- /dev/null +++ b/tensorflow/compiler/xla/service/pattern_matcher_gmock.h @@ -0,0 +1,92 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_GMOCK_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_GMOCK_H_ + +#include +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { + +namespace pattern_matcher_gmock_detail { +template +class GmockMatcher { + public: + explicit GmockMatcher(Pattern p) : pattern_(std::move(p)) {} + + // In service of better error messages, list out the overloads explicitly + // rather than just using a template. gMock's polymorphism plus + // pattern_matcher yields some pretty gnarly stuff. + bool MatchAndExplain(const Layout& l, + ::testing::MatchResultListener* listener) const { + return MatchAndExplainImpl(&l, listener); + } + bool MatchAndExplain(const Layout* l, + ::testing::MatchResultListener* listener) const { + return MatchAndExplainImpl(l, listener); + } + + bool MatchAndExplain(const Shape& s, + ::testing::MatchResultListener* listener) const { + return MatchAndExplainImpl(&s, listener); + } + bool MatchAndExplain(const Shape* s, + ::testing::MatchResultListener* listener) const { + return MatchAndExplainImpl(s, listener); + } + + bool MatchAndExplain(const HloInstruction& instr, + ::testing::MatchResultListener* listener) const { + return MatchAndExplainImpl(&instr, listener); + } + bool MatchAndExplain(const HloInstruction* instr, + ::testing::MatchResultListener* listener) const { + return MatchAndExplainImpl(instr, listener); + } + + void DescribeTo(std::ostream* os) const { pattern_.DescribeTo(os); } + + void DescribeNegationTo(std::ostream* os) const { + *os << "is NOT: "; + DescribeTo(os); + } + + private: + template + bool MatchAndExplainImpl(const T* t, + ::testing::MatchResultListener* listener) const { + MatchOption options{/*.capture=*/true, /*.explain_os=*/listener->stream()}; + return Match(t, pattern_, options); + } + + Pattern pattern_; +}; +} // namespace pattern_matcher_gmock_detail + +template +::testing::PolymorphicMatcher< + pattern_matcher_gmock_detail::GmockMatcher> +GmockMatch(Pattern&& p) { + return ::testing::MakePolymorphicMatcher( + pattern_matcher_gmock_detail::GmockMatcher( + std::forward(p))); +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_GMOCK_H_ diff --git a/tensorflow/compiler/xla/service/pattern_matcher_gmock_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_gmock_test.cc new file mode 100644 index 00000000000..9ca2fb05c1f --- /dev/null +++ b/tensorflow/compiler/xla/service/pattern_matcher_gmock_test.cc @@ -0,0 +1,76 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +namespace m = ::xla::match; +using ::testing::Eq; +using ::testing::Not; + +template +string Describe(const ::testing::Matcher& m) { + std::stringstream ss; + m.DescribeTo(&ss); + return ss.str(); +} + +template +string Explain( + const MatchedTy& val, + const ::testing::Matcher::type>& m) { + ::testing::StringMatchResultListener listener; + EXPECT_THAT(val, ::testing::Not(m)); // For the error message. + EXPECT_FALSE(m.MatchAndExplain(val, &listener)); + return listener.str(); +} + +// This file tests the GmockMatch function. The actual explanation and +// description returned by matchers is tested in pattern_matchers_test. +TEST(PatternMatcherGmock, MatchShape) { + Shape s = ShapeUtil::MakeShape(F32, {10, 100}); + // You can pass const Shape& or a const Shape*. + EXPECT_THAT(s, GmockMatch(m::Shape())); + EXPECT_THAT(&s, Not(GmockMatch(m::Shape().WithElementType(F16)))); + EXPECT_THAT(Describe(GmockMatch(m::Shape().IsArray())), + "a shape that represents an array"); +} + +TEST(PatternMatcherGmock, MatchLayout) { + Layout l = LayoutUtil::MakeLayout({0, 1}); + EXPECT_THAT(l, GmockMatch(m::Layout())); + EXPECT_THAT(&l, Not(GmockMatch(m::Layout().WithSparseFormat()))); + EXPECT_THAT(Describe(GmockMatch(m::Layout().WithSparseFormat())), + "a layout with format SPARSE"); +} + +TEST(PatternMatchGmock, MatchInstruction) { + auto instr = + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {42}), "p"); + EXPECT_THAT(instr.get(), GmockMatch(m::Parameter())); + EXPECT_THAT(*instr, GmockMatch(m::Parameter(0))); + EXPECT_THAT(*instr, Not(GmockMatch(m::Parameter(1)))); + EXPECT_THAT(Describe(GmockMatch(m::Parameter())), + "an HloInstruction with opcode parameter"); +} + +} // anonymous namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc index 3f74273517a..186ef0c7911 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher_test.cc +++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc @@ -14,14 +14,18 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/core/platform/test.h" namespace xla { namespace { +namespace m = match; + TEST(PatternMatcherTest, AddOp) { constexpr char kModuleStr[] = R"(HloModule two_plus_two_module ENTRY %two_plus_two_computation () -> f32[] { @@ -229,23 +233,74 @@ TEST(PatternMatcherTest, AnyOf) { } TEST(PatternMatcherTest, ConstantScalar) { + using match::ConstantEffectiveScalar; + using match::ConstantScalar; + using match::Op; + using match::Tuple; + constexpr char kModuleStr[] = R"( - HloModule test_module ENTRY test { ROOT constant = f16[] constant(42) })"; + HloModule test_module + ENTRY test { + a = s32[] constant(1) + b = s32[1,1] constant(s32[1,1]{{2}}) + c = s32[1,2] constant(s32[1,2]{{2,2}}) + d = f32[] constant(1) + e = f32[] constant(1.25) + ROOT tuple = (s32[], s32[1,1], s32[1,2], f32[], f32[]) tuple(a,b,c,d,e) + })"; TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); auto* root = hlo_module->entry_computation()->root_instruction(); - EXPECT_TRUE(Match(root, match::ConstantScalar(42))); - EXPECT_FALSE(Match(root, match::ConstantScalar(41))); - EXPECT_FALSE(Match(root, match::ConstantScalar(0))); -} + const HloInstruction* a = root->operand(0); + const HloInstruction* b = root->operand(1); + const HloInstruction* c = root->operand(2); + const HloInstruction* d = root->operand(3); + const HloInstruction* e = root->operand(4); + EXPECT_TRUE(Match(a, ConstantScalar())); + EXPECT_TRUE(Match(a, ConstantScalar(1))); + EXPECT_TRUE(Match(a, ConstantEffectiveScalar())); + EXPECT_TRUE(Match(a, ConstantEffectiveScalar(1))); + EXPECT_FALSE(Match(a, ConstantScalar(2))); + EXPECT_FALSE(Match(a, ConstantScalar(2.01))); + EXPECT_FALSE(Match(a, ConstantEffectiveScalar(2))); + EXPECT_FALSE(Match(a, ConstantEffectiveScalar(1.01))); -TEST(PatternMatcherTest, NoMatchConstantScalar) { - constexpr char kModuleStr[] = R"( - HloModule test_module ENTRY test { ROOT v = f16[] parameter(0) })"; - TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); - auto* root = hlo_module->entry_computation()->root_instruction(); + EXPECT_FALSE(Match(b, ConstantScalar())); + EXPECT_FALSE(Match(b, ConstantScalar(2))); + EXPECT_TRUE(Match(b, ConstantEffectiveScalar())); + EXPECT_TRUE(Match(b, ConstantEffectiveScalar(2))); - EXPECT_FALSE(Match(root, match::ConstantScalar(42))); + EXPECT_FALSE(Match(c, ConstantScalar())); + EXPECT_FALSE(Match(c, ConstantScalar(2))); + EXPECT_FALSE(Match(c, ConstantEffectiveScalar())); + EXPECT_FALSE(Match(c, ConstantEffectiveScalar(2))); + + EXPECT_TRUE(Match(d, ConstantScalar(1))); + EXPECT_TRUE(Match(d, ConstantEffectiveScalar(1))); + EXPECT_TRUE(Match(d, ConstantScalar(1.0))); + EXPECT_TRUE(Match(d, ConstantEffectiveScalar(1.0))); + + EXPECT_TRUE(Match(e, ConstantScalar(1.25f))); + EXPECT_TRUE(Match(e, ConstantScalar(1.25))); + EXPECT_TRUE(Match(e, ConstantEffectiveScalar(1.25))); + EXPECT_FALSE(Match(e, ConstantScalar(1))); + EXPECT_FALSE(Match(e, ConstantEffectiveScalar(1))); + + const HloInstruction* instr = nullptr; + EXPECT_TRUE(Match(a, ConstantScalar(&instr))); + EXPECT_EQ(instr, a); + + instr = nullptr; + EXPECT_TRUE(Match(a, ConstantScalar(&instr, 1))); + EXPECT_EQ(instr, a); + + instr = nullptr; + EXPECT_TRUE(Match(a, ConstantEffectiveScalar(&instr))); + EXPECT_EQ(instr, a); + + instr = nullptr; + EXPECT_TRUE(Match(a, ConstantEffectiveScalar(&instr, 1))); + EXPECT_EQ(instr, a); } TEST(PatternMatcherTest, MultiplyAnyOrder) { @@ -267,6 +322,15 @@ TEST(PatternMatcherTest, MultiplyAnyOrder) { root, MultiplyAnyOrder(&instr, ConstantScalar(42), ConstantScalar(52)))); EXPECT_TRUE(Match( root, MultiplyAnyOrder(&instr, ConstantScalar(52), ConstantScalar(42)))); + + // Check that MultiplyAnyOrder exposes the same API as Op(), so we can call + // e.g. IsNonConstant() on it. + EXPECT_TRUE(Match( + root, MultiplyAnyOrder(&instr, ConstantScalar(42), ConstantScalar(52)) + .IsNonConstant())); + EXPECT_TRUE( + Match(root, MultiplyAnyOrder(ConstantScalar(42), ConstantScalar(52)) + .IsNonConstant())); } TEST(PatternMatcherTest, AnyOfShortCircuit) { @@ -315,14 +379,22 @@ TEST(PatternMatcherTest, AllOf) { TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); auto* root = hlo_module->entry_computation()->root_instruction(); + auto f16_scalar = ShapeUtil::MakeShape(F16, {}); + auto f16_pattern = Constant().WithShapeEqualTo(&f16_scalar); + auto f16_compatible_pattern = Constant().WithShapeCompatibleTo(&f16_scalar); auto scalar_pattern = Constant().WithShape(match::Shape().IsScalar()); - auto f16_pattern = Constant().WithShape(match::Shape().WithElementType(F16)); ASSERT_TRUE(Match(root, scalar_pattern)); ASSERT_TRUE(Match(root, f16_pattern)); - EXPECT_TRUE(Match(root, AllOf(scalar_pattern, f16_pattern))); - EXPECT_TRUE(Match(root, AllOf(f16_pattern, scalar_pattern))); + ASSERT_TRUE(Match(root, f16_compatible_pattern)); + EXPECT_TRUE(Match(root, AllOf(scalar_pattern, f16_pattern, + f16_compatible_pattern))); + EXPECT_TRUE( + Match(root, AllOf(f16_pattern, f16_compatible_pattern, + scalar_pattern))); EXPECT_FALSE( Match(root, AllOf(Broadcast(Op()), f16_pattern))); + EXPECT_FALSE(Match( + root, AllOf(Broadcast(Op()), f16_compatible_pattern))); EXPECT_FALSE( Match(root, AllOf(Broadcast(Op()), scalar_pattern))); } @@ -431,5 +503,433 @@ TEST(PatternMatcherTest, TestConcat) { Reshape(ConstantScalar(4))))); } +template +string Description(const Pattern& pattern) { + std::stringstream ss; + pattern.DescribeTo(&ss); + return ss.str(); +} + +template +string Explanation(Elem* elem, const Pattern& pattern) { + std::stringstream ss; + MatchOption options{/*.capture=*/true, /*.explain_os=*/&ss}; + Match(elem, pattern, options); + return ss.str(); +} +template +string Explanation(const std::unique_ptr& elem, const Pattern& pattern) { + return Explanation(elem.get(), pattern); +} +template +string Explanation(const Elem& elem, const Pattern& pattern) { + return Explanation(&elem, pattern); +} + +// Helper macro for checking a pattern's description and the explanation printed +// when attempting to match (and presumably failing) on a given object. +// +// We use a macro rather than a function because we want good line numbers in +// errors. We use this rather than writing a helper that returns a pair of +// (description, explanation) and doing something like +// +// EXPECT_THAT(DescAndExplanation(...), ::testing::Pair(..., ...)); +// +// because EXPECT_EQ prints a unified diff if multiline string comparison fails, +// while EXPECT_THAT does not. This unified diff makes the errors much easier +// to read. +#define EXPECT_DESC_AND_EXPLANATION(elem, pattern, expected_desc, \ + expected_explanation) \ + do { \ + EXPECT_EQ(Description(pattern), (expected_desc)); \ + EXPECT_EQ(Explanation((elem), (pattern)), expected_explanation); \ + } while (0) + +TEST(PatternMatcherTest, LayoutDescribeToAndExplain) { + auto layout = LayoutUtil::MakeLayout({1, 2}); + auto layout2 = LayoutUtil::MakeLayout({2, 2}); + + EXPECT_DESC_AND_EXPLANATION(static_cast(nullptr), m::Layout(), + "a layout", "Layout is null"); + EXPECT_DESC_AND_EXPLANATION(layout2, m::Layout().EqualTo(&layout), + "a layout equal to {1,2}", + "Layout {2,2} is not equal to expected {1,2}"); + EXPECT_DESC_AND_EXPLANATION(layout2, m::Layout().WithSparseFormat(), + "a layout with format SPARSE", + "Layout has format DENSE but expected SPARSE"); + EXPECT_DESC_AND_EXPLANATION(layout, + m::Layout().EqualTo(&layout).WithSparseFormat(), + "a layout:\n" + " * equal to {1,2} AND\n" + " * with format SPARSE", + "Layout has format DENSE but expected SPARSE"); +} + +TEST(PatternMatcherTest, ShapeDescribeToAndExplain) { + auto shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {0, 1}); + auto layout = shape.layout(); + + EXPECT_DESC_AND_EXPLANATION(static_cast(nullptr), m::Shape(), + "a shape", "Shape is null"); + EXPECT_DESC_AND_EXPLANATION( + ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {1, 0}), + m::Shape().EqualTo(&shape), "a shape equal to f32[1,2]{0,1}", + "Shape not equal to f32[1,2]{0,1}\n" + "in f32[1,2]{1,0}"); + EXPECT_DESC_AND_EXPLANATION(ShapeUtil::MakeShape(F32, {2, 2}), + m::Shape().CompatibleTo(&shape), + "a shape compatible with f32[1,2]", + "Shape not compatible with f32[1,2]\n" + "in f32[2,2]{1,0}"); + EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithElementType(F16), + "a shape with element type F16", + "Shape does not have element type F16\n" + "in f32[1,2]{0,1}"); + EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().IsScalar(), + "a shape that represents a scalar", + "Shape is not a scalar\n" + "in f32[1,2]{0,1}"); + EXPECT_DESC_AND_EXPLANATION(ShapeUtil::MakeNil(), m::Shape().IsArray(), + "a shape that represents an array", + "Shape is not an array\n" + "in ()"); + EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().IsTuple(), + "a shape that represents a tuple", + "Shape is not a tuple\n" + "in f32[1,2]{0,1}"); + EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().IsEffectiveScalar(), + "a shape that is an effective scalar", + "Shape is not an effective scalar\n" + "in f32[1,2]{0,1}"); + EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithRank(42), + "a shape that has 42 dimensions", + "Shape does not have rank 42\n" + "in f32[1,2]{0,1}"); + EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithRank(0), + "a shape that is a scalar", + "Shape is not a scalar\n" + "in f32[1,2]{0,1}"); + EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithRank(1).IsArray(), + "a shape:\n" + " * that has 1 dimension AND\n" + " * that represents an array", + "Shape does not have rank 1\n" + "in f32[1,2]{0,1}"); + EXPECT_DESC_AND_EXPLANATION(ShapeUtil::MakeNil(), + m::Shape().IsArray().WithRank(1), + "a shape:\n" + " * that represents an array AND\n" + " * that has 1 dimension", + "Shape is not an array\n" + "in ()"); + EXPECT_DESC_AND_EXPLANATION( + ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {1, 0}), + m::Shape().WithLayoutEqualTo(&layout), + "a shape with\n a layout equal to {0,1}", + "Layout {1,0} is not equal to expected {0,1}\n" + "in f32[1,2]{1,0}"); + EXPECT_DESC_AND_EXPLANATION( + shape, m::Shape().WithLayout(m::Layout().WithSparseFormat()), + "a shape with\n a layout with format SPARSE", + "Layout has format DENSE but expected SPARSE\n" + "in f32[1,2]{0,1}"); + EXPECT_DESC_AND_EXPLANATION(shape, + m::Shape().WithSubshapeEqualTo({10}, &shape), + "a shape with subshape at index {10} which is\n" + " a shape equal to f32[1,2]{0,1}", + "No subshape at {10}\n" + "in f32[1,2]{0,1}"); + EXPECT_DESC_AND_EXPLANATION( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2, 2})}), + m::Shape().WithSubshapeEqualTo({0}, &shape), + "a shape with subshape at index {0} which is\n" + " a shape equal to f32[1,2]{0,1}", + "Shape not equal to f32[1,2]{0,1}\n" + "in f32[2,2]{1,0}\n" + "in subshape at {0}\n" + "in (f32[2,2])"); + EXPECT_DESC_AND_EXPLANATION(shape, + m::Shape().WithSubshapeCompatibleTo({10}, &shape), + "a shape with subshape at index {10} which is\n" + " a shape compatible with f32[1,2]", + "No subshape at {10}\n" + "in f32[1,2]{0,1}"); + EXPECT_DESC_AND_EXPLANATION( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2, 2})}), + m::Shape().WithSubshapeCompatibleTo({0}, &shape), + "a shape with subshape at index {0} which is\n" + " a shape compatible with f32[1,2]", + "Shape not compatible with f32[1,2]\n" + "in f32[2,2]{1,0}\n" + "in subshape at {0}\n" + "in (f32[2,2])"); + EXPECT_DESC_AND_EXPLANATION( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeTupleShape({shape})}), + m::Shape().WithSubshape({0, 0}, m::Shape().IsScalar()), + "a shape with subshape at index {0,0} which is\n" + " a shape that represents a scalar", + "Shape is not a scalar\n" + "in f32[1,2]{0,1}\n" + "in subshape at {0,0}\n" + "in ((f32[1,2]))"); +} + +std::unique_ptr SetName(absl::string_view name, + std::unique_ptr instr) { + instr->SetAndSanitizeName(string(name)); + return instr; +} + +TEST(PatternMatcherTest, HloInstructionDescribeToAndExplain) { + std::unique_ptr iota = + SetName("i", HloInstruction::CreateIota(ShapeUtil::MakeShape(S32, {42}), + /*iota_dimension=*/0)); + std::unique_ptr constant = + SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + + EXPECT_DESC_AND_EXPLANATION(static_cast(nullptr), + m::Op(), "an HloInstruction", + "HloInstruction* is null"); + EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithName("foo"), + "an HloInstruction named \"foo\"", + "HloInstruction not named \"foo\"\n" + "in i = s32[42]{0} iota(), iota_dimension=0"); + EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithOpcode(HloOpcode::kAdd), + "an HloInstruction with opcode add", + "HloInstruction doesn't have opcode add\n" + "in i = s32[42]{0} iota(), iota_dimension=0"); + EXPECT_DESC_AND_EXPLANATION( + constant, m::Op().IsNonConstant(), + "an HloInstruction with any opcode other than constant", + "HloInstruction has opcode constant, expected anything else\n" + "in c = s32[] constant(0)"); + EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithNumOperands(42), + "an HloInstruction with 42 operands", + "HloInstruction doesn't have 42 operands\n" + "in i = s32[42]{0} iota(), iota_dimension=0"); + EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithShape(m::Shape().IsTuple()), + "an HloInstruction outputting\n" + " a shape that represents a tuple", + "Shape is not a tuple\n" + "in s32[42]{0}\n" + "in output shape\n" + "in i = s32[42]{0} iota(), iota_dimension=0"); + EXPECT_DESC_AND_EXPLANATION( + iota, m::Op().WithOperand(2, m::Op().WithOpcode(HloOpcode::kAdd)), + "an HloInstruction with operand 2 which is:\n" + " an HloInstruction with opcode add", + "desired operand index 2 is out of bounds\n" + "in i = s32[42]{0} iota(), iota_dimension=0"); + + EXPECT_DESC_AND_EXPLANATION( + SetName("a", HloInstruction::CreateBinary(ShapeUtil::MakeShape(S32, {}), + HloOpcode::kAdd, constant.get(), + constant.get())), + m::Op().WithOperand(1, m::Op().IsNonConstant()), + "an HloInstruction with operand 1 which is:\n" + " an HloInstruction with any opcode other than constant", + "HloInstruction has opcode constant, expected anything else\n" + "in c = s32[] constant(0)\n" + "in operand 1\n" + "in a = s32[] add(s32[] c, s32[] c)"); + EXPECT_DESC_AND_EXPLANATION( + iota, m::Op().WithFusionKind(HloInstruction::FusionKind::kLoop), + "an HloInstruction with fusion kind kLoop", + "HloInstruction does not have fusion kind kLoop; it's not a fusion\n" + "in i = s32[42]{0} iota(), iota_dimension=0"); + EXPECT_DESC_AND_EXPLANATION( + iota, m::Op().WithTupleIndex(42), + "an HloInstruction which is a GTE with index 42", + "HloInstruction is not a GTE with index 42; it's not a GTE at all\n" + "in i = s32[42]{0} iota(), iota_dimension=0"); + EXPECT_DESC_AND_EXPLANATION(iota, m::Op().IsConstantScalar(), + "an HloInstruction which is a constant scalar", + "HloInstruction is not a constant\n" + "in i = s32[42]{0} iota(), iota_dimension=0"); + EXPECT_DESC_AND_EXPLANATION( + SetName("c", HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 2}))), + m::Op().IsConstantEffectiveScalar(), + "an HloInstruction which is a constant effective scalar", + "HloInstruction is not an effective scalar\n" + "in c = s32[2]{0} constant({1, 2})"); + EXPECT_DESC_AND_EXPLANATION( + SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(10))), + m::Op().IsConstantScalar(42), + "an HloInstruction which is a constant scalar with value 42", + "HloInstruction's constant value 10 did not match expected value 42\n" + "in c = s32[] constant(10)"); + EXPECT_DESC_AND_EXPLANATION( + SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.25))), + m::Op().IsConstantEffectiveScalar(1.25), + "an HloInstruction which is a constant effective scalar with value 1.25", + "HloInstruction's constant value 2.25 did not match expected value 1.25\n" + "in c = f64[] constant(2.25)"); + EXPECT_DESC_AND_EXPLANATION( + constant, m::Op().Is(iota.get()), + absl::StrCat("an HloInstruction which is 0x", absl::Hex(iota.get()), + " (i = s32[42]{0} iota(), iota_dimension=0)"), + absl::StrCat("HloInstruction 0x", absl::Hex(constant.get()), " is not 0x", + absl::Hex(iota.get()), + " (i = s32[42]{0} iota(), iota_dimension=0)\n" + "in c = s32[] constant(0)")); +} + +TEST(PatternMatcherTest, HloInstructionMatcherAnyOrderDescribeTo) { + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + EXPECT_DESC_AND_EXPLANATION( + SetName("a", HloInstruction::CreateBinary( + scalar_s32, HloOpcode::kAdd, + SetName("b", HloInstruction::CreateConstant( + LiteralUtil::CreateR0(0))) + .get(), + SetName("c", HloInstruction::CreateConstant( + LiteralUtil::CreateR0(0))) + .get())), + m::AddAnyOrder(m::Op().WithName("b"), m::Op().WithName("bar")), + "an HloInstruction:\n" + " * with opcode add AND\n" + " * with two operands in either order:\n" + " - an HloInstruction named \"b\"\n" + " - an HloInstruction named \"bar\"", + "HloInstruction's operands (ignoring order) did not match second " + "matcher. Specifically,\n" + " - an HloInstruction named \"bar\"\n" + "does not match LHS:\n" + " - HloInstruction not named \"bar\"\n" + " in b = s32[] constant(0)\n" + "does not match RHS:\n" + " - HloInstruction not named \"bar\"\n" + " in c = s32[] constant(0)\n" + "in a = s32[] add(s32[] b, s32[] c)"); + + EXPECT_DESC_AND_EXPLANATION( + SetName("a", + HloInstruction::CreateBinary( + scalar_s32, HloOpcode::kAdd, + HloInstruction::CreateParameter(0, scalar_s32, "p").get(), + SetName("c", HloInstruction::CreateConstant( + LiteralUtil::CreateR0(0))) + .get())), + m::AddAnyOrder(m::Op().IsConstantScalar(), m::Op().IsConstant()), + "an HloInstruction:\n" + " * with opcode add AND\n" + " * with two operands in either order:\n" + " - an HloInstruction which is a constant scalar\n" + " - an HloInstruction with opcode constant", + "HloInstruction's LHS operand did not match either of the two matchers. " + "Specifically,\n" + " - an HloInstruction which is a constant scalar\n" + "does not match LHS:\n" + " - HloInstruction is not a constant\n" + " in p = s32[] parameter(0)\n" + "and\n" + " - an HloInstruction with opcode constant\n" + "does not match LHS:\n" + " - HloInstruction doesn't have opcode constant\n" + " in p = s32[] parameter(0)\n" + "in a = s32[] add(s32[] p, s32[] c)"); +} + +TEST(PatternMatcherTest, AnyOfMatcherDescribeToAndExplain) { + EXPECT_DESC_AND_EXPLANATION( + SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))), + m::AnyOf(m::Op().WithName("foo"), + m::Op().WithName("bar")), + "any of:\n" + " - an HloInstruction named \"foo\" OR\n" + " - an HloInstruction named \"bar\"", + "None of the following matchers succeeded:\n" + "Matcher #1\n" + " - an HloInstruction named \"foo\"\n" + "failed with\n" + " - HloInstruction not named \"foo\"\n" + " in c = s32[] constant(0)\n" + "Matcher #2\n" + " - an HloInstruction named \"bar\"\n" + "failed with\n" + " - HloInstruction not named \"bar\"\n" + " in c = s32[] constant(0)"); +} + +TEST(PatternMatcherTest, Parameter) { + auto param = + HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "p1"); + auto non_param = + SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + EXPECT_FALSE(Match(param.get(), m::Parameter(0))); + EXPECT_TRUE(Match(param.get(), m::Parameter())); + EXPECT_TRUE(Match(param.get(), m::Parameter(1))); + EXPECT_FALSE(Match(non_param.get(), m::Parameter())); + EXPECT_FALSE(Match(non_param.get(), m::Parameter(1))); + + EXPECT_DESC_AND_EXPLANATION(non_param, m::Parameter(1), + "an HloInstruction:\n" + " * with opcode parameter AND\n" + " * which is parameter 1", + "HloInstruction doesn't have opcode parameter\n" + "in c = s32[] constant(0)"); + EXPECT_EQ(Explanation(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "p0"), + m::Parameter(1)), + "HloInstruction is not parameter 1\n" + "in p0 = f32[] parameter(0)"); +} + +TEST(PatternMatcherTest, OneUseAndOneUser) { + auto param = + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0"); + + EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse())); + EXPECT_DESC_AND_EXPLANATION( + param, m::Op().WithOneUse(), + "an HloInstruction which has exactly one use", + "HloInstruction has 0 users, but expected exactly one.\n" + "in p0 = f32[] parameter(0)"); + + EXPECT_FALSE(Match(param.get(), m::Op().WithOneUser())); + EXPECT_DESC_AND_EXPLANATION( + param, m::Op().WithOneUser(), + "an HloInstruction which has exactly one user (but possibly is used " + "multiple times by that instruction)", + "HloInstruction has 0 users, but expected exactly one.\n" + "in p0 = f32[] parameter(0)"); + + { + auto reshape = + SetName("r", HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {1}), param.get())); + EXPECT_TRUE(Match(param.get(), m::Op().WithOneUse())); + EXPECT_TRUE(Match(param.get(), m::Op().WithOneUser())); + + auto reshape1 = + SetName("r1", HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {1}), param.get())); + EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse())); + EXPECT_FALSE(Match(param.get(), m::Op().WithOneUser())); + + const char* kMultipleUserExplanation = + "HloInstruction has 2 users, but expected exactly one.\n" + "All users:\n" + " - r = f32[1]{0} reshape(f32[] p0)\n" + " - r1 = f32[1]{0} reshape(f32[] p0)\n" + "in p0 = f32[] parameter(0)"; + EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUse()), + kMultipleUserExplanation); + EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUser()), + kMultipleUserExplanation); + } + + auto add = SetName("add", HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, + param.get(), param.get())); + EXPECT_TRUE(Match(param.get(), m::Op().WithOneUser())); + EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse())); + EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUse()), + "HloInstruction is used 2 times by its user, but is expected to be " + "used just once: add = f32[] add(f32[] p0, f32[] p0)\n" + "in p0 = f32[] parameter(0)"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc b/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc index 16fa80d53e7..efeec965714 100644 --- a/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc @@ -54,7 +54,7 @@ TEST_F(ReducePrecisionInsertionTest, BeforeUnaryInstruction) { HloInstruction* b = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, a)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected state before adding ops. @@ -81,7 +81,7 @@ TEST_F(ReducePrecisionInsertionTest, BeforeUnaryScalarInstruction) { HloInstruction* b = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, a)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected state before adding ops. @@ -111,7 +111,7 @@ TEST_F(ReducePrecisionInsertionTest, BeforeBinaryInstruction) { HloInstruction* c = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected state before adding ops. @@ -140,7 +140,7 @@ TEST_F(ReducePrecisionInsertionTest, BeforeZeroInputInstruction) { HloInstruction* b = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, a)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected state before adding ops. @@ -173,7 +173,7 @@ TEST_F(ReducePrecisionInsertionTest, AvoidAddingDuplicateInstructions) { HloInstruction* d = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, b, c)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected state before adding ops. @@ -205,7 +205,7 @@ TEST_F(ReducePrecisionInsertionTest, AfterRootInstruction) { HloInstruction* b = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, a)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected state before adding ops. @@ -242,7 +242,7 @@ TEST_F(ReducePrecisionInsertionTest, AfterNonRootInstruction) { HloInstruction* c = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_cos, b_cos)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); // Confirm expected graph before adding ops. @@ -295,7 +295,7 @@ TEST_F(ReducePrecisionInsertionTest, ShouldReduceOutputPrecisionIsFalse) { HloInstruction* y = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, x)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected graph before adding ops. @@ -321,7 +321,7 @@ TEST_F(ReducePrecisionInsertionTest, InsertionIsNotRecursive) { HloInstruction* b = builder.AddInstruction( HloInstruction::CreateReducePrecision(shape, a, 8, 23)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected state before adding ops. @@ -348,7 +348,7 @@ TEST_F(ReducePrecisionInsertionTest, SkipRedundantReducePrecisionAfter) { HloInstruction* y = builder.AddInstruction( HloInstruction::CreateReducePrecision(shape, x, 5, 10)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected graph before adding ops. @@ -376,7 +376,7 @@ TEST_F(ReducePrecisionInsertionTest, AddNonRedundantReducePrecision) { HloInstruction* y = builder.AddInstruction( HloInstruction::CreateReducePrecision(shape, x, 8, 23)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected graph before adding ops. @@ -402,7 +402,7 @@ TEST_F(ReducePrecisionInsertionTest, IgnoreOpsInsideFusionNode) { builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); HloInstruction* y = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, x)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Manually fuse the kCos operation into a fusion operation. @@ -438,7 +438,7 @@ TEST_F(ReducePrecisionInsertionTest, OpGetsInsertedInHeadOfFusionNode) { builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); HloInstruction* y = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, x)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Manually fuse the kCos operation into a fusion operation. @@ -485,7 +485,7 @@ TEST_F(ReducePrecisionInsertionTest, OpGetsInsertedInTailOfFusionNode) { builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); HloInstruction* y = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, x)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Manually fuse the kCos operation into a fusion operation. diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 75f7413b3c3..5ec7fe2aded 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -41,6 +41,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/source_map_util.h" #include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -275,8 +276,8 @@ StatusOr> Service::CreateModuleConfig( } if (execution_options != nullptr && execution_options->has_shape_with_output_layout()) { - const auto& shape_with_output_layout = - execution_options->shape_with_output_layout(); + const Shape shape_with_output_layout( + execution_options->shape_with_output_layout()); TF_RETURN_IF_ERROR( ValidateResultShape(shape_with_output_layout, program_shape.result())); TF_RETURN_IF_ERROR( @@ -658,9 +659,9 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, // replica 0. TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, - CreateModuleConfig(request.computation().host_program_shape(), - replicated_arguments.front(), - request.execution_options())); + CreateModuleConfig( + ProgramShape{request.computation().host_program_shape()}, + replicated_arguments.front(), request.execution_options())); VLOG(3) << "ExecuteGraphParallel created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); @@ -745,9 +746,9 @@ Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, } if (available_device_count < arg->device_count() * replica_count) { return ResourceExhausted( - "Requested device count (%d) exceeds the number of available devices " - "on the target (%d)", - arg->device_count(), available_device_count); + "Requested logical device count (%d) with replica count (%d) exceeds " + "the number of available physical devices on the target (%d)", + arg->device_count(), replica_count, available_device_count); } for (int64 i = 0; i < arg->device_count(); ++i) { @@ -818,14 +819,17 @@ Status Service::Compile(const CompileRequest* arg, CompileResponse* result) { "The compile request does not support multiple device handles."); } - std::vector argument_shapes; - absl::c_transform(arg->input_shape_with_layout(), - std::back_inserter(argument_shapes), - [](const Shape& shape) { return &shape; }); + std::vector argument_shapes; + argument_shapes.reserve(arg->input_shape_with_layout_size()); + std::vector argument_shape_ptrs; + for (const ShapeProto& shape_proto : arg->input_shape_with_layout()) { + argument_shapes.push_back(Shape(shape_proto)); + argument_shape_ptrs.push_back(&argument_shapes.back()); + } TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, - CreateModuleConfig(arg->computation().host_program_shape(), - argument_shapes, &arg->execution_options())); + CreateModuleConfig(ProgramShape{arg->computation().host_program_shape()}, + argument_shape_ptrs, &arg->execution_options())); VLOG(3) << "Compile created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); @@ -930,14 +934,14 @@ Status Service::TransferToClient(const TransferToClientRequest* arg, TF_ASSIGN_OR_RETURN(const ShapedBuffer* shaped_buffer, allocation_tracker_.ResolveForReplica(arg->data(), 0)); - const Shape* return_shape; + Shape return_shape; if (arg->has_shape_with_layout()) { - if (!LayoutUtil::HasLayout(arg->shape_with_layout())) { + return_shape = Shape(arg->shape_with_layout()); + if (!LayoutUtil::HasLayout(return_shape)) { return InvalidArgument("shape_with_layout must have layout if present."); } - return_shape = &arg->shape_with_layout(); } else { - return_shape = &shaped_buffer->on_host_shape(); + return_shape = Shape(shaped_buffer->on_host_shape()); } TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream( @@ -948,30 +952,15 @@ Status Service::TransferToClient(const TransferToClientRequest* arg, execute_backend_->transfer_manager()->TransferLiteralFromDevice( stream.get(), *shaped_buffer)); - if (LayoutUtil::LayoutsInShapesEqual(*return_shape, result_literal.shape())) { + if (LayoutUtil::LayoutsInShapesEqual(return_shape, result_literal.shape())) { *result->mutable_literal() = result_literal.ToProto(); } else { *result->mutable_literal() = - result_literal.Relayout(*return_shape).ToProto(); + result_literal.Relayout(return_shape).ToProto(); } return Status::OK(); } -namespace { - -// Creates a clone of the given shaped buffer with the given device ordinal. The -// shape and DeviceMemoryBase values of the clone are identical to the original. -std::unique_ptr CloneShapedBufferOnDevice( - const ShapedBuffer& shaped_buffer, int device_ordinal) { - auto clone = absl::make_unique( - shaped_buffer.on_host_shape(), shaped_buffer.on_device_shape(), - shaped_buffer.platform(), device_ordinal); - clone->buffers() = shaped_buffer.buffers(); - return clone; -} - -} // namespace - Status Service::TransferToServer(const TransferToServerRequest* arg, TransferToServerResponse* result) { TF_ASSIGN_OR_RETURN(Literal literal, @@ -1060,11 +1049,11 @@ Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg, executor = replicas[arg->replica_id()]; } - auto literal = Literal::CreateFromShape(arg->shape_with_layout()); + auto literal = Literal::CreateFromShape(Shape(arg->shape_with_layout())); TF_RETURN_IF_ERROR( execute_backend_->transfer_manager()->TransferLiteralFromOutfeed( - executor, arg->shape_with_layout(), literal)); + executor, Shape(arg->shape_with_layout()), literal)); *result->mutable_literal() = literal.ToProto(); return Status::OK(); } @@ -1087,7 +1076,7 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg, "constant computation may not depend on any parameters."); } - ProgramShape program_shape = arg->computation().host_program_shape(); + ProgramShape program_shape(arg->computation().host_program_shape()); TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result())); if (arg->has_output_layout()) { TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape( @@ -1118,7 +1107,7 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg, Status Service::GetShape(const GetShapeRequest* arg, GetShapeResponse* result) { TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer, allocation_tracker_.ResolveForReplica(arg->data(), 0)); - *result->mutable_shape() = buffer->on_host_shape(); + *result->mutable_shape() = buffer->on_host_shape().ToProto(); return Status::OK(); } @@ -1131,7 +1120,7 @@ Status Service::GetComputationGraphStats( return InvalidArgument("Program shape may not be empty."); } - HloModuleConfig config(arg->computation().host_program_shape()); + HloModuleConfig config(ProgramShape{arg->computation().host_program_shape()}); config.set_debug_options(arg->debug_options()); TF_ASSIGN_OR_RETURN(std::unique_ptr module, CreateModuleFromProto(arg->computation(), config)); diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 61a60ef9efa..7e7282a7370 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -391,17 +391,6 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return ShapeUtil::MakeShape(element_type, new_dimensions); } -/* static */ StatusOr ShapeInference::InferAfterAllShape( - absl::Span arg_shapes) { - for (const Shape* arg_shape : arg_shapes) { - if (arg_shape->element_type() != TOKEN) { - return InvalidArgument( - "Operands of token instructions must be TOKEN types."); - } - } - return ShapeUtil::MakeTokenShape(); -} - /* static */ StatusOr ShapeInference::InferConvertShape( const Shape& operand_shape, PrimitiveType new_element_type) { auto old_element_type = operand_shape.element_type(); @@ -1029,7 +1018,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, switch (opcode) { case HloOpcode::kTuple: { Shape result = ShapeUtil::MakeTupleShape({}); - result.mutable_tuple_shapes()->Reserve(operand_shapes.size()); + result.mutable_tuple_shapes()->reserve(operand_shapes.size()); for (const Shape* shape : operand_shapes) { ShapeUtil::AppendShapeToTuple(*shape, &result); } @@ -2038,7 +2027,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, dimension); } - return ShapeUtil::MakeShape(S64, {}); + // TODO(b/119580730): Remove this restriction when very large dimension size + // is needed. + if (shape.dimensions(dimension) > std::numeric_limits::max()) { + return InvalidArgument( + "GetDimensionSize's input shape is %s, the %dth dimension exceeds the " + "UINT_MAX limit.", + ShapeUtil::HumanString(shape), dimension); + } + + return ShapeUtil::MakeShape(U32, {}); } /* static */ StatusOr ShapeInference::InferSliceShape( diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 31ef4b2e410..d94385a04d5 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -232,13 +232,6 @@ class ShapeInference { static StatusOr InferConcatOpShape( absl::Span arg_shapes, int64 dimension); - // Infers the shape produced by a kAfterAll. Trivially this shape is always a - // TOKEN shape. However, ShapeInference serves two purposes: inferring shapes - // and checking operand shapes. This method verifies that the operand shapes - // are all TOKENs. - static StatusOr InferAfterAllShape( - absl::Span arg_shapes); - // Helper that validates the given operand shape can be converted to the // target output_shape via a convert instruction -- the requirement is that // the shape is identical except for the element type. diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index 7a565bf0768..17cdaa74fc3 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -172,7 +172,7 @@ TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) { HloInstruction* mul = builder.AddInstruction(HloInstruction::CreateBinary( add->shape(), HloOpcode::kMultiply, add, sub)); - auto module = CreateNewUnverifiedModule("fuse_with_constant_operands"); + auto module = CreateNewVerifiedModule("fuse_with_constant_operands"); HloComputation* entry_computation = module->AddEntryComputation(builder.Build(mul)); HloInstruction* call = module->OutlineExpressionFromComputation( @@ -247,7 +247,7 @@ TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) { conv_shape.ValueOrDie(), x, transpose_y, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewUnverifiedModule("test_module"); + auto module = CreateNewVerifiedModule("test_module"); HloComputation* entry_computation = module->AddEntryComputation(builder.Build(conv)); FoldTranspose(module.get()); @@ -302,7 +302,7 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) { conv_shape.ValueOrDie(), x, transpose_y, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewUnverifiedModule("test_module"); + auto module = CreateNewVerifiedModule("test_module"); HloComputation* entry_computation = module->AddEntryComputation(builder.Build(conv)); FoldTranspose(module.get()); @@ -362,7 +362,7 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { conv_shape.ValueOrDie(), transpose_x, y, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewUnverifiedModule("test_module"); + auto module = CreateNewVerifiedModule("test_module"); HloComputation* entry_computation = module->AddEntryComputation(builder.Build(conv)); FoldTranspose(module.get()); @@ -428,7 +428,7 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) { conv_shape.ValueOrDie(), transpose_x, y, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewUnverifiedModule("test_module"); + auto module = CreateNewVerifiedModule("test_module"); HloComputation* entry_computation = module->AddEntryComputation(builder.Build(conv)); FoldTranspose(module.get()); diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index 96f3055c98e..50d51eaeb76 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -280,6 +280,13 @@ Status TuplePointsToAnalysis::HandleDomain(HloInstruction* domain) { return Status::OK(); } +Status TuplePointsToAnalysis::HandleAddDependency( + HloInstruction* add_dependency) { + // AddDependency just forwards the value of its zero-th operand. + CreateCopiedPointsToSet(add_dependency, add_dependency->operand(0)); + return Status::OK(); +} + Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) { // RecvDone aliases its input (Recv) tuple element {0} to element {0} of its // output. The other indices ({} and {1}) define their own buffers. diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h index bcfcb388f95..0a1d5649d6d 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h @@ -252,6 +252,7 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { Status HandleRecvDone(HloInstruction* recv_done) override; Status HandleSend(HloInstruction* send) override; Status HandleTupleSelect(HloInstruction* tuple_select) override; + Status HandleAddDependency(HloInstruction* add_dependency) override; string ToString() const; diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index 10ef2d38fa2..561762b5d42 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -264,6 +264,22 @@ TEST_F(TuplePointsToAnalysisTest, GetTupleElement) { UnorderedElementsAre(inner_tuple)); } +TEST_F(TuplePointsToAnalysisTest, AddDependency) { + auto builder = HloComputation::Builder(TestName()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto token = builder.AddInstruction(HloInstruction::CreateToken()); + auto add_dependency = builder.AddInstruction( + HloInstruction::CreateAddDependency(constant, token)); + BuildModuleAndRunAnalysis(builder.Build()); + + auto& points_to_set = points_to_analysis_->GetPointsToSet(add_dependency); + EXPECT_EQ(1, points_to_set.size()); + EXPECT_FALSE(points_to_set.IsAmbiguous()); + EXPECT_TRUE(points_to_set.IsDistinct()); + ExpectHasTopLevelBuffers(points_to_set.CreateFlattenedSet(), {constant}); +} + TEST_F(TuplePointsToAnalysisTest, DuplicatedElement) { // Create a tuple which contains duplicate elements. auto builder = HloComputation::Builder(TestName()); diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc index b7c28bfac78..41011176ffa 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/tuple_util.h" #include "tensorflow/compiler/xla/service/while_loop_analysis.h" #include "tensorflow/compiler/xla/service/while_util.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" namespace xla { @@ -207,6 +208,37 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( continue; } + if (!hoist_size_inflating_ops_) { + // Check that hoisting the instruction doesn't cause a significant memory + // blow-up. LICM extends the live-range of the output of the hoisted + // instruction to be the entire while loop, which may be problematic on + // platforms where memory is limited. This can be especially harmful if + // the instruction has a significantly larger output than its input, e.g. + // kIota, kBroadcast or kConstant. + int64 input_size = 0, output_size = 0; + + for (auto* operand : instruction->operands()) { + ShapeUtil::ForEachSubshape( + operand->shape(), + [&input_size](const Shape& subshape, const ShapeIndex& /*index*/) { + if (ShapeUtil::IsArray(subshape)) { + input_size += ShapeUtil::ByteSizeOfElements(subshape); + } + }); + } + ShapeUtil::ForEachSubshape( + instruction->shape(), + [&output_size](const Shape& subshape, const ShapeIndex& /*index*/) { + if (ShapeUtil::IsArray(subshape)) { + output_size += ShapeUtil::ByteSizeOfElements(subshape); + } + }); + + if (output_size > input_size) { + continue; + } + } + auto is_invariant = [&](HloInstruction* op) { return hoisted_instructions.find(op) != hoisted_instructions.end() || unhoisted_invariant_instructions.count(op) || diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h index 3031899f71e..bd6232dc0a9 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h @@ -34,8 +34,14 @@ class WhileLoopInvariantCodeMotion : public HloModulePass { // Setting `hoist_constants` to false can be help if LICM is run in the mid // level HLO pipeline because hoisting constants out of while loop bodies can // break optimizations like constant folding. - explicit WhileLoopInvariantCodeMotion(bool hoist_constants = false) - : hoist_constants_(hoist_constants) {} + // Setting `hoist_size_inflating_ops` to false will forbid hoisting + // instructions where the size of the output(s) is larger than the size of the + // input(s). This is useful on platforms on which it's important to prevent + // blow-ups in memory size. + explicit WhileLoopInvariantCodeMotion(bool hoist_constants = false, + bool hoist_size_inflating_ops = true) + : hoist_constants_(hoist_constants), + hoist_size_inflating_ops_(hoist_size_inflating_ops) {} ~WhileLoopInvariantCodeMotion() override = default; absl::string_view name() const override { @@ -49,6 +55,7 @@ class WhileLoopInvariantCodeMotion : public HloModulePass { HloInstruction* while_instr); bool hoist_constants_; + bool hoist_size_inflating_ops_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc index 046ccb2d3f2..8e7c4bc8828 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc @@ -570,5 +570,59 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DoNotHoistOutOfSingleIteration) { EXPECT_FALSE(simplified_loop); } +const char* const kInflatingTestCase = R"( +HloModule ModuleWithWhile + +mul { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT mul = f32[] multiply(lhs, rhs) +} + +body { + p_body = (f32[]) parameter(0) + iota = f32[1024, 1024] iota(), iota_dimension=0 + add = f32[1024, 1024] add(iota, iota) + constant = f32[] constant(1.0) + reduce = f32[] reduce(f32[1024, 1024] add, f32[] constant), dimensions={0,1}, to_apply=mul + ROOT root = (f32[]) tuple(reduce) +} + +condition { + p_cond = (f32[]) parameter(0) + ROOT result = pred[] constant(true) +} + +ENTRY entry { + param = f32[] parameter(0) + while_init = (f32[]) tuple(param) + ROOT while = (f32[]) while(while_init), condition=condition, body=body +} +)"; + +TEST_F(WhileLoopInvariantCodeMotionTest, HoistsInflatingByDefault) { + auto m = ParseAndReturnVerifiedModule(kInflatingTestCase).ValueOrDie(); + + TF_ASSERT_OK_AND_ASSIGN( + bool simplified_loop, + WhileLoopInvariantCodeMotion(/*hoist_constants=*/true).Run(m.get())); + EXPECT_TRUE(simplified_loop); + + HloComputation* while_body = m->GetComputationWithName("wide.body"); + ASSERT_NE(while_body, nullptr); + EXPECT_THAT(while_body->instructions(), Not(Contains(op::Iota()))); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, NoHoistInflating) { + auto m = ParseAndReturnVerifiedModule(kInflatingTestCase).ValueOrDie(); + + TF_ASSERT_OK_AND_ASSIGN( + bool simplified_loop, + WhileLoopInvariantCodeMotion(/*hoist_constants=*/true, + /*hoist_size_inflating_ops=*/false) + .Run(m.get())); + EXPECT_FALSE(simplified_loop); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index 6f924a29d8a..d30f67dd811 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -19,13 +19,17 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/types/optional.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_query.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/service/while_loop_analysis.h" namespace xla { +namespace m = match; using absl::optional; using hlo_query::ContainsInstrWithOpcode; @@ -302,6 +306,147 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { return true; } +// Removes each loop parameter (i.e. member of the while loop tuple) that is a +// constant and is the same in the while loop body and the while loop init. +static StatusOr TryRemoveConstantParams(HloInstruction* while_op) { + HloModule* module = while_op->GetModule(); + HloComputation* computation = while_op->parent(); + auto* while_init = while_op->mutable_operand(0); + auto* while_body = while_op->while_body(); + auto* while_cond = while_op->while_condition(); + auto* while_body_root = while_body->root_instruction(); + if (while_init->opcode() != HloOpcode::kTuple || + while_body_root->opcode() != HloOpcode::kTuple) { + return false; + } + + TF_RET_CHECK(while_cond->num_parameters() == 1); + TF_RET_CHECK(while_body->num_parameters() == 1); + TF_RET_CHECK( + ShapeUtil::Compatible(while_init->shape(), while_body_root->shape())); + + absl::flat_hash_set constant_tuple_indices; + const auto& while_shape = while_init->shape(); + for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) { + auto* init_elem = while_init->operand(i); + auto* body_elem = while_body_root->operand(i); + if (init_elem->opcode() == HloOpcode::kConstant && + body_elem->opcode() == HloOpcode::kConstant && + init_elem->literal() == body_elem->literal()) { + constant_tuple_indices.insert(i); + } + } + + if (constant_tuple_indices.empty()) { + return false; + } + + // OK, we found some constant elements of the while parameter! Eliminate + // them. + std::vector new_while_shape_elems; + for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) { + if (!constant_tuple_indices.count(i)) { + new_while_shape_elems.push_back(while_shape.tuple_shapes(i)); + } + } + Shape new_while_shape = ShapeUtil::MakeTupleShape(new_while_shape_elems); + + // `new_instrs` holds instructions created outside of a computation for + // cloning. Elements added here just need to live until the end of the + // relevant CloneWithReplacement call. + std::vector> new_instrs; + auto add_new_instr = [&](std::unique_ptr instr) { + new_instrs.push_back(std::move(instr)); + return new_instrs.back().get(); + }; + + // Returns a new tuple without the elements of constant_tuple_indices. + auto remove_constant_elems = [&](HloInstruction* instr) { + CHECK(ShapeUtil::Compatible(instr->shape(), while_shape)); + + std::vector tuple_elems; + for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) { + if (!constant_tuple_indices.count(i)) { + tuple_elems.push_back( + add_new_instr(HloInstruction::CreateGetTupleElement( + while_shape.tuple_shapes(i), instr, i))); + } + } + return HloInstruction::CreateTuple(tuple_elems); + }; + + auto add_constant_elems = [&](HloInstruction* instr) { + CHECK(ShapeUtil::Compatible(instr->shape(), new_while_shape)); + + std::vector tuple_elems; + int64 j = 0; + for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) { + if (constant_tuple_indices.count(i)) { + tuple_elems.push_back(while_init->mutable_operand(i)); + } else { + tuple_elems.push_back( + add_new_instr(HloInstruction::CreateGetTupleElement( + while_shape.tuple_shapes(i), instr, j))); + ++j; + } + } + return HloInstruction::CreateTuple(tuple_elems); + }; + + // Special case: constant_tuple_indices covers the whole while parameter, so + // the new while shape is the empty tuple. In this case, the value of the + // while loop is simply equal to the value of `init`. + // + // It's unfortunate to special-case this, but it's simpler than the + // alternative. The problem is that if our while parameter has no + // non-constant elems, the tuple returned by `add_constant_elems` won't depend + // on instr (the loop body/cond parameter), and therefore + // CloneWithReplacementPairs will *leave the parameter out entirely*, creating + // invalid HLO. + if (ShapeUtil::IsEmptyTuple(new_while_shape)) { + TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, while_init)); + return true; + } + + std::unique_ptr new_while_cond = + while_cond->CloneWithReplacementPairs({ + while_cond->parameter_instruction(0), + add_constant_elems(add_new_instr(HloInstruction::CreateParameter( + 0, new_while_shape, + while_cond->parameter_instruction(0)->name()))), + }); + + std::unique_ptr new_while_body = + while_body->CloneWithReplacementPairs( + { + while_body->parameter_instruction(0), + add_constant_elems(add_new_instr(HloInstruction::CreateParameter( + 0, new_while_shape, + while_cond->parameter_instruction(0)->name()))), + }, + { + while_body->root_instruction(), + remove_constant_elems( + add_new_instr(while_body->root_instruction()->Clone())), + }); + + // Create the final while loop, and add any new instructions created to + // `computation`. + new_instrs.clear(); + TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( + while_op, + add_constant_elems( + computation->AddInstruction(HloInstruction::CreateWhile( + new_while_shape, + module->AddEmbeddedComputation(std::move(new_while_cond)), + module->AddEmbeddedComputation(std::move(new_while_body)), + add_new_instr(remove_constant_elems(while_init))))))); + for (auto& instr : new_instrs) { + computation->AddInstruction(std::move(instr)); + } + return true; +} + // Tries to remove a while loop from the graph. // // - Loops with trip count of 0 can be replaced by the loop's "init" value. @@ -381,16 +526,14 @@ static StatusOr TryPropagateConstant(HloInstruction* while_op) { // performance by forcing us to copy constants. absl::flat_hash_map index_to_constant; for (int i = 0; i < root_operands.size(); i++) { - HloInstruction* instr = root_operands[i]; - if (instr->opcode() == HloOpcode::kGetTupleElement && - instr->tuple_index() == i && instr->operand(0) == while_body_param && - ShapeUtil::IsScalar(instr->shape())) { - auto tuple_element = while_init->operand(i); - if (tuple_element->IsConstant()) { - VLOG(3) << "Found loop invariant tuple element " << i << " " - << tuple_element->ToString(); - index_to_constant[i] = tuple_element; - } + const HloInstruction* init_tuple_elem = nullptr; + if (Match(root_operands[i], + m::GetTupleElement(m::Op().Is(while_body_param), i) + .WithShape(m::Shape().IsScalar())) && + Match(while_init->operand(i), m::Constant(&init_tuple_elem))) { + VLOG(3) << "Found loop invariant tuple element " << i << " " + << init_tuple_elem->ToString(); + index_to_constant[i] = init_tuple_elem; } } @@ -519,14 +662,6 @@ static StatusOr TryFlattenNestedTuples(HloInstruction* while_op) { return false; } - // Cowardly refuse to perform this optimization in the presence of kDomain - // instructions, which may reference other instructions in the loop and - // therefore make this complicated. - if (ContainsInstrWithOpcode(while_body, {HloOpcode::kDomain}) || - ContainsInstrWithOpcode(while_cond, {HloOpcode::kDomain})) { - return false; - } - std::vector flattened_shape_elems; ShapeUtil::ForEachSubshape(while_shape, [&](const Shape& s, const ShapeIndex& /*index*/) { @@ -605,6 +740,243 @@ static StatusOr TryFlattenNestedTuples(HloInstruction* while_op) { return true; } +// Tries to merge loop induction variables of a given type. +// +// In this pass we're only concerned with elements of the loop's tuple that +// are effective-scalars of type `elem_ty`. Some terminology: +// +// - The trip counter is the first element of the loop's tuple that starts at +// 0 and does x++ on each iteration. +// +// - An induction variable is an element of the loop's tuple that is not the +// trip counter and does `x += ` on each iteration of the loop. +// Negative constants are OK. +// +// This pass adds a trip counter if one isn't already present, then replaces +// each induction variable with +// +// + * . +// +// This reduces the number of scalar operations in the loop, which is important +// e.g. on GPUs, where each scalar operation is nontrivially expensive because +// it's a separate kernel launch. +// +// Returns the new loop if a change was made, or null if no change was made. +// Note that the new loop is not a valid replacement for the old loop; it may +// need to be wrapped in a tuple that changes its shape. We return the loop +// itself so that you can call TryMergeInductionVariables in a loop, once for +// each integral type elem_ty. +static StatusOr TryMergeInductionVariables( + HloInstruction* while_op, PrimitiveType elem_ty) { + CHECK(primitive_util::IsIntegralType(elem_ty)) << PrimitiveType_Name(elem_ty); + HloModule* module = while_op->GetModule(); + HloComputation* computation = while_op->parent(); + auto* while_init = while_op->mutable_operand(0); + auto* while_body = while_op->while_body(); + auto* while_cond = while_op->while_condition(); + auto* while_body_root = while_body->root_instruction(); + if (while_init->opcode() != HloOpcode::kTuple || + while_body_root->opcode() != HloOpcode::kTuple) { + return nullptr; + } + + TF_RET_CHECK(while_cond->num_parameters() == 1); + TF_RET_CHECK(while_body->num_parameters() == 1); + TF_RET_CHECK( + ShapeUtil::Compatible(while_init->shape(), while_body_root->shape())); + Shape while_shape = while_init->shape(); + + // The tuple index of the trip counter, if one is present. + absl::optional trip_counter; + // Maps the tuple index of each induction variable to its constant increment. + absl::flat_hash_map induction_vars; + for (int64 i = 0; i < while_body_root->operand_count(); ++i) { + HloInstruction* constant; + if (!Match(while_body_root->mutable_operand(i), + m::AddAnyOrder(m::GetTupleElement(m::Parameter(), i), + m::ConstantScalar(&constant)) + .WithShape(m::Shape().WithElementType(elem_ty)))) { + continue; + } + if (!trip_counter && constant->literal().IsAll(1) && + while_init->operand(i)->IsConstant() && + while_init->operand(i)->literal().IsAll(0)) { + VLOG(10) << "Found existing trip counter at index " << i; + trip_counter = i; + } else { + VLOG(10) << "Found induction variable at index " << i; + induction_vars.emplace(i, Cast(constant)); + } + } + + // There's only something to simplify if we can either: + // + // - combine one or more induction vars with an existing trip counter, or + // - replace two or more induction variables with a new trip counter. + // + // Put another way, there's only something to simplify if the number of + // induction vars plus the number of existing trip counters (0 or 1) is >= 2. + if (induction_vars.size() + (trip_counter.has_value() ? 1 : 0) < 2) { + return nullptr; + } + + // OK, we're going to do the transformation! Set up some helpers. + + // `new_instrs` holds instructions created outside of a computation for + // cloning. Elements added here just need to live until the end of the + // relevant CloneWithReplacement call. + std::vector> new_instrs; + auto add_new_instr = [&](std::unique_ptr instr) { + new_instrs.push_back(std::move(instr)); + return new_instrs.back().get(); + }; + + auto add_binary_op = [&](const Shape& shape, HloOpcode opcode, + HloInstruction* lhs, HloInstruction* rhs) { + // Reshape lhs/rhs to the output shape if necessary. This deals with the + // fact that induction variables need only be effective scalars, not true + // scalars. + if (!ShapeUtil::Compatible(shape, lhs->shape())) { + lhs = add_new_instr(HloInstruction::CreateReshape(shape, lhs)); + } + if (!ShapeUtil::Compatible(shape, rhs->shape())) { + rhs = add_new_instr(HloInstruction::CreateReshape(shape, rhs)); + } + return add_new_instr(HloInstruction::CreateBinary(shape, opcode, lhs, rhs)); + }; + + auto add_gte = [&](HloInstruction* src, int64 idx) { + return add_new_instr(HloInstruction::CreateGetTupleElement( + src->shape().tuple_shapes(idx), src, idx)); + }; + + // Our new while loop will have the same shape as the old while loop, except + // we'll add a trip counter to the end if it wasn't originally present. + Shape new_while_shape = while_shape; + bool added_trip_counter = false; + if (!trip_counter) { + VLOG(10) << "Adding new trip counter to end of loop's tuple."; + trip_counter = new_while_shape.tuple_shapes_size(); + *new_while_shape.add_tuple_shapes() = + ShapeUtil::MakeShape(elem_ty, /*dimensions=*/{}); + added_trip_counter = true; + } + + // Converts `instr` into a tuple of the "old" form -- that is, to a tuple with + // shape `while_body->shape()` and where the induction variables are "reified" + // (i.e. they have value + * ). + auto convert_to_old_form = [&](HloInstruction* instr) { + CHECK(ShapeUtil::Compatible(instr->shape(), new_while_shape)); + std::vector tuple_elems; + for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) { + const auto& elem_shape = while_shape.tuple_shapes(i); + if (!induction_vars.count(i)) { + tuple_elems.push_back(add_gte(instr, i)); + continue; + } + tuple_elems.push_back(add_binary_op( + elem_shape, HloOpcode::kAdd, add_gte(instr, i), + add_binary_op(elem_shape, HloOpcode::kMultiply, + add_gte(instr, *trip_counter), + add_new_instr(induction_vars.at(i)->Clone())))); + } + return HloInstruction::CreateTuple(tuple_elems); + }; + + // Converts `root` into a tuple of the "new" form -- that is, to a tuple with + // shape `new_while_shape` and where the induction variables (but not trip + // counters) are replaced with their unchanging values. + auto convert_to_new_form = [&](HloInstruction* old_root, + HloParameterInstruction* loop_body_param) { + CHECK(ShapeUtil::Compatible(old_root->shape(), while_shape)); + std::vector tuple_elems; + + // In the new form, induction variables come from `init`, everything else + // (including the trip counter if it's not one we created ourselves) comes + // from the `root` tuple unmodified. + for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) { + tuple_elems.push_back( + add_gte((induction_vars.count(i) ? loop_body_param : old_root), i)); + } + // If we created a trip counter ourselves, add 1 to it in the next + // iteration. + if (added_trip_counter) { + tuple_elems.push_back(add_binary_op( + new_while_shape.tuple_shapes(*trip_counter), HloOpcode::kAdd, + add_gte(loop_body_param, *trip_counter), + add_new_instr( + HloInstruction::CreateConstant(LiteralUtil::One(elem_ty))))); + } + + return HloInstruction::CreateTuple(tuple_elems); + }; + + // Creates a new init tuple, which is the same as the old init tuple except if + // we added a trip counter, it's set to 0. + auto get_new_while_init = [&](HloInstruction* init) { + CHECK(ShapeUtil::Compatible(init->shape(), while_shape)); + if (!added_trip_counter) { + return init; + } + std::vector tuple_elems; + for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) { + tuple_elems.push_back(add_gte(init, i)); + } + tuple_elems.push_back(add_new_instr( + HloInstruction::CreateConstant(LiteralUtil::Zero(elem_ty)))); + return add_new_instr(HloInstruction::CreateTuple(tuple_elems)); + }; + + std::unique_ptr new_while_cond = + while_cond->CloneWithReplacementPairs({ + while_cond->parameter_instruction(0), + convert_to_old_form(add_new_instr(HloInstruction::CreateParameter( + 0, new_while_shape, + while_cond->parameter_instruction(0)->name()))), + }); + + // Creating the new while body proceeds in two steps. First we convert the + // users of the parameter to the old form. Then as a second + // CloneWithReplacement operation we convert the root to the new form. We + // have to do this in two steps because the new root needs to use the new + // param0, and during the first clone operation, only the *old-form* param0 is + // accessible. + // + // We have to add temp_new_while_body to the module because cloning a + // computation touches the module (to get its NameUniquer). + HloComputation* temp_new_while_body = + module->AddEmbeddedComputation(while_body->CloneWithReplacementPairs({ + while_body->parameter_instruction(0), + convert_to_old_form(add_new_instr(HloInstruction::CreateParameter( + 0, new_while_shape, + while_body->parameter_instruction(0)->name()))), + })); + std::unique_ptr new_while_body = + temp_new_while_body->CloneWithReplacementPairs({ + temp_new_while_body->root_instruction(), + convert_to_new_form( + add_new_instr(temp_new_while_body->root_instruction()->Clone()), + Cast( + temp_new_while_body->parameter_instruction(0))), + }); + TF_RETURN_IF_ERROR(module->RemoveEmbeddedComputation(temp_new_while_body)); + + // Create the final while loop, and add any new instructions created to + // `computation`. + new_instrs.clear(); + auto* new_while = computation->AddInstruction(HloInstruction::CreateWhile( + new_while_shape, + module->AddEmbeddedComputation(std::move(new_while_cond)), + module->AddEmbeddedComputation(std::move(new_while_body)), + get_new_while_init(while_init))); + TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( + while_op, convert_to_old_form(new_while))); + for (auto& instr : new_instrs) { + computation->AddInstruction(std::move(instr)); + } + return new_while; +} + StatusOr WhileLoopSimplifier::Run(HloModule* module) { XLA_VLOG_LINES(3, "WhileLoopSimplifier::Run(), before:\n" + module->ToString()); @@ -650,19 +1022,50 @@ StatusOr WhileLoopSimplifier::Run(HloModule* module) { continue; } + // TODO(b/119281462): Cowardly refuse to perform any of the following + // optimizations in the presence of kDomain instructions. It seems that + // modifying a while loop's tuple doesn't work when kDomain is present. + if (ContainsInstrWithOpcode(while_op->while_body(), {HloOpcode::kDomain}) || + ContainsInstrWithOpcode(while_op->while_condition(), + {HloOpcode::kDomain})) { + continue; + } + + // Each of the optimizations below modifies the while loop itself if it's + // successful, meaning that `while_op` is no longer valid after one of these + // transformations returns true. + TF_ASSIGN_OR_RETURN(result, TryFlattenNestedTuples(while_op)); changed |= result; if (result) { - // Successfully flattening nested tuples results in us cloning and - // replacing the while loop, meaning that `while_op` is no longer valid. continue; } TF_ASSIGN_OR_RETURN(result, TryRemoveDeadWhileParams(while_op)); changed |= result; if (result) { - // Successfully removing dead while params results in us cloning and - // replacing the while loop, meaning that `while_op` is no longer valid. + continue; + } + + TF_ASSIGN_OR_RETURN(result, TryRemoveConstantParams(while_op)); + changed |= result; + if (result) { + continue; + } + + bool merged_induction_vars = false; + // Notably missing from this list are S16 and U16. These don't currently + // work because S/U16 literals are not implemented. + for (auto elem_ty : {S8, U8, S32, U32, S64, U64}) { + TF_ASSIGN_OR_RETURN(auto* new_while_op, + TryMergeInductionVariables(while_op, elem_ty)); + if (new_while_op) { + while_op = new_while_op; + changed = true; + merged_induction_vars = true; + } + } + if (merged_induction_vars) { continue; } } diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index 05005e0b262..4950e8269e9 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -17,9 +17,12 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_replace.h" +#include "tensorflow/compiler/xla/service/algebraic_simplifier.h" +#include "tensorflow/compiler/xla/service/hlo_cse.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -27,8 +30,17 @@ limitations under the License. namespace xla { namespace { +using ::testing::_; namespace op = xla::testing::opcode_matchers; +// Returns the first kWhile instruction within m's entry computation. +HloInstruction* FindFirstWhile(HloModule* m) { + const auto& instrs = m->entry_computation()->instructions(); + return *absl::c_find_if(instrs, [](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kWhile; + }); +} + class WhileLoopSimplifierTest : public HloTestBase { protected: // Makes an HloModule that contains a loop with `num_iters` iteration. @@ -540,11 +552,7 @@ TEST_F(WhileLoopSimplifierTest, FlattenNestedTuple) { // it easy to find. EXPECT_TRUE(HloDCE().Run(m.get()).ok()); - const auto& instrs = m->entry_computation()->instructions(); - HloInstruction* new_while = - *absl::c_find_if(instrs, [](const HloInstruction* instr) { - return instr->opcode() == HloOpcode::kWhile; - }); + HloInstruction* new_while = FindFirstWhile(m.get()); Shape flat_tuple = ShapeUtil::ParseShapeString("(s32[1], s32[2], s32[3], s32[4])") .ValueOrDie(); @@ -563,5 +571,177 @@ TEST_F(WhileLoopSimplifierTest, FlattenNestedTuple) { .ValueOrDie())); } +// Edge-case: All elements of the loop carry are constants which can be removed, +// leaving us with a nullary loop. This is a special case, we just replace the +// loop with its init. +TEST_F(WhileLoopSimplifierTest, OnlyConstantsInLoopCarry) { + const string hlo_string = R"( + HloModule Test + Body { + param = (s32[1]) parameter(0) + a = s32[1] constant({0}) + ROOT tuple = (s32[1]) tuple(a) + } + Cond { + param = (s32[1]) parameter(0) + ROOT cond = pred[] constant(true) + } + ENTRY Loop { + a = s32[1] constant({0}) + init = (s32[1]) tuple(a) + ROOT while = (s32[1]) while(init), condition=Cond, body=Body + })"; + + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); + EXPECT_TRUE(HloDCE().Run(m.get()).ok()); + EXPECT_TRUE(TupleSimplifier().Run(m.get()).ok()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + op::Tuple(op::Constant())); +} + +TEST_F(WhileLoopSimplifierTest, RemoveConstantFromLoopCarry) { + const string hlo_string = R"( + HloModule Test + Body { + param = (s32[1], s32[2], s32[3]) parameter(0) + a = s32[1] get-tuple-element(param), index=0 + a.1 = s32[1] add(a, a) + b = s32[2] constant({1,1}) + c = s32[3] constant({10,10,10}) + ROOT tuple = (s32[1], s32[2], s32[3]) tuple(a.1, b, c) + } + Cond { + param = (s32[1], s32[2], s32[3]) parameter(0) + /* Use each tuple element. The verifier will then ensure that if any of + * these get modified, they're replaced with values of the correct shape. */ + a = s32[1] get-tuple-element(param), index=0 + b = s32[2] get-tuple-element(param), index=1 + c = s32[3] get-tuple-element(param), index=2 + ROOT cond = pred[] constant(true) + } + ENTRY Loop { + /* Only `b` should be simplified away. `a` is not a constant within the + * loop, and `c`'s value changes depending on whether we run 0 or 1 + * iterations of the loop. */ + a = s32[1] constant({0}) + b = s32[2] constant({1,1}) + c = s32[3] constant({2,2,2}) + init = (s32[1], s32[2], s32[3]) tuple(a,b,c) + ROOT while = (s32[1], s32[2], s32[3]) while(init), + condition=Cond, body=Body + })"; + + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); + // DCE away the old loop so there's just one while loop in the module, making + // it easy to find. + EXPECT_TRUE(HloDCE().Run(m.get()).ok()); + // Run the tuple simplifier to make the resulting HLO a bit easier to check. + EXPECT_TRUE(TupleSimplifier().Run(m.get()).ok()); + + HloInstruction* new_while = FindFirstWhile(m.get()); + Shape new_while_shape = + ShapeUtil::ParseShapeString("(s32[1], s32[3])").ValueOrDie(); + EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), new_while_shape)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_body()->root_instruction()->shape(), new_while_shape)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_body()->parameter_instruction(0)->shape(), + new_while_shape)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_condition()->parameter_instruction(0)->shape(), + new_while_shape)); + EXPECT_TRUE(ShapeUtil::Equal( + m->entry_computation()->root_instruction()->shape(), + ShapeUtil::ParseShapeString("(s32[1], s32[2], s32[3])").ValueOrDie())); + EXPECT_THAT(m->entry_computation()->root_instruction(), + op::Tuple(_, op::Constant(), _)); +} + +const char* const kSimpleMergeInductionVariablesModule = R"( + HloModule Test + Body { + param = (TYPE[], TYPE[], TYPE[]) parameter(0) + + a = TYPE[] get-tuple-element(param), index=0 + one = TYPE[] constant(1) + a1 = TYPE[] add(a, one) + + b = TYPE[] get-tuple-element(param), index=1 + negone = TYPE[] constant(-1) + b1 = TYPE[] add(b, negone) + + c = TYPE[] add(a, b) + + ROOT tuple = (TYPE[], TYPE[], TYPE[]) tuple(a1,b1,c) + } + Cond { + param = (TYPE[], TYPE[], TYPE[]) parameter(0) + a = TYPE[] get-tuple-element(param), index=0 + b = TYPE[] get-tuple-element(param), index=1 + sum = TYPE[] power(a, b) + ten = TYPE[] constant(10) + ROOT cond = pred[] less-than(sum, ten) + } + ENTRY Loop { + a = TYPE[] constant(10) + b = TYPE[] constant(100) + c = TYPE[] constant(0) + init = (TYPE[], TYPE[], TYPE[]) tuple(a,b,c) + while = (TYPE[], TYPE[], TYPE[]) while(init), condition=Cond, body=Body + + a1 = TYPE[] get-tuple-element(while), index=0 + b1 = TYPE[] get-tuple-element(while), index=1 + ROOT sum = TYPE[] add(a1, b1) + })"; + +TEST_F(WhileLoopSimplifierTest, MergeInductionVariables_Simple) { + string hlo_string = absl::StrReplaceAll(kSimpleMergeInductionVariablesModule, + {{"TYPE", "s32"}}); + + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); + // DCE away the old loop so there's just one while loop in the module, making + // it easy to find, and run the tuple simplifier to make the resulting HLO + // easier to check. + EXPECT_TRUE(HloDCE().Run(m.get()).ok()); + EXPECT_TRUE(TupleSimplifier().Run(m.get()).ok()); + + HloInstruction* new_while = FindFirstWhile(m.get()); + // We should have added a new loop counter for s32[] to the end of the tuple. + SCOPED_TRACE(m->ToString()); + Shape new_while_shape = + ShapeUtil::ParseShapeString("(s32[], s32[], s32[], s32[])").ValueOrDie(); + EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), new_while_shape)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_body()->root_instruction()->shape(), new_while_shape)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_body()->parameter_instruction(0)->shape(), + new_while_shape)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_condition()->parameter_instruction(0)->shape(), + new_while_shape)); + + EXPECT_THAT(new_while->while_body()->root_instruction(), + op::Tuple(op::GetTupleElement(op::Parameter(), 0), + op::GetTupleElement(op::Parameter(), 1), op::Add(), + op::Add(op::GetTupleElement(op::Parameter(), 3), + op::Constant()))); + EXPECT_THAT(new_while->while_condition()->root_instruction(), + op::Lt(op::Power(op::Add(), op::Add()), op::Constant())); +} + +// We shouldn't merge S16 induction variables; we can't create constants of this +// type because S16 literals are not implemented. +TEST_F(WhileLoopSimplifierTest, MergeInductionVariables_SkipS16) { + string hlo_string = absl::StrReplaceAll(kSimpleMergeInductionVariablesModule, + {{"TYPE", "s16"}}); + EXPECT_FALSE( + WhileLoopSimplifier() + .Run(ParseAndReturnVerifiedModule(hlo_string).ValueOrDie().get()) + .ValueOrDie()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/shape.cc b/tensorflow/compiler/xla/shape.cc new file mode 100644 index 00000000000..746ab9e9977 --- /dev/null +++ b/tensorflow/compiler/xla/shape.cc @@ -0,0 +1,107 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/shape.h" + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/shape_util.h" + +namespace xla { + +Shape::Shape(const ShapeProto& shape_proto) { + set_element_type(shape_proto.element_type()); + dimensions_.reserve(shape_proto.dimensions_size()); + for (const int64 dimension : shape_proto.dimensions()) { + add_dimensions(dimension); + } + tuple_shapes_.reserve(shape_proto.tuple_shapes_size()); + for (const ShapeProto& element_shape : shape_proto.tuple_shapes()) { + *add_tuple_shapes() = Shape(element_shape); + } + if (shape_proto.has_layout()) { + *mutable_layout() = shape_proto.layout(); + } +} + +ShapeProto Shape::ToProto() const { + ShapeProto proto; + proto.set_element_type(element_type_); + proto.mutable_dimensions()->Reserve(dimensions_size()); + for (const int64 dimension : dimensions()) { + proto.add_dimensions(dimension); + } + proto.mutable_tuple_shapes()->Reserve(tuple_shapes_size()); + for (const Shape& shape : tuple_shapes()) { + *proto.add_tuple_shapes() = shape.ToProto(); + } + if (has_layout()) { + *proto.mutable_layout() = layout(); + } + return proto; +} + +string Shape::ToString(bool print_layout) const { + if (print_layout) { + return ShapeUtil::HumanStringWithLayout(*this); + } else { + return ShapeUtil::HumanString(*this); + } +} + +std::ostream& operator<<(std::ostream& out, const Shape& shape) { + out << shape.ToString(/*print_layout=*/true); + return out; +} + +ProgramShape::ProgramShape(const ProgramShapeProto& program_shape_proto) { + for (const ShapeProto& shape_proto : program_shape_proto.parameters()) { + *add_parameters() = Shape(shape_proto); + } + *mutable_result() = Shape(program_shape_proto.result()); + for (const string& name : program_shape_proto.parameter_names()) { + add_parameter_names(name); + } +} + +ProgramShapeProto ProgramShape::ToProto() const { + ProgramShapeProto proto; + for (const Shape& shape : parameters()) { + *proto.add_parameters() = shape.ToProto(); + } + *proto.mutable_result() = result().ToProto(); + for (const string& name : parameter_names()) { + proto.add_parameter_names(name); + } + return proto; +} + +string ProgramShape::ToString() const { + std::vector parameter_strings(parameters_size()); + for (int i = 0; i < parameters_size(); ++i) { + parameter_strings[i] = absl::StrCat( + i < parameter_names_size() ? parameter_names(i) : "(unknown)", ": ", + ShapeUtil::HumanString(parameters(i))); + } + return absl::StrCat("(", absl::StrJoin(parameter_strings, ", "), ") -> ", + ShapeUtil::HumanString(result())); +} + +std::ostream& operator<<(std::ostream& out, const ProgramShape& program_shape) { + out << program_shape.ToString() << "\n"; + return out; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/shape.h b/tensorflow/compiler/xla/shape.h new file mode 100644 index 00000000000..7f6b14ab428 --- /dev/null +++ b/tensorflow/compiler/xla/shape.h @@ -0,0 +1,204 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SHAPE_H_ +#define TENSORFLOW_COMPILER_XLA_SHAPE_H_ + +#include +#include + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// A shape describes the number of dimensions in a array, the bounds of each +// dimension, and the primitive component type. For tuples, shape describes the +// structure (number of elements and nesting). +class Shape { + public: + Shape() = default; + + // Construct a shape from a ShapeProto. + explicit Shape(const ShapeProto& shape_proto); + + // Returns a ShapeProto representation of the Shape. + ShapeProto ToProto() const; + + // Returns a human-readable string that represents the given shape, with or + // without layout. e.g. "F32[42,12] {0, 1}" or "F32[64]". + string ToString(bool print_layout = false) const; + + // The following methods mirror the protobuf generated code interface for the + // message ShapeProto. This enabled easy migration of this data structure + // from a proto to a proper C++ class. + // TODO(b/29771030): Replace or augment these methods with a more ergonomic + // interface. + + // Methods for accessing the primitive type. + PrimitiveType element_type() const { return element_type_; } + void set_element_type(PrimitiveType value) { element_type_ = value; } + + // Methods for accessing the dimensions array. + int dimensions_size() const { return dimensions_.size(); } + int64 dimensions(int index) const { return dimensions_.at(index); } + void set_dimensions(int index, int64 value) { dimensions_.at(index) = value; } + void add_dimensions(int64 value) { dimensions_.push_back(value); } + void clear_dimensions() { dimensions_.clear(); } + const std::vector& dimensions() const { return dimensions_; } + std::vector* mutable_dimensions() { return &dimensions_; } + + // Methods for accessing the tuple subshapes. This field only non-empty for + // tuple shapes. + int tuple_shapes_size() const { return tuple_shapes_.size(); } + const Shape& tuple_shapes(int index) const { return tuple_shapes_.at(index); } + Shape* mutable_tuple_shapes(int index) { return &tuple_shapes_.at(index); } + Shape* add_tuple_shapes() { + tuple_shapes_.push_back(Shape()); + return &tuple_shapes_.back(); + } + void clear_tuple_shapes() { tuple_shapes_.clear(); } + const std::vector& tuple_shapes() const { return tuple_shapes_; } + std::vector* mutable_tuple_shapes() { return &tuple_shapes_; } + + // Methods for accessing the layout field. + bool has_layout() const { return layout_.has_value(); } + const Layout& layout() const { + if (layout_.has_value()) { + return *layout_; + } else { + return Layout::default_instance(); + } + } + Layout* mutable_layout() { + if (!layout_.has_value()) { + layout_ = Layout(); + } + return &layout_.value(); + } + void clear_layout() { layout_.reset(); } + + void Swap(Shape* other) { + using std::swap; + swap(*this, *other); + } + + void Clear() { + element_type_ = PRIMITIVE_TYPE_INVALID; + dimensions_.clear(); + tuple_shapes_.clear(); + layout_.reset(); + } + + string SerializeAsString() const { return ToProto().SerializeAsString(); } + string ShortDebugString() const { return ToProto().ShortDebugString(); } + string DebugString() const { return ToProto().DebugString(); } + + public: + // The element type of this shape (tuple, array, etc). + PrimitiveType element_type_ = PRIMITIVE_TYPE_INVALID; + + // The array bounds of the dimensions. This is nonempty only for array shapes. + std::vector dimensions_; + + // The tuple element subshapes. This is nonempty only for tuple shapes. + std::vector tuple_shapes_; + + // The array layout of the shape. This is present only for array shapes. + absl::optional layout_; +}; + +// Shape of the parameters and output of an XLA computation. This is analogous +// to a traditional function signature. +class ProgramShape { + public: + ProgramShape() = default; + + // Creates a ProgramShape from a ProgramShapeProto protobuf. + explicit ProgramShape(const ProgramShapeProto& program_shape_proto); + + // Returns a proto representation of the object. + ProgramShapeProto ToProto() const; + + string ToString() const; + + // The following methods mirror the protobuf generated code interface for the + // message ProgramShapeProto. This enabled easy migration of this data + // structure from a proto to a proper C++ class. + // TODO(b/29771030): Replace or augment these methods with a more ergonomic + // interface. + + // Methods for accessing and manipulating the Shape of the parameters. + int parameters_size() const { return parameters_.size(); } + const Shape& parameters(int index) const { return parameters_.at(index); } + Shape* mutable_parameters(int index) { return ¶meters_.at(index); } + Shape* add_parameters() { + parameters_.emplace_back(); + return ¶meters_.back(); + } + void clear_parameters() { parameters_.clear(); } + const std::vector& parameters() const { return parameters_; } + std::vector* mutable_parameters() { return ¶meters_; } + + // Methods for accessing and manipulating the Shape of the result. + const Shape& result() const { return result_; } + Shape* mutable_result() { return &result_; } + + // Methods for accessing and manipulating the names of the parameters. + int parameter_names_size() const { return parameter_names_.size(); } + const string& parameter_names(int index) const { + return parameter_names_.at(index); + } + void set_parameter_names(int index, const string& value) { + parameter_names_.at(index) = value; + } + string* mutable_parameter_names(int index) { + return ¶meter_names_.at(index); + } + void add_parameter_names(const string& value) { + parameter_names_.push_back(value); + } + string* add_parameter_names() { + parameter_names_.push_back(""); + return ¶meter_names_.back(); + } + void clear_parameter_names() { parameter_names_.clear(); } + const std::vector& parameter_names() const { + return parameter_names_; + } + std::vector* mutable_parameter_names() { return ¶meter_names_; } + + string ShortDebugString() const { return ToProto().ShortDebugString(); } + string DebugString() const { return ToProto().DebugString(); } + + private: + // The shapes of the parameters of the computation represented by this object. + std::vector parameters_; + + // The names of the parameters of the computation represented by this object. + std::vector parameter_names_; + + // The shape of the result of the computation represented by this object. + Shape result_; +}; + +std::ostream& operator<<(std::ostream& out, const Shape& shape); +std::ostream& operator<<(std::ostream& out, const ProgramShape& program_shape); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SHAPE_H_ diff --git a/tensorflow/compiler/xla/shape_test.cc b/tensorflow/compiler/xla/shape_test.cc new file mode 100644 index 00000000000..e396897eeeb --- /dev/null +++ b/tensorflow/compiler/xla/shape_test.cc @@ -0,0 +1,149 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/shape.h" + +#include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +class ShapeTest : public ::testing::Test { + protected: + const Shape opaque_ = ShapeUtil::MakeOpaqueShape(); + const Shape token_ = ShapeUtil::MakeTokenShape(); + const Shape scalar_ = ShapeUtil::MakeShape(F32, {}); + const Shape matrix_ = ShapeUtil::MakeShape(U32, {1, 2}); + const Shape matrix2_ = ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1}); + const Shape tuple_ = + ShapeUtil::MakeTupleShape({opaque_, scalar_, matrix_, matrix2_}); + const Shape nested_tuple_ = + ShapeUtil::MakeTupleShape({tuple_, matrix_, token_}); +}; + +TEST_F(ShapeTest, ShapeToFromProto) { + for (const Shape& shape : + {opaque_, token_, scalar_, matrix_, matrix2_, tuple_, nested_tuple_}) { + Shape shape_copy(shape.ToProto()); + EXPECT_TRUE(ShapeUtil::Equal(shape, shape_copy)) + << shape << " != " << shape_copy; + } +} + +TEST_F(ShapeTest, ShapeToString) { + EXPECT_EQ("opaque[]", opaque_.ToString()); + EXPECT_EQ("token[]", token_.ToString()); + EXPECT_EQ("f32[]", scalar_.ToString()); + EXPECT_EQ("u32[1,2]", matrix_.ToString()); + EXPECT_EQ("s32[3,4]", matrix2_.ToString()); + EXPECT_EQ("(opaque[], f32[], u32[1,2], s32[3,4])", tuple_.ToString()); + EXPECT_EQ("((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])", + nested_tuple_.ToString()); + + EXPECT_EQ("opaque[]", opaque_.ToString(/*print_layout=*/true)); + EXPECT_EQ("f32[]", scalar_.ToString(/*print_layout=*/true)); + EXPECT_EQ("u32[1,2]{1,0}", matrix_.ToString(/*print_layout=*/true)); + EXPECT_EQ("s32[3,4]{0,1}", matrix2_.ToString(/*print_layout=*/true)); + EXPECT_EQ("(opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1})", + tuple_.ToString(/*print_layout=*/true)); + EXPECT_EQ( + "((opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1}), u32[1,2]{1,0}, " + "token[])", + nested_tuple_.ToString(/*print_layout=*/true)); +} + +TEST_F(ShapeTest, ProgramShapeToFromProto) { + ProgramShape program_shape; + *program_shape.add_parameters() = ShapeUtil::MakeShape(F32, {1, 2, 3}); + *program_shape.add_parameters() = ShapeUtil::MakeTokenShape(); + *program_shape.add_parameters() = ShapeUtil::MakeShape(S64, {}); + *program_shape.add_parameters() = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(S32, {}), + ShapeUtil::MakeTupleShape({ShapeUtil::MakeTokenShape()}), + ShapeUtil::MakeShape(F32, {42, 42})}); + + *program_shape.mutable_result() = ShapeUtil::MakeShape(F32, {7}); + + program_shape.add_parameter_names("foo"); + program_shape.add_parameter_names("bar"); + program_shape.add_parameter_names("baz"); + program_shape.add_parameter_names("qux qux"); + + // Create a copy of the program shape by round-tripping through a proto. + ProgramShape program_shape_copy(program_shape.ToProto()); + ASSERT_EQ(program_shape.parameters_size(), + program_shape_copy.parameters_size()); + for (int i = 0; i < program_shape.parameters_size(); ++i) { + EXPECT_TRUE(ShapeUtil::Equal(program_shape.parameters(i), + program_shape_copy.parameters(i))); + } + + EXPECT_TRUE( + ShapeUtil::Equal(program_shape.result(), program_shape_copy.result())); + + ASSERT_EQ(program_shape.parameter_names_size(), + program_shape_copy.parameter_names_size()); + for (int i = 0; i < program_shape.parameter_names_size(); ++i) { + EXPECT_EQ(program_shape.parameter_names(i), + program_shape_copy.parameter_names(i)); + } +} + +TEST_F(ShapeTest, ProgramShapeToString) { + ProgramShape prog = ShapeUtil::MakeProgramShape( + {opaque_, scalar_, matrix_, matrix2_, tuple_, nested_tuple_}, + nested_tuple_); + EXPECT_EQ( + "((unknown): opaque[], " + "(unknown): f32[], " + "(unknown): u32[1,2], " + "(unknown): s32[3,4], " + "(unknown): (opaque[], f32[], u32[1,2], s32[3,4]), " + "(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])) " + "-> " + "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])", + prog.ToString()); + + prog.add_parameter_names("arg0"); + prog.add_parameter_names("scalar"); + prog.add_parameter_names("matrix"); + prog.add_parameter_names("matrix2"); + prog.add_parameter_names("tuple"); + prog.add_parameter_names("nested_tuple"); + EXPECT_EQ( + "(arg0: opaque[], " + "scalar: f32[], " + "matrix: u32[1,2], " + "matrix2: s32[3,4], " + "tuple: (opaque[], f32[], u32[1,2], s32[3,4]), " + "nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], " + "token[])) " + "-> " + "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])", + prog.ToString()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index df610102b4c..7bf97729165 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -667,12 +667,11 @@ void ShapeTree::CopySubtreeFrom(const ShapeTree& other, template bool ShapeTree::operator==(const ShapeTree& other) const { bool equal = true; - ForEachElement( - [this, &other, &equal](const ShapeIndex& index, const T& data) { - if (data != other.element(index)) { - equal = false; - } - }); + ForEachElement([&other, &equal](const ShapeIndex& index, const T& data) { + if (data != other.element(index)) { + equal = false; + } + }); return equal; } diff --git a/tensorflow/compiler/xla/shape_tree_test.cc b/tensorflow/compiler/xla/shape_tree_test.cc index c8ff55e7845..2b6c484bc4f 100644 --- a/tensorflow/compiler/xla/shape_tree_test.cc +++ b/tensorflow/compiler/xla/shape_tree_test.cc @@ -52,10 +52,10 @@ class ShapeTreeTest : public ::testing::Test { TEST_F(ShapeTreeTest, DefaultConstructor) { ShapeTree int_tree; - EXPECT_TRUE(ShapeUtil::IsNil(int_tree.shape())); + EXPECT_TRUE(ShapeUtil::IsEmptyTuple(int_tree.shape())); ShapeTree bool_tree; - EXPECT_TRUE(ShapeUtil::IsNil(bool_tree.shape())); + EXPECT_TRUE(ShapeUtil::IsEmptyTuple(bool_tree.shape())); } void ShapeTreeTest::TestShapeConstructor(const Shape& shape, diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index d0c35d8dee4..f3cc51ca915 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -79,14 +79,14 @@ bool ShapeIndexView::StartsWith(ShapeIndexView prefix) const { indices_.subspan(0, prefix.size()) == prefix.indices_; } -namespace { - -// Returns whether the given primitive type corresponds to an array shape. -bool IsArrayPrimitiveType(PrimitiveType primitive_type) { +/* static */ bool ShapeUtil::IsArrayPrimitiveType( + PrimitiveType primitive_type) { return primitive_type != PRIMITIVE_TYPE_INVALID && primitive_type != TUPLE && primitive_type != OPAQUE && primitive_type != TOKEN; } +namespace { + // Recursive helper for comparing the equality of two shapes. Returns true if // the shapes are the same. If compare_layouts is true, then layouts must also // match. @@ -121,6 +121,23 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts, VLOG(3) << "CompareShapes: lhs layout != rhs layout"; return false; } + + const auto& lhs_tiles = lhs.layout().tiles(); + const auto& rhs_tiles = rhs.layout().tiles(); + if (lhs_tiles.size() != rhs_tiles.size()) { + return false; + } + for (int64 i = 0; i < lhs_tiles.size(); i++) { + if (!absl::c_equal(lhs_tiles[i].dimensions(), + rhs_tiles[i].dimensions())) { + return false; + } + } + + if (lhs.layout().element_size_in_bits() != + rhs.layout().element_size_in_bits()) { + return false; + } } } @@ -203,7 +220,7 @@ StatusOr MakeShapeWithLayoutInternal( /* static */ ProgramShape ShapeUtil::MakeProgramShape( std::initializer_list parameters, Shape result) { ProgramShape program_shape; - for (const auto& shape : parameters) { + for (const Shape& shape : parameters) { *program_shape.add_parameters() = shape; } *program_shape.mutable_result() = std::move(result); @@ -272,7 +289,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( /* static */ Shape ShapeUtil::MakeTupleShape(absl::Span shapes) { Shape result; result.set_element_type(TUPLE); - result.mutable_tuple_shapes()->Reserve(shapes.size()); + result.mutable_tuple_shapes()->reserve(shapes.size()); for (const auto& shape : shapes) { AppendShapeToTuple(shape, &result); } @@ -372,10 +389,6 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return IsTuple(shape) && TupleElementCount(shape) == 0; } -/* static */ bool ShapeUtil::IsNil(const Shape& shape) { - return IsEmptyTuple(shape); -} - /* static */ int64 ShapeUtil::TupleElementCount(const Shape& shape) { CHECK(IsTuple(shape)) << HumanString(shape); return shape.tuple_shapes_size(); @@ -1155,7 +1168,7 @@ Status ForEachMutableSubshapeHelper( // Let the argument `permutation` be P. This is a permutation over `shape`'s // dimensions, so our return value will be a shape with dims P.I = P. Our // goal is to construct a layout permutation L* that we can apply to P such - // that that the physical dimension ordering of the returned shape is the same + // that the physical dimension ordering of the returned shape is the same // as that of the original shape, namely L'. // // Our returned shape has dims P and layout L*, so its in-memory layout is @@ -1600,7 +1613,8 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ Shape ShapeUtil::DeleteDimension(int64 dim_to_delete, Shape shape) { CHECK(IsArray(shape)); - shape.mutable_dimensions()->erase(shape.dimensions().begin() + dim_to_delete); + shape.mutable_dimensions()->erase(shape.mutable_dimensions()->begin() + + dim_to_delete); if (LayoutUtil::HasLayout(shape)) { Layout* layout = shape.mutable_layout(); layout->set_format(DENSE); @@ -1634,11 +1648,6 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, return shape; } -std::ostream& operator<<(std::ostream& out, const Shape& shape) { - out << ShapeUtil::HumanStringWithLayout(shape); - return out; -} - /*static*/ size_t ShapeUtil::Hash(const Shape& shape) { using tensorflow::hash; using tensorflow::Hash64Combine; diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index a7a3026cf3f..84a27f662a5 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -28,6 +28,7 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -37,6 +38,7 @@ limitations under the License. #include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -100,6 +102,11 @@ class ShapeIndex { string ToString() const; + template + friend H AbslHashValue(H h, const ShapeIndex& index) { + return H::combine(std::move(h), index.indices_); + } + private: container_type indices_; }; @@ -461,6 +468,9 @@ class ShapeUtil { // arrays. static bool IsArray(const Shape& shape); + // Returns whether the given primitive type corresponds to an array shape. + static bool IsArrayPrimitiveType(PrimitiveType primitive_type); + // Returns whether the shape is a tuple with at least one element which is // also a tuple. static bool IsNestedTuple(const Shape& shape); @@ -468,9 +478,6 @@ class ShapeUtil { // Returns true if shape is an empty tuple. static bool IsEmptyTuple(const Shape& shape); - // Returns true if shape is the nil shape (an empty tuple). - static bool IsNil(const Shape& shape); - // Returns the number of elements in the given tuple shape. // Precondition: IsTuple(shape) static int64 TupleElementCount(const Shape& shape); @@ -754,10 +761,18 @@ class ShapeUtil { pool.emplace(tensorflow::Env::Default(), "foreach", kNumThreads); } + tensorflow::mutex mu; + Status status; // Guarded by mu + while (n < rank) { if (pool != absl::nullopt) { - pool->Schedule( - [indexes, &visitor_function] { visitor_function(indexes); }); + pool->Schedule([indexes, &visitor_function, &mu, &status] { + StatusOr result = visitor_function(indexes); + if (!result.ok()) { + tensorflow::mutex_lock lock(mu); + status = status.ok() ? result.status() : status; + } + }); } else { TF_ASSIGN_OR_RETURN(bool should_continue, visitor_function(indexes)); if (!should_continue) { @@ -775,14 +790,14 @@ class ShapeUtil { } } - return Status::OK(); + // Waits for the scheduled work to complete. + pool.reset(); + return status; } TF_DISALLOW_COPY_AND_ASSIGN(ShapeUtil); }; -std::ostream& operator<<(std::ostream& out, const Shape& shape); - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_ diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 0c647369a37..60bdbe30204 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -376,12 +376,12 @@ TEST(ShapeUtilTest, ByteSizeOfWithoutPadding) { } TEST(ShapeUtilTest, NilShape) { - EXPECT_TRUE(ShapeUtil::IsNil(ShapeUtil::MakeNil())); - EXPECT_FALSE(ShapeUtil::IsNil(ShapeUtil::MakeShape(F32, {1, 2, 3}))); - EXPECT_FALSE(ShapeUtil::IsNil(ShapeUtil::MakeShape(F32, {0, 1}))); - EXPECT_FALSE(ShapeUtil::IsNil( + EXPECT_TRUE(ShapeUtil::IsEmptyTuple(ShapeUtil::MakeNil())); + EXPECT_FALSE(ShapeUtil::IsEmptyTuple(ShapeUtil::MakeShape(F32, {1, 2, 3}))); + EXPECT_FALSE(ShapeUtil::IsEmptyTuple(ShapeUtil::MakeShape(F32, {0, 1}))); + EXPECT_FALSE(ShapeUtil::IsEmptyTuple( ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {})}))); - EXPECT_FALSE(ShapeUtil::IsNil( + EXPECT_FALSE(ShapeUtil::IsEmptyTuple( ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {0})}))); } @@ -546,68 +546,6 @@ TEST(ShapeUtilTest, IsLeafIndex) { EXPECT_TRUE(ShapeUtil::IsLeafIndex(nested_tuple_shape, {1, 1})); } -TEST(ShapeUtilTest, HumanString) { - Shape opaque = ShapeUtil::MakeOpaqueShape(); - Shape token = ShapeUtil::MakeTokenShape(); - Shape scalar = ShapeUtil::MakeShape(F32, {}); - Shape matrix = ShapeUtil::MakeShape(U32, {1, 2}); - Shape matrix2 = ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1}); - Shape tuple = ShapeUtil::MakeTupleShape({opaque, scalar, matrix, matrix2}); - Shape nested_tuple = ShapeUtil::MakeTupleShape({tuple, matrix, token}); - - EXPECT_EQ("opaque[]", ShapeUtil::HumanString(opaque)); - EXPECT_EQ("token[]", ShapeUtil::HumanString(token)); - EXPECT_EQ("f32[]", ShapeUtil::HumanString(scalar)); - EXPECT_EQ("u32[1,2]", ShapeUtil::HumanString(matrix)); - EXPECT_EQ("s32[3,4]", ShapeUtil::HumanString(matrix2)); - EXPECT_EQ("(opaque[], f32[], u32[1,2], s32[3,4])", - ShapeUtil::HumanString(tuple)); - EXPECT_EQ("((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])", - ShapeUtil::HumanString(nested_tuple)); - - EXPECT_EQ("opaque[]", ShapeUtil::HumanStringWithLayout(opaque)); - EXPECT_EQ("f32[]", ShapeUtil::HumanStringWithLayout(scalar)); - EXPECT_EQ("u32[1,2]{1,0}", ShapeUtil::HumanStringWithLayout(matrix)); - EXPECT_EQ("s32[3,4]{0,1}", ShapeUtil::HumanStringWithLayout(matrix2)); - EXPECT_EQ("(opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1})", - ShapeUtil::HumanStringWithLayout(tuple)); - EXPECT_EQ( - "((opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1}), u32[1,2]{1,0}, " - "token[])", - ShapeUtil::HumanStringWithLayout(nested_tuple)); - - ProgramShape prog = ShapeUtil::MakeProgramShape( - {opaque, scalar, matrix, matrix2, tuple, nested_tuple}, nested_tuple); - EXPECT_EQ( - "((unknown): opaque[], " - "(unknown): f32[], " - "(unknown): u32[1,2], " - "(unknown): s32[3,4], " - "(unknown): (opaque[], f32[], u32[1,2], s32[3,4]), " - "(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])) " - "-> " - "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])", - ShapeUtil::HumanString(prog)); - - prog.add_parameter_names("arg0"); - prog.add_parameter_names("scalar"); - prog.add_parameter_names("matrix"); - prog.add_parameter_names("matrix2"); - prog.add_parameter_names("tuple"); - prog.add_parameter_names("nested_tuple"); - EXPECT_EQ( - "(arg0: opaque[], " - "scalar: f32[], " - "matrix: u32[1,2], " - "matrix2: s32[3,4], " - "tuple: (opaque[], f32[], u32[1,2], s32[3,4]), " - "nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], " - "token[])) " - "-> " - "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])", - ShapeUtil::HumanString(prog)); -} - TEST(ShapeUtilTest, ForEachSubshapeArray) { const Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); int calls = 0; diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index db34d34f969..f7f090fe4ab 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -79,6 +79,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/core:lib", + "@com_google_absl//absl/base", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", ], @@ -135,6 +136,7 @@ cc_library( "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", @@ -297,6 +299,56 @@ xla_test( ], ) +xla_test( + name = "conv_depthwise_test", + timeout = "long", + srcs = ["conv_depthwise_test.cc"], + blacklisted_backends = [ + # disabled because of a break b/119590850. + "gpu", + ], + shard_count = 50, + deps = [ + "//tensorflow/compiler/xla:execution_options_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/service:bfloat16_normalization", + "//tensorflow/compiler/xla/service:despecializer", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/types:optional", + ], +) + +xla_test( + name = "grouped_convolution_test", + timeout = "long", + srcs = ["grouped_convolution_test.cc"], + blacklisted_backends = [ + # disabled because of a break b/119590850. + "gpu", + # disabled because it times out. + "cpu", + ], + shard_count = 50, + deps = [ + "//tensorflow/compiler/xla:execution_options_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/service:bfloat16_normalization", + "//tensorflow/compiler/xla/service:despecializer", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/types:optional", + ], +) + xla_test( name = "check_execution_arity_test", srcs = ["check_execution_arity_test.cc"], @@ -1265,6 +1317,7 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1865,6 +1918,7 @@ xla_test( xla_test( name = "multioutput_fusion_test", srcs = ["multioutput_fusion_test.cc"], + backends = ["gpu"], deps = [ "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 2180b22cb3b..f6be27bee27 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -350,6 +350,44 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { error_spec_); } +// TODO(b/119692968): This test runs OOM on the GPU and CPU backend. +XLA_TEST_F(ArrayElementwiseOpTest, + DISABLED_ON_GPU(DISABLED_ON_CPU(DeeplyNestedAddWithSlices))) { + XlaBuilder builder(TestName()); + std::vector values(30, 0.0); + auto a_literal = LiteralUtil::CreateR1(values); + auto a = Parameter(&builder, 0, a_literal.shape(), "x"); + auto b_literal = LiteralUtil::CreateR1(values); + auto b = Parameter(&builder, 1, b_literal.shape(), "x"); + + // Construct a sequence of diamond-shaped gadgets like this: + // + // add + // / \ + // slice slice + // \ / + // add + // + // Each 'left' slice removes the last element, each 'right' slice removes the + // first element. In this way, we index into the add with different + // multi-dimensional index arrays, which defeats the caching we use to avoid + // exponential compile time. + std::function generate_recursive = + [&](int64 slice_size) -> XlaOp { + if (slice_size == values.size()) { + return Add(a, b); + } + XlaOp param = generate_recursive(slice_size + 1); + auto slice1 = Slice(param, {0}, {slice_size}, {1}); + auto slice2 = Slice(param, {1}, {slice_size + 1}, {1}); + return Add(slice1, slice2); + }; + generate_recursive(1); + auto a_data = client_->TransferToServer(a_literal).ConsumeValueOrDie(); + auto b_data = client_->TransferToServer(b_literal).ConsumeValueOrDie(); + ComputeAndCompareR1(&builder, {0.0}, {a_data.get(), b_data.get()}); +} + XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantF32s) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); @@ -2744,12 +2782,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGtR3F32sWithDegenerateDim2) { Array3D expected_3d( {{{0, 1}, {0, 0}, {0, 0}}, {{0, 1}, {1, 0}, {0, 1}}}); const string expected = R"(pred[2,3,2] { -{ { 0, 1 }, +{ + { 0, 1 }, { 0, 0 }, - { 0, 0 } }, -{ { 0, 1 }, + { 0, 0 } +}, +{ + { 0, 1 }, { 1, 0 }, - { 0, 1 } } + { 0, 1 } +} })"; EXPECT_EQ(expected, ExecuteToString(&builder, {})); } diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc index dde19fb65d6..702fb32adfc 100644 --- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc @@ -161,8 +161,7 @@ XLA_TEST_F(BroadcastSimpleTest, 1DTo2D) { XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsUsual) { XlaBuilder b(TestName()); - BroadcastInDim(ConstantR1(&b, {1, 2}), - ShapeUtil::MakeShape(F32, {2, 2}), {1}); + BroadcastInDim(ConstantR1(&b, {1, 2}), {2, 2}, {1}); Array2D expected(2, 2); expected(0, 0) = 1; @@ -175,8 +174,7 @@ XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsUsual) { XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsTranspose) { XlaBuilder b(TestName()); - BroadcastInDim(ConstantR1(&b, {1, 2}), - ShapeUtil::MakeShape(F32, {2, 2}), {0}); + BroadcastInDim(ConstantR1(&b, {1, 2}), {2, 2}, {0}); Array2D expected(2, 2); expected(0, 0) = 1; @@ -189,8 +187,8 @@ XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsTranspose) { XLA_TEST_F(BroadcastSimpleTest, 2DTo3D_WithDims) { XlaBuilder b(TestName()); - BroadcastInDim(ConstantR2(&b, {{1.0, 5.0}, {2.0, 6.0}}), - ShapeUtil::MakeShape(F32, {2, 2, 2}), {0, 1}); + BroadcastInDim(ConstantR2(&b, {{1.0, 5.0}, {2.0, 6.0}}), {2, 2, 2}, + {0, 1}); Array3D expected(2, 2, 2); expected(0, 0, 0) = 1.0; @@ -207,8 +205,8 @@ XLA_TEST_F(BroadcastSimpleTest, 2DTo3D_WithDims) { XLA_TEST_F(BroadcastSimpleTest, 2DTo3D_WithDimsNotPossibleWithBroadCast) { XlaBuilder b(TestName()); - BroadcastInDim(ConstantR2(&b, {{1.0, 5.0}, {2.0, 6.0}}), - ShapeUtil::MakeShape(F32, {2, 2, 2}), {0, 2}); + BroadcastInDim(ConstantR2(&b, {{1.0, 5.0}, {2.0, 6.0}}), {2, 2, 2}, + {0, 2}); Array3D expected(2, 2, 2); expected(0, 0, 0) = 1.0; @@ -225,8 +223,7 @@ XLA_TEST_F(BroadcastSimpleTest, 2DTo3D_WithDimsNotPossibleWithBroadCast) { XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsNotPossibleWithBroadCast) { XlaBuilder b(TestName()); - BroadcastInDim(ConstantR1(&b, {1, 2}), - ShapeUtil::MakeShape(F32, {3, 2}), {1}); + BroadcastInDim(ConstantR1(&b, {1, 2}), {3, 2}, {1}); Array2D expected(3, 2); expected(0, 0) = 1; diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index b98572e24c8..12c02998333 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -107,7 +107,7 @@ StatusOr ClientLibraryTestBase::ExecuteAndTransfer( ExecutionOptions execution_options = execution_options_; if (shape_with_output_layout != nullptr) { *execution_options.mutable_shape_with_output_layout() = - *shape_with_output_layout; + shape_with_output_layout->ToProto(); } return client_->ExecuteAndTransfer(computation, arguments, &execution_options); @@ -127,7 +127,7 @@ StatusOr ClientLibraryTestBase::ExecuteAndTransferReference( ExecutionOptions execution_options = execution_options_; if (shape_with_output_layout != nullptr) { *execution_options.mutable_shape_with_output_layout() = - *shape_with_output_layout; + shape_with_output_layout->ToProto(); } execution_options.clear_device_handles(); return ref_client_->ExecuteAndTransfer(computation, arguments, diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 34148e5886d..65a23dd8835 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -76,7 +76,7 @@ class ClientLibraryTestBase : public ::testing::Test { void SetFastMathDisabled(bool disabled) { auto* opts = execution_options_.mutable_debug_options(); opts->set_xla_cpu_enable_fast_math(!disabled); - opts->set_xla_gpu_enable_fast_math(!disabled); + opts->set_xla_gpu_enable_fast_min_max(!disabled); } void SetSeed(uint64 seed) { execution_options_.set_seed(seed); } diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc index 6f2ca84bb64..363dee74b27 100644 --- a/tensorflow/compiler/xla/tests/client_test.cc +++ b/tensorflow/compiler/xla/tests/client_test.cc @@ -50,7 +50,8 @@ XLA_TEST_F(ClientTest, ExecuteWithLayout) { ExecutionOptions execution_options = execution_options_; *execution_options.mutable_shape_with_output_layout() = ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2}, - execute_layout); + execute_layout) + .ToProto(); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr data, client_->Execute(computation, {}, &execution_options)); @@ -84,7 +85,8 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) { {ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2}, /*minor_to_major=*/{0, 1}), ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2}, - /*minor_to_major=*/{1, 0})}); + /*minor_to_major=*/{1, 0})}) + .ToProto(); TF_ASSERT_OK_AND_ASSIGN( auto result, diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc index 9811a015e91..4f5b525a342 100644 --- a/tensorflow/compiler/xla/tests/concat_test.cc +++ b/tensorflow/compiler/xla/tests/concat_test.cc @@ -492,6 +492,32 @@ XLA_TEST_F(ConcatTest, ConcatR3WeirdDims) { ComputeAndCompareR3(&builder, expected, {p0.get(), p1.get()}); } +XLA_TEST_F(ConcatTest, ConcatDeeplyNested) { + XlaBuilder builder(TestName()); + auto a_literal = LiteralUtil::CreateR1({256.0}); + auto a = Parameter(&builder, 0, a_literal.shape(), "x"); + auto b = ConcatInDim(&builder, {a, a}, 0); + auto c = ConcatInDim(&builder, {b, b}, 0); + auto d = ConcatInDim(&builder, {c, c}, 0); + auto e = ConcatInDim(&builder, {d, d}, 0); + auto f = ConcatInDim(&builder, {e, e}, 0); + auto g = ConcatInDim(&builder, {f, f}, 0); + auto h = ConcatInDim(&builder, {g, g}, 0); + auto i = ConcatInDim(&builder, {h, h}, 0); + auto j = ConcatInDim(&builder, {i, i}, 0); + auto k = ConcatInDim(&builder, {j, j}, 0); + auto l = ConcatInDim(&builder, {k, k}, 0); + auto m = ConcatInDim(&builder, {l, l}, 0); + auto n = ConcatInDim(&builder, {m, m}, 0); + auto o = ConcatInDim(&builder, {n, n}, 0); + auto p = ConcatInDim(&builder, {o, o}, 0); + auto q = ConcatInDim(&builder, {p, p}, 0); + ConcatInDim(&builder, {q, q}, 0); + std::vector expected(131072, 256.0); + auto a_data = client_->TransferToServer(a_literal).ConsumeValueOrDie(); + ComputeAndCompareR1(&builder, expected, {a_data.get()}); +} + // Describes a binary rank-2 concatenation test. struct R2BinarySpec { int64 lhs_dim0; diff --git a/tensorflow/compiler/xla/tests/conv_depthwise_test.cc b/tensorflow/compiler/xla/tests/conv_depthwise_test.cc new file mode 100644 index 00000000000..bc9bd8a2691 --- /dev/null +++ b/tensorflow/compiler/xla/tests/conv_depthwise_test.cc @@ -0,0 +1,234 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/service/bfloat16_normalization.h" +#include "tensorflow/compiler/xla/service/despecializer.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +string GetFloatDataType(bool use_bfloat16) { + return use_bfloat16 ? "bf16" : "f32"; +} + +struct DepthwiseConvolution2DSpec { + int64 output_feature, window, stride, pad, lhs_dilate; + std::vector activation_dims; + std::vector activation_layout; + std::vector kernel_dims; + std::vector kernel_layout; + std::vector output_dims; + std::vector output_layout; +}; + +class DepthwiseConvolution2DTest + : public HloTestBase, + public ::testing::WithParamInterface< + ::testing::tuple> {}; + +static std::vector GetConv2DTestCases() { + std::vector config_set; + std::vector> config_options = { + {128, 6, 3, 64}, {256, 5, 3, 256}, {256, 5, 2, 144}, {144, 5, 3, 64}, + {144, 5, 2, 256}, {8, 48, 17, 8}, {128, 20, 6, 64}, {128, 1, 2, 144}, + {256, 1, 2, 64}, {64, 14, 12, 172}, {16, 9, 4, 16}}; + + for (auto option : config_options) { + int64 feature = option[0]; + int64 activation_size = option[1]; + int64 kernel_size = option[2]; + int64 batch = option[3]; + + std::vector kernel_layout = {3, 2, 1, 0}; + DepthwiseConvolution2DSpec config; + config.output_feature = feature; + config.window = kernel_size; + + config.activation_dims = {batch, activation_size, activation_size, feature}; + config.activation_layout = {3, 0, 2, 1}; + + config.kernel_dims = {kernel_size, kernel_size, 1, feature}; + config.kernel_layout = {3, 2, 1, 0}; + + if (activation_size == 1 && kernel_size == 2) { + // Test for outer dim. + config.output_dims = {batch, activation_size + kernel_size - 1, + activation_size + kernel_size, feature}; + } else if (feature == 256) { + // Restrict dilation-based tests only to one feature configuration. + config.stride = activation_size - 1; + config.pad = 0; + config.lhs_dilate = feature / 32; + config.output_dims = {batch, feature / 32, + activation_size - kernel_size + 1, feature}; + } else { + config.stride = config.pad = config.lhs_dilate = -1; + config.output_dims = {batch, activation_size - kernel_size + 1, + activation_size - kernel_size + 1, feature}; + } + + // Try this layout for all kernel shapes. + config.output_layout = {3, 0, 2, 1}; + config_set.push_back(config); + + // Try other layouts only for certain kernel shapes. + if (kernel_size % 2 == 0) { + config.activation_layout = {0, 3, 2, 1}; + config_set.push_back(config); + + config.output_layout = {0, 3, 2, 1}; + config_set.push_back(config); + + config.activation_layout = {3, 0, 2, 1}; + config_set.push_back(config); + } + } + + return config_set; +} + +string DepthwiseConvolution2DTestDataToString( + const ::testing::TestParamInfo< + ::testing::tuple>& data) { + const auto& spec = ::testing::get<0>(data.param); + const string data_type = GetFloatDataType(::testing::get<1>(data.param)); + string str = absl::StrCat( + "activation_dims_", absl::StrJoin(spec.activation_dims, "x"), + "_activation_layout_", absl::StrJoin(spec.activation_layout, "_"), + "_kernel_dims_", absl::StrJoin(spec.kernel_dims, "x"), "_kernel_layout_", + absl::StrJoin(spec.kernel_layout, "_"), "_output_dims_", + absl::StrJoin(spec.output_dims, "x"), "_output_layout_", + absl::StrJoin(spec.output_layout, "_"), data_type); + // -1 indicates non-existence. + if (spec.stride != -1) { + absl::StrAppend(&str, "_lhs_dilation_", spec.lhs_dilate, "x1"); + } + + // Test names are not allowed to contain the '-' character. + absl::c_replace(str, '-', 'n'); + return str; +} + +string BuildHloTextDepthwiseConvolution2D( + const DepthwiseConvolution2DSpec& spec, bool use_bfloat16) { + const string data_type = GetFloatDataType(use_bfloat16); + if (spec.activation_dims[1] == 1 && spec.kernel_dims[1] == 2) { + return absl::StrFormat( + R"( + HloModule TensorFlowDepthwiseConv + + ENTRY main { + activation = %s[%s]{%s} parameter(0) + kernel = %s[%s]{%s} parameter(1) + ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel), + window={size=%dx%d pad=1_1x%d_%d rhs_dilate=1x%d}, dim_labels=b01f_01io->b01f, + feature_group_count=%d + } + )", + data_type, absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), data_type, + absl::StrJoin(spec.output_dims, ","), + absl::StrJoin(spec.output_layout, ","), data_type, + absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), spec.window, spec.window, + spec.window, spec.window, spec.window, spec.output_feature); + + } else if (spec.stride == -1) { + return absl::StrFormat( + R"( + HloModule TensorFlowDepthwiseConv + + ENTRY main { + activation = %s[%s]{%s} parameter(0) + kernel = %s[%s]{%s} parameter(1) + ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel), + window={size=%dx%d}, dim_labels=b01f_01io->b01f, + feature_group_count=%d + } + )", + data_type, absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), data_type, + absl::StrJoin(spec.output_dims, ","), + absl::StrJoin(spec.output_layout, ","), data_type, + absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), spec.window, spec.window, + spec.output_feature); + } else { + return absl::StrFormat( + R"( + HloModule TensorFlowDepthwiseConv + + ENTRY main { + activation = %s[%s]{%s} parameter(0) + kernel = %s[%s]{%s} parameter(1) + ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel), + window={size=%dx%d stride=%dx1 pad=%d_%dx0_0 lhs_dilate=%dx1}, + dim_labels=b01f_01io->b01f, feature_group_count=%d + } + )", + data_type, absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), data_type, + absl::StrJoin(spec.output_dims, ","), + absl::StrJoin(spec.output_layout, ","), data_type, + absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), spec.window, spec.window, + spec.stride, 0, 0, spec.lhs_dilate, spec.output_feature); + } +} + +XLA_TEST_P(DepthwiseConvolution2DTest, DoIt) { + const DepthwiseConvolution2DSpec& spec = ::testing::get<0>(GetParam()); + bool use_bfloat16 = ::testing::get<1>(GetParam()); + const string hlo_text = + BuildHloTextDepthwiseConvolution2D(spec, use_bfloat16); + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{0.01, 0.01}, + [](HloModule* module) -> Status { + BFloat16MixedPrecisionRemoval remover; + TF_RETURN_IF_ERROR(remover.Run(module).status()); + Despecializer despecializer; + return despecializer.Run(module).status(); + })); +} + +INSTANTIATE_TEST_CASE_P( + DepthwiseConvolution2DTestWithRandomIndices, DepthwiseConvolution2DTest, + ::testing::Combine(::testing::ValuesIn(GetConv2DTestCases()), + ::testing::Bool()), + DepthwiseConvolution2DTestDataToString); + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 211d004ec8c..459add96813 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -721,8 +721,6 @@ class Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid : public ConvolutionTest { ComputeAndCompareLiteral(&builder, expected_r4, {input_literal.get(), filter_literal.get()}, error_spec_); - - auto filter_r = filter_r1.Reshape(filter_dims); } }; @@ -731,6 +729,291 @@ TYPED_TEST(Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid, Types) { this->RunTest(); } +template +class Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid_Output_Batch_In_Lanes + : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 4, 4, 512}; + std::vector filter_dims = {3, 3, 1, 512}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/512); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape), + static_cast(1)); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), + static_cast(2)); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + std::vector output_elems(2048, static_cast(18)); + + auto expected_r1 = LiteralUtil::CreateR1(output_elems); + auto expected_r4 = expected_r1.Reshape({1, 2, 2, 512}).ConsumeValueOrDie(); + auto expected_r4_relaid = + expected_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1})); + + auto input_literal = + client_->TransferToServer(input_r4).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4_relaid, + {input_literal.get(), filter_literal.get()}, + error_spec_, &expected_r4_relaid.shape()); + } +}; + +TYPED_TEST_CASE( + Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid_Output_Batch_In_Lanes, + TestTypes); +TYPED_TEST(Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid_Output_Batch_In_Lanes, + Types) { + this->RunTest(); +} + +template +class Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Input_Batch_in_Lanes + : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {256, 4, 4, 512}; + std::vector filter_dims = {3, 3, 1, 512}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/512); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape), + static_cast(1)); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + auto input_r4_relaid = + input_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1})); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), + static_cast(2)); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + std::vector output_elems(2048 * 256, static_cast(18)); + + auto expected_r1 = LiteralUtil::CreateR1(output_elems); + auto expected_r4 = + expected_r1.Reshape({256, 2, 2, 512}).ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(input_r4_relaid).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); + } +}; + +TYPED_TEST_CASE(Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Input_Batch_in_Lanes, + TestTypes); +TYPED_TEST(Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Input_Batch_in_Lanes, + Types) { + this->RunTest(); +} + +template +class Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Both_Batch_in_Lanes + : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {256, 4, 4, 512}; + std::vector filter_dims = {3, 3, 1, 512}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/512); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape), + static_cast(1)); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + auto input_r4_relaid = + input_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1})); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), + static_cast(2)); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + std::vector output_elems(2048 * 256, static_cast(18)); + + auto expected_r1 = LiteralUtil::CreateR1(output_elems); + auto expected_r4 = + expected_r1.Reshape({256, 2, 2, 512}).ConsumeValueOrDie(); + auto expected_r4_relaid = + expected_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1})); + + auto input_literal = + client_->TransferToServer(input_r4_relaid).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4_relaid, + {input_literal.get(), filter_literal.get()}, + error_spec_, &expected_r4_relaid.shape()); + } +}; + +TYPED_TEST_CASE(Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Both_Batch_in_Lanes, + TestTypes); +TYPED_TEST(Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Both_Batch_in_Lanes, + Types) { + this->RunTest(); +} + +template +class Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid_Output_Batch_In_Lanes + : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 4, 4, 5}; + std::vector filter_dims = {3, 3, 1, 5}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/5); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); + iota_int_init_value(input_elems, 1); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + auto input_r4_relaid = + input_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1})); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); + iota_int_init_value(filter_elems, 1); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + auto expected_r1 = LiteralUtil::CreateR1( + {static_cast(6864), static_cast(7296), static_cast(7746), + static_cast(8214), static_cast(8700), static_cast(7809), + static_cast(8286), static_cast(8781), static_cast(9294), + static_cast(9825), static_cast(10644), static_cast(11256), + static_cast(11886), static_cast(12534), static_cast(13200), + static_cast(11589), static_cast(12246), static_cast(12921), + static_cast(13614), static_cast(14325)}); + auto expected_r4 = expected_r1.Reshape({1, 2, 2, 5}).ConsumeValueOrDie(); + auto expected_r4_relaid = + expected_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1})); + + auto input_literal = + client_->TransferToServer(input_r4_relaid).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4_relaid, + {input_literal.get(), filter_literal.get()}, + error_spec_, &expected_r4_relaid.shape()); + } +}; + +TYPED_TEST_CASE( + Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid_Output_Batch_In_Lanes, + TestTypes); +TYPED_TEST(Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid_Output_Batch_In_Lanes, + Types) { + this->RunTest(); +} + template class Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Valid : public ConvolutionTest { public: @@ -786,8 +1069,6 @@ class Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Valid : public ConvolutionTest { ComputeAndCompareLiteral(&builder, expected_r4, {input_literal.get(), filter_literal.get()}, error_spec_); - - auto filter_r = filter_r1.Reshape(filter_dims); } }; @@ -796,6 +1077,146 @@ TYPED_TEST(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Valid, Types) { this->RunTest(); } +template +class Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Input_Batch_In_Lanes + : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 4, 4, 160}; + std::vector filter_dims = {3, 3, 1, 160}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/160); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape), + static_cast(1)); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + auto input_r4_relaid = + input_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1})); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), + static_cast(2)); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + std::vector output_elems(640, static_cast(18)); + + auto expected_r1 = LiteralUtil::CreateR1(output_elems); + auto expected_r4 = expected_r1.Reshape({1, 2, 2, 160}).ConsumeValueOrDie(); + auto expected_r4_relaid = + expected_r4.Relayout(LayoutUtil::MakeLayout({3, 0, 2, 1})); + + auto input_literal = + client_->TransferToServer(input_r4_relaid).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4_relaid, + {input_literal.get(), filter_literal.get()}, + error_spec_, &expected_r4_relaid.shape()); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Input_Batch_In_Lanes, + TestTypes); +TYPED_TEST(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Input_Batch_In_Lanes, + Types) { + this->RunTest(); +} + +template +class Convolve2D_1x4x4x160_3x3x1x160_Dephtwise_Both_Batch_In_Lanes + : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 4, 4, 160}; + std::vector filter_dims = {3, 3, 1, 160}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/160); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape), + static_cast(1)); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + auto input_r4_relaid = + input_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1})); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), + static_cast(2)); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + std::vector output_elems(640, static_cast(18)); + + auto expected_r1 = LiteralUtil::CreateR1(output_elems); + auto expected_r4 = expected_r1.Reshape({1, 2, 2, 160}).ConsumeValueOrDie(); + auto expected_r4_relaid = + expected_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1})); + + auto input_literal = + client_->TransferToServer(input_r4_relaid).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4_relaid, + {input_literal.get(), filter_literal.get()}, + error_spec_, &expected_r4_relaid.shape()); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x4x4x160_3x3x1x160_Dephtwise_Both_Batch_In_Lanes, + TestTypes); +TYPED_TEST(Convolve2D_1x4x4x160_3x3x1x160_Dephtwise_Both_Batch_In_Lanes, + Types) { + this->RunTest(); +} + template class Convolve2D_1x4x4x1024_3x3x1x1024_Depthwise_Valid : public ConvolutionTest { @@ -852,8 +1273,6 @@ class Convolve2D_1x4x4x1024_3x3x1x1024_Depthwise_Valid ComputeAndCompareLiteral(&builder, expected_r4, {input_literal.get(), filter_literal.get()}, error_spec_); - - auto filter_r = filter_r1.Reshape(filter_dims); } }; @@ -863,7 +1282,7 @@ TYPED_TEST(Convolve2D_1x4x4x1024_3x3x1x1024_Depthwise_Valid, Types) { } template -class Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid : public ConvolutionTest { +class Convolve2D_1x2x2x6_2x2x2x12_Grouped_Valid : public ConvolutionTest { public: void RunTest() { XlaBuilder builder(TestName()); @@ -922,8 +1341,329 @@ class Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid : public ConvolutionTest { } }; -TYPED_TEST_CASE(Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid, TestTypes); -TYPED_TEST(Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid, Types) { +TYPED_TEST_CASE(Convolve2D_1x2x2x6_2x2x2x12_Grouped_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x2x2x6_2x2x2x12_Grouped_Valid, Types) { + this->RunTest(); +} + +template +class Convolve2D_1x2x2x1024_2x2x128x512_Grouped_Valid : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 2, 2, 1024}; + std::vector filter_dims = {2, 2, 128, 512}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/8); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape), + static_cast(1)); + + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), + static_cast(2)); + + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + std::vector output_elems(512, static_cast(1024)); + auto expected_r1 = LiteralUtil::CreateR1(output_elems); + auto expected_r4 = expected_r1.Reshape({1, 1, 1, 512}).ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(input_r4).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x2x2x1024_2x2x128x512_Grouped_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x2x2x1024_2x2x128x512_Grouped_Valid, Types) { + this->RunTest(); +} + +template +class Convolve2D_1x2x2x1024_2x2x128x8_Grouped_Valid : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 2, 2, 1024}; + std::vector filter_dims = {2, 2, 128, 8}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/8); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape), + static_cast(1)); + + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), + static_cast(2)); + + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + std::vector output_elems(8, static_cast(1024)); + auto expected_r1 = LiteralUtil::CreateR1(output_elems); + auto expected_r4 = expected_r1.Reshape({1, 1, 1, 8}).ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(input_r4).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x2x2x1024_2x2x128x8_Grouped_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x2x2x1024_2x2x128x8_Grouped_Valid, Types) { + this->RunTest(); +} + +template +class Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 2, 2, 12}; + std::vector filter_dims = {2, 2, 3, 4}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/4); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); + iota_int_init_value(input_elems, 1); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); + iota_int_init_value(filter_elems, 1); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + auto expected_r1 = + LiteralUtil::CreateR1({static_cast(7712), static_cast(8816), + static_cast(9992), static_cast(11240)}); + auto expected_r4 = expected_r1.Reshape({1, 1, 1, 4}).ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(input_r4).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid, Types) { + this->RunTest(); +} + +template +class Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid_Filter_OF_In_Sublanes + : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 2, 2, 12}; + std::vector filter_dims = {2, 2, 4, 3}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(3); + dnums.set_kernel_output_feature_dimension(2); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/4); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); + iota_int_init_value(input_elems, 1); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); + iota_int_init_value(filter_elems, 1); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + auto filter_r4_relaid = + filter_r4.Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); + auto expected_r1 = LiteralUtil::CreateR1( + {static_cast(6968), static_cast(8516), static_cast(10280), + static_cast(12260)}); + auto expected_r4 = expected_r1.Reshape({1, 1, 1, 4}).ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(input_r4).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4_relaid).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid_Filter_OF_In_Sublanes, + TestTypes); +TYPED_TEST(Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid_Filter_OF_In_Sublanes, + Types) { + this->RunTest(); +} + +template +class Convolve2D_1x1x1x12_1x1x3x4_Grouped_Valid : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 1, 1, 12}; + std::vector filter_dims = {1, 1, 3, 4}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/4); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); + iota_int_init_value(input_elems, 1); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); + iota_int_init_value(filter_elems, 1); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + auto expected_r1 = + LiteralUtil::CreateR1({static_cast(38), static_cast(98), + static_cast(176), static_cast(272)}); + auto expected_r4 = expected_r1.Reshape({1, 1, 1, 4}).ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(input_r4).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x1x1x12_1x1x3x4_Grouped_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x1x1x12_1x1x3x4_Grouped_Valid, Types) { this->RunTest(); } @@ -1217,6 +1957,18 @@ ENTRY Test { EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001})); } +XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_CPU(ConvolveF64ForwardReversed)) { + constexpr char kHlo[] = R"( +HloModule TestModule + +ENTRY Test { + %arg0 = f64[3,56,56,16] parameter(0) + %arg1 = f64[3,3,3,64] parameter(1) + ROOT %conv = f64[54,54,16,64] convolution(%arg0, %arg1), window={size=3x3 rhs_reversal=1x1}, dim_labels=f01b_i01o->01bf +})"; + EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001})); +} + XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_CPU(ConvolveF64BackwardFilter)) { constexpr char kHlo[] = R"( HloModule TestModule diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 6c0847a8757..25091b8d5d5 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -637,6 +637,76 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMul) { {x_data.get(), y_data.get()}, this->error_spec_); } +#ifndef XLA_TEST_BACKEND_CPU +// TODO(b/74459949): failed on CPU on 2018-10-29. +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulR3LhsR2Rhs) { + using T = TypeParam; + + XlaBuilder builder(this->TestName()); + auto x = + Parameter(&builder, 0, ShapeUtil::MakeShapeWithType({2, 2, 2}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShapeWithType({2, 2}), "y"); + + DotDimensionNumbers dnums; + dnums.add_lhs_contracting_dimensions(1); + dnums.add_rhs_contracting_dimensions(1); + dnums.add_lhs_batch_dimensions(0); + dnums.add_rhs_batch_dimensions(0); + + DotGeneral(x, y, dnums); + + auto x_data = + this->client_ + ->TransferToServer(LiteralUtil::CreateR3FromArray3D( + {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}})) + .ConsumeValueOrDie(); + + auto y_data = this->client_ + ->TransferToServer(LiteralUtil::CreateR2FromArray2D( + {{1.0f, 0.0f}, {0.0f, 1.0f}})) + .ConsumeValueOrDie(); + + this->template ComputeAndCompareR2( + &builder, + /*expected=*/{{1.0f, 2.0f}, {7.0f, 8.0f}}, {x_data.get(), y_data.get()}, + this->error_spec_); +} + +// TODO(b/74459949): failed on CPU on 2018-10-29. +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulR2LhsR3Rhs) { + using T = TypeParam; + + XlaBuilder builder(this->TestName()); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShapeWithType({2, 2}), "x"); + auto y = + Parameter(&builder, 1, ShapeUtil::MakeShapeWithType({2, 2, 2}), "y"); + + DotDimensionNumbers dnums; + dnums.add_lhs_contracting_dimensions(1); + dnums.add_rhs_contracting_dimensions(1); + dnums.add_lhs_batch_dimensions(0); + dnums.add_rhs_batch_dimensions(0); + + DotGeneral(x, y, dnums); + + auto x_data = this->client_ + ->TransferToServer(LiteralUtil::CreateR2FromArray2D( + {{1.0f, 0.0f}, {0.0f, 1.0f}})) + .ConsumeValueOrDie(); + + auto y_data = + this->client_ + ->TransferToServer(LiteralUtil::CreateR3FromArray3D( + {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}})) + .ConsumeValueOrDie(); + + this->template ComputeAndCompareR2( + &builder, + /*expected=*/{{1.0f, 2.0f}, {7.0f, 8.0f}}, {x_data.get(), y_data.get()}, + this->error_spec_); +} +#endif // XLA_TEST_BACKEND_CPU + XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulMultipleBatch) { using T = TypeParam; diff --git a/tensorflow/compiler/xla/tests/grouped_convolution_test.cc b/tensorflow/compiler/xla/tests/grouped_convolution_test.cc new file mode 100644 index 00000000000..8f7049910e7 --- /dev/null +++ b/tensorflow/compiler/xla/tests/grouped_convolution_test.cc @@ -0,0 +1,245 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/service/bfloat16_normalization.h" +#include "tensorflow/compiler/xla/service/despecializer.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +string GetFloatDataType(bool use_bfloat16) { + return use_bfloat16 ? "bf16" : "f32"; +} + +struct GroupedConvolution2DSpec { + int64 input_feature, output_feature, window, stride, pad, lhs_dilate; + int64 group_size, group_count; + std::vector activation_dims; + std::vector activation_layout; + std::vector kernel_dims; + std::vector kernel_layout; + std::vector output_dims; + std::vector output_layout; +}; + +class GroupedConvolution2DTest + : public HloTestBase, + public ::testing::WithParamInterface< + ::testing::tuple> {}; + +static std::vector GetConv2DTestCases() { + std::vector config_set; + // Add to this set if you want a new test configuration. + // Rule : the penultimate number must be divisible by the last number. + std::vector> config_options = {{8, 2, 2, 1, 1024, 128}, + {512, 3, 3, 144, 1024, 16}, + {256, 3, 3, 129, 512, 64}, + {64, 1, 2, 127, 32, 8}, + {256, 3, 3, 256, 1024, 4}}; + + for (auto option : config_options) { + int64 output_feature = option[0]; + int64 activation_size = option[1]; + int64 kernel_size = option[2]; + int64 batch = option[3]; + int64 input_feature = option[4]; + int64 group_size = option[5]; + + std::vector kernel_layout = {3, 2, 1, 0}; + GroupedConvolution2DSpec config; + config.group_size = group_size; + config.group_count = input_feature / group_size; + config.output_feature = output_feature; + config.window = kernel_size; + + config.activation_dims = {batch, activation_size, activation_size, + input_feature}; + config.activation_layout = {3, 0, 2, 1}; + + config.kernel_dims = {kernel_size, kernel_size, group_size, output_feature}; + config.kernel_layout = {3, 2, 1, 0}; + + if (activation_size == 1 && kernel_size == 2) { + // Test for outer dim. + config.output_dims = {batch, activation_size + kernel_size - 1, + activation_size + kernel_size, output_feature}; + } else if (output_feature == 256) { + // Restrict dilation-based tests only to one feature configuration. + config.stride = activation_size - 1; + config.pad = 0; + config.lhs_dilate = output_feature / 32; + config.output_dims = {batch, output_feature / 32, + activation_size - kernel_size + 1, output_feature}; + } else { + config.stride = config.pad = config.lhs_dilate = -1; + config.output_dims = {batch, activation_size - kernel_size + 1, + activation_size - kernel_size + 1, output_feature}; + } + + // Try this layout for all kernel shapes. + config.output_layout = {3, 0, 2, 1}; + config_set.push_back(config); + + // Try other layouts only for certain kernel shapes. + if (kernel_size % 2 == 0) { + config.activation_layout = {0, 3, 2, 1}; + config_set.push_back(config); + + config.output_layout = {0, 3, 2, 1}; + config_set.push_back(config); + + config.activation_layout = {3, 0, 2, 1}; + config_set.push_back(config); + } + } + + return config_set; +} + +string GroupedConvolution2DTestDataToString( + const ::testing::TestParamInfo< + ::testing::tuple>& data) { + const auto& spec = ::testing::get<0>(data.param); + const string data_type = GetFloatDataType(::testing::get<1>(data.param)); + string str = absl::StrCat( + "activation_dims_", absl::StrJoin(spec.activation_dims, "x"), + "_activation_layout_", absl::StrJoin(spec.activation_layout, "_"), + "_kernel_dims_", absl::StrJoin(spec.kernel_dims, "x"), "_kernel_layout_", + absl::StrJoin(spec.kernel_layout, "_"), "_output_dims_", + absl::StrJoin(spec.output_dims, "x"), "_output_layout_", + absl::StrJoin(spec.output_layout, "_"), data_type); + // -1 indicates non-existence. + if (spec.stride != -1) { + absl::StrAppend(&str, "_lhs_dilation_", spec.lhs_dilate, "x1"); + } + + // Test names are not allowed to contain the '-' character. + absl::c_replace(str, '-', 'n'); + return str; +} + +string BuildHloTextGroupedConvolution2D(const GroupedConvolution2DSpec& spec, + bool use_bfloat16) { + const string data_type = GetFloatDataType(use_bfloat16); + if (spec.activation_dims[1] == 1 && spec.kernel_dims[1] == 2) { + // Check for outer dim. + return absl::StrFormat( + R"( + HloModule TensorFlowDepthwiseConv + + ENTRY main { + activation = %s[%s]{%s} parameter(0) + kernel = %s[%s]{%s} parameter(1) + ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel), + window={size=%dx%d pad=1_1x%d_%d rhs_dilate=1x%d}, dim_labels=b01f_01io->b01f, + feature_group_count=%d + } + )", + data_type, absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), data_type, + absl::StrJoin(spec.output_dims, ","), + absl::StrJoin(spec.output_layout, ","), data_type, + absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), spec.window, spec.window, + spec.window, spec.window, spec.window, spec.group_count); + + } else if (spec.stride == -1) { + // Check for basic, non-dilated cases. + return absl::StrFormat( + R"( + HloModule TensorFlowDepthwiseConv + + ENTRY main { + activation = %s[%s]{%s} parameter(0) + kernel = %s[%s]{%s} parameter(1) + ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel), + window={size=%dx%d}, dim_labels=b01f_01io->b01f, + feature_group_count=%d + } + )", + data_type, absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), data_type, + absl::StrJoin(spec.output_dims, ","), + absl::StrJoin(spec.output_layout, ","), data_type, + absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), spec.window, spec.window, + spec.group_count); + } else { + // Check for base dilations. + return absl::StrFormat( + R"( + HloModule TensorFlowDepthwiseConv + + ENTRY main { + activation = %s[%s]{%s} parameter(0) + kernel = %s[%s]{%s} parameter(1) + ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel), + window={size=%dx%d stride=%dx1 pad=%d_%dx0_0 lhs_dilate=%dx1}, + dim_labels=b01f_01io->b01f, feature_group_count=%d + } + )", + data_type, absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), data_type, + absl::StrJoin(spec.output_dims, ","), + absl::StrJoin(spec.output_layout, ","), data_type, + absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), spec.window, spec.window, + spec.stride, 0, 0, spec.lhs_dilate, spec.group_count); + } +} + +XLA_TEST_P(GroupedConvolution2DTest, DoIt) { + const GroupedConvolution2DSpec& spec = ::testing::get<0>(GetParam()); + bool use_bfloat16 = ::testing::get<1>(GetParam()); + const string hlo_text = BuildHloTextGroupedConvolution2D(spec, use_bfloat16); + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{0.01, 0.01}, + [](HloModule* module) -> Status { + BFloat16MixedPrecisionRemoval remover; + TF_RETURN_IF_ERROR(remover.Run(module).status()); + Despecializer despecializer; + return despecializer.Run(module).status(); + })); +} + +INSTANTIATE_TEST_CASE_P( + GroupedConvolution2DTestWithRandomIndices, GroupedConvolution2DTest, + ::testing::Combine(::testing::ValuesIn(GetConv2DTestCases()), + ::testing::Bool()), + GroupedConvolution2DTestDataToString); + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index d8fa00272f8..989a7c705a8 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -99,6 +99,8 @@ void VerifiedHloModule::VerifyOrAddFailure(const string& message) { ADD_FAILURE() << "HloVerifier failed on module " << name() << (message.empty() ? "" : absl::StrCat(" (", message, ")")) << ": " << status; + LOG(ERROR) << "Contents of bad module:"; + XLA_LOG_LINES(tensorflow::ERROR, ToString()); } } @@ -140,14 +142,6 @@ std::unique_ptr HloTestBase::CreateNewVerifiedModule( allow_mixed_precision_in_hlo_verifier_); } -StatusOr> -HloTestBase::ParseAndReturnUnverifiedModule(absl::string_view hlo_text, - const HloModuleConfig& config) { - auto module = absl::make_unique(TestName(), config); - TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get())); - return std::move(module); -} - StatusOr> HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text, const HloModuleConfig& config) { diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 366726d90b4..1d1e7f43729 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/base/macros.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/backend.h" @@ -100,6 +101,7 @@ class HloTestBase : public ::testing::Test { // // This returns a vanilla HloModule that doesn't run the HLO verifier on // destruction. + ABSL_DEPRECATED("Use CreateNewVerifiedModule instead.") std::unique_ptr CreateNewUnverifiedModule( const string& name = TestName()); @@ -108,12 +110,6 @@ class HloTestBase : public ::testing::Test { std::unique_ptr CreateNewVerifiedModule( const string& name = TestName()); - // Parses the given string and returns module as a vanilla, unverified - // HloModule. - StatusOr> ParseAndReturnUnverifiedModule( - absl::string_view hlo_text, - const HloModuleConfig& config = HloModuleConfig()); - // Parses the given string and returns module as a VerifiedHloModule. StatusOr> ParseAndReturnVerifiedModule( absl::string_view hlo_text, diff --git a/tensorflow/compiler/xla/tests/iota_test.cc b/tensorflow/compiler/xla/tests/iota_test.cc index 310f3495922..65205f53ddc 100644 --- a/tensorflow/compiler/xla/tests/iota_test.cc +++ b/tensorflow/compiler/xla/tests/iota_test.cc @@ -113,5 +113,26 @@ INSTANTIATE_TEST_CASE_P(IotaR3TestInstantiation, IotaR3Test, /*step=*/10), ::testing::Values(0, 1, 2))); +class IotaR3PredTest : public ClientLibraryTestBase, + public ::testing::WithParamInterface {}; + +TEST_P(IotaR3PredTest, DoIt) { + const auto element_type = PRED; + const int64 num_elements = 2; + const int64 iota_dim = GetParam(); + XlaBuilder builder(TestName() + "_" + PrimitiveType_Name(element_type)); + std::vector dimensions = {42, 19}; + dimensions.insert(dimensions.begin() + iota_dim, num_elements); + Iota(&builder, ShapeUtil::MakeShape(element_type, dimensions), iota_dim); + if (primitive_util::IsFloatingPointType(element_type)) { + ComputeAndCompare(&builder, {}, ErrorSpec{0.0001}); + } else { + ComputeAndCompare(&builder, {}); + } +} + +INSTANTIATE_TEST_CASE_P(IotaR3PredTestInstantiation, IotaR3PredTest, + ::testing::Values(0, 1, 2)); + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/replay_test.cc b/tensorflow/compiler/xla/tests/replay_test.cc index 5cf87e565bf..34c7dc7c464 100644 --- a/tensorflow/compiler/xla/tests/replay_test.cc +++ b/tensorflow/compiler/xla/tests/replay_test.cc @@ -55,7 +55,8 @@ TEST_F(ReplayTest, TwoPlusTwoReplay) { client_->GetComputationShape(computation).ConsumeValueOrDie(); std::unique_ptr replayed_shape = client_->GetComputationShape(replayed).ConsumeValueOrDie(); - ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape)); + ASSERT_TRUE(protobuf_util::ProtobufEquals(original_shape->ToProto(), + replayed_shape->ToProto())); // Run it. Literal literal = @@ -87,7 +88,8 @@ XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) { client_->GetComputationShape(computation).ConsumeValueOrDie(); std::unique_ptr replayed_shape = client_->GetComputationShape(replayed).ConsumeValueOrDie(); - ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape)); + ASSERT_TRUE(protobuf_util::ProtobufEquals(original_shape->ToProto(), + replayed_shape->ToProto())); // Run it. std::unique_ptr x_data = @@ -133,7 +135,8 @@ TEST_F(ReplayTest, MapPlusTwoOverR1) { client_->GetComputationShape(computation).ConsumeValueOrDie(); std::unique_ptr replayed_shape = client_->GetComputationShape(replayed).ConsumeValueOrDie(); - ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape)); + ASSERT_TRUE(protobuf_util::ProtobufEquals(original_shape->ToProto(), + replayed_shape->ToProto())); // Run it. Literal literal = diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index dedc95b5ae8..298136002e9 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -618,7 +618,8 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { ExecutionOptions execution_options = execution_options_; *execution_options.mutable_shape_with_output_layout() = ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {2, 8}, - {1, 0}); + {1, 0}) + .ToProto(); Literal actual = client_ ->ExecuteAndTransfer(computation, {input.get()}, &execution_options) @@ -767,7 +768,8 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { ExecutionOptions execution_options = execution_options_; *execution_options.mutable_shape_with_output_layout() = ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {7, 2, 3, 5}, - {2, 3, 0, 1}); + {2, 3, 0, 1}) + .ToProto(); Literal output_literal = client_ ->ExecuteAndTransfer(computation, {input_data.get()}, diff --git a/tensorflow/compiler/xla/tests/scatter_test.cc b/tensorflow/compiler/xla/tests/scatter_test.cc index 7e1f4aa0eb4..32de0fdf78f 100644 --- a/tensorflow/compiler/xla/tests/scatter_test.cc +++ b/tensorflow/compiler/xla/tests/scatter_test.cc @@ -129,6 +129,42 @@ ENTRY main { RunTest(hlo_text, &operand, &scatter_indices, &updates); } +XLA_TEST_F(ScatterTest, TensorFlowScatterV2_InversePermutation) { + const char* hlo_text = R"( +HloModule TensorFlowScatterV2 + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + permutation = s32[3,4] parameter(0) + reshape = s32[3,4,1] reshape(permutation) + operand = s32[3,4] iota(), iota_dimension=1 + updates = s32[3,4,1,1] iota(), iota_dimension=1 + iota = s32[3,4,1] iota(), iota_dimension=0 + indices = s32[3,4,2] concatenate(iota, reshape), dimensions={2} + ROOT scatter = s32[3,4] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={2,3}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=2 +} +)"; + Literal permutation = + LiteralUtil::CreateR2({{1, 3, 2, 0}, {3, 0, 2, 1}, {2, 3, 1, 0}}); + HloModuleConfig config; + config.set_debug_options(GetDebugOptionsForTest()); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_text, config)); + auto actual = ExecuteAndTransfer(std::move(module), {&permutation}); + Literal expected = + LiteralUtil::CreateR2({{3, 0, 2, 1}, {1, 3, 2, 0}, {3, 2, 0, 1}}); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual)); +} + XLA_TEST_F(ScatterTest, SimpleR4) { const char* hlo_text = R"( HloModule SimpleR4 diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 2f18036ff4c..eafa48ed7b8 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -15,6 +15,7 @@ limitations under the License. #include +#include "absl/base/casts.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/primitive_util.h" @@ -28,65 +29,113 @@ namespace xla { namespace { template -void PopulateWithRandomFloatingPointDataImpl(Literal* literal, - std::minstd_rand0* engine, - bool no_duplicates) { - CHECK(engine != nullptr); - CHECK_EQ(literal->shape().element_type(), - primitive_util::NativeToPrimitiveType()); - if (no_duplicates) { - // Duplicates may be generated if the number of elements in the literal - // exceeds the number of positive values supported by the type. - FloatT next_value = std::numeric_limits::min(); - for (FloatT& value : literal->data()) { - value = next_value; - next_value = - std::nextafter(next_value, std::numeric_limits::max()); - } - std::shuffle(literal->data().begin(), literal->data().end(), - *engine); - } else { - std::uniform_real_distribution generator(-0.1f, 0.2f); - for (FloatT& value : literal->data()) { - value = static_cast(generator(*engine)); - } +void PopulateWithRandomFloatingPointData(Literal* literal, + std::minstd_rand0* engine) { + std::uniform_real_distribution generator(-0.1f, 0.2f); + for (FloatT& value : literal->data()) { + value = static_cast(generator(*engine)); } } template -void PopulateWithRandomFloatingPointData(Literal* literal, - std::minstd_rand0* engine, - bool no_duplicates) { - CHECK(engine != nullptr); - PopulateWithRandomFloatingPointDataImpl(literal, engine, - no_duplicates); -} +void PopulateWithIntNext(Literal* literal); template <> -void PopulateWithRandomFloatingPointData(Literal* literal, - std::minstd_rand0* engine, - bool no_duplicates) { - // no_duplicates is ignored for half types. Unique values can only be - // generated for arrays with fewer than ~2**16 elements and no_duplicates is - // best-effort anyway. - CHECK(engine != nullptr); - std::uniform_real_distribution generator(-0.1f, 0.2f); +void PopulateWithIntNext(Literal* literal) { + // Duplicates may be generated if we don't have enough bits. + uint16 next_value = 0; for (half& value : literal->data()) { - value = static_cast(generator(*engine)); + // Zero-out the MSB of the exponent to avoid Infs and NaNs, and put it into + // the sign bit. We could be less wasteful, but this is best-effort anyway. + uint16 exponent_msb = next_value & 0x4000; + value.x = (next_value & 0xBFFF) | (exponent_msb << 1); + next_value++; } } template <> -void PopulateWithRandomFloatingPointData(Literal* literal, - std::minstd_rand0* engine, - bool no_duplicates) { - // no_duplicates is ignored for bfloat types. Unique values can only be - // generated for arrays with fewer than ~2**16 elements and no_duplicates is - // best-effort anyway. - CHECK(engine != nullptr); - std::uniform_real_distribution generator(-0.1f, 0.2f); +void PopulateWithIntNext(Literal* literal) { + // Duplicates may be generated if we don't have enough bits. + // Start at 0x80 rather than 0 to avoid denormals. + uint16 next_value = 0x80; for (bfloat16& value : literal->data()) { - value = static_cast(generator(*engine)); + // Zero-out the MSB of the exponent to avoid Infs and NaNs, and put it into + // the sign bit. We could be less wasteful, but this is best-effort anyway. + uint16 exponent_msb = next_value & 0x4000; + value.value = (next_value & 0xBFFF) | (exponent_msb << 1); + next_value++; + } +} + +template +void PopulateWithNextAfter(Literal* literal) { + // Duplicates may be generated if the number of elements in the literal + // exceeds the number of positive values supported by the type. + float next_value = std::numeric_limits::min(); + for (float& value : literal->data()) { + value = next_value; + next_value = std::nextafter(next_value, std::numeric_limits::max()); + } +} + +template ::value || + std::is_same::value, + int>::type = 0> +void PopulateWithNoDuplicateData(Literal* literal, std::minstd_rand0* engine) { + PopulateWithIntNext(literal); + std::shuffle(literal->data().begin(), literal->data().end(), + *engine); +} + +template ::value && + !std::is_same::value, + int>::type = 0> +void PopulateWithNoDuplicateData(Literal* literal, std::minstd_rand0* engine) { + PopulateWithNextAfter(literal); + std::shuffle(literal->data().begin(), literal->data().end(), + *engine); +} + +template +void PopulateWithFloatingPointData(Literal* literal, std::minstd_rand0* engine, + bool no_duplicates) { + CHECK(engine != nullptr); + CHECK_EQ(literal->shape().element_type(), + primitive_util::NativeToPrimitiveType()); + if (no_duplicates) { + PopulateWithNoDuplicateData(literal, engine); + } else { + PopulateWithRandomFloatingPointData(literal, engine); + } +} + +template <> +void PopulateWithFloatingPointData(Literal* literal, + std::minstd_rand0* engine, + bool no_duplicates) { + CHECK(engine != nullptr); + CHECK_EQ(literal->shape().element_type(), + primitive_util::NativeToPrimitiveType()); + if (no_duplicates) { + PopulateWithNoDuplicateData(literal, engine); + } else { + PopulateWithRandomFloatingPointData(literal, engine); + } +} + +template <> +void PopulateWithFloatingPointData(Literal* literal, + std::minstd_rand0* engine, + bool no_duplicates) { + CHECK(engine != nullptr); + CHECK_EQ(literal->shape().element_type(), + primitive_util::NativeToPrimitiveType()); + if (no_duplicates) { + PopulateWithNoDuplicateData(literal, engine); + } else { + PopulateWithRandomFloatingPointData(literal, engine); } } @@ -135,20 +184,16 @@ StatusOr MakeFakeLiteralInternal(const Shape& shape, Literal literal(shape); switch (shape.element_type()) { case BF16: - PopulateWithRandomFloatingPointData(&literal, engine, - no_duplicates); + PopulateWithFloatingPointData(&literal, engine, no_duplicates); break; case F16: - PopulateWithRandomFloatingPointData(&literal, engine, - no_duplicates); + PopulateWithFloatingPointData(&literal, engine, no_duplicates); break; case F32: - PopulateWithRandomFloatingPointData(&literal, engine, - no_duplicates); + PopulateWithFloatingPointData(&literal, engine, no_duplicates); break; case F64: - PopulateWithRandomFloatingPointData(&literal, engine, - no_duplicates); + PopulateWithFloatingPointData(&literal, engine, no_duplicates); break; case S8: PopulateWithRandomIntegralData(&literal, engine, no_duplicates); diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc index e066b3f4f22..e8f5d7a9a79 100644 --- a/tensorflow/compiler/xla/tests/test_utils_test.cc +++ b/tensorflow/compiler/xla/tests/test_utils_test.cc @@ -175,5 +175,28 @@ ENTRY %sort.148.1589 (parameter.0: s32[1048576], parameter.1: s32[1048576]) -> ( } } +XLA_TEST_F(TestUtilsTest, NoDuplicatesBfloat16) { + // Inputs which are sort keys in key/value sorts should have no duplicates. + auto module = ParseHloString(R"( +HloModule sort, is_scheduled=true + +ENTRY %sort. (parameter.0: bf16[2,1452], parameter.1: s32[2,1452]) -> (bf16[2,1452], s32[2,1452]) { + %parameter.0 = bf16[2,1452]{1,0} parameter(0) + %parameter.1 = s32[2,1452]{1,0} parameter(1) + ROOT %sort = (bf16[2,1452]{1,0}, s32[2,1452]{1,0}) sort(bf16[2,1452]{1,0} %parameter.0, s32[2,1452]{1,0} %parameter.1), dimensions={1} +} +)") + .ValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN(std::vector args, + MakeFakeArguments(module.get())); + ASSERT_EQ(args.size(), 2); + const Literal& key_arg = args[0]; + + absl::flat_hash_set key_set; + for (const bfloat16& value : key_arg.data()) { + EXPECT_TRUE(key_set.insert(absl::bit_cast(value)).second); + } +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc index a2b7c26331b..601c6b06938 100644 --- a/tensorflow/compiler/xla/tests/token_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include "absl/strings/str_cat.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -108,26 +109,6 @@ XLA_TEST_F(TokenHloTest, InvalidTupleTokenShapedEntryParameter) { ::testing::HasSubstr("Entry parameter 0 is or contains a token shape")); } -XLA_TEST_F(TokenHloTest, InvalidOperandToTokenInstruction) { - std::unique_ptr module = CreateNewUnverifiedModule(); - auto builder = HloComputation::Builder(TestName()); - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0")); - builder.AddInstruction(HloInstruction::CreateAfterAll({param})); - builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(123))); - module->AddEntryComputation(builder.Build()); - - Status status = - HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false) - .Run(module.get()) - .status(); - ASSERT_IS_NOT_OK(status); - EXPECT_THAT(status.error_message(), - ::testing::HasSubstr( - "Operands of token instructions must be TOKEN types")); -} - XLA_TEST_F(TokenHloTest, TokenInWhileLoop) { // Thread a token around a while loop. Token is created and consumed by a // AfterAll instruction in the while body. @@ -220,5 +201,95 @@ ENTRY %TokenInConditional (param.3: pred[]) -> s32[] { } } +XLA_TEST_F(TokenHloTest, AddDependency) { + string module_string = R"( +HloModule AddDependency, is_scheduled=true + +// Computes (p0 + 42) * (-p1) +// where there is a dependency from the add to the negation using a token +// with after-all and add-dependency instructions. +ENTRY %AddDependency (p0: f32[], p1: f32[]) -> f32[] { + %p0 = f32[] parameter(0) + %p1 = f32[] parameter(1) + + %forty_two = f32[] constant(42.0) + %add = f32[] add(f32[] %p0, f32[] %forty_two) + %token = token[] after-all(f32[] %add) + %p1_after_token = f32[] add-dependency(f32[] %p1, token[] %token) + %neg = f32[] negate(f32[] %p1_after_token) + ROOT %product = f32[] multiply(f32[] %add, f32[] %neg) +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseHloString(module_string, GetModuleConfigForTest())); + auto p0 = LiteralUtil::CreateR0(10.0); + auto p1 = LiteralUtil::CreateR0(3.0); + auto expected = LiteralUtil::CreateR0(-156.0); + EXPECT_EQ(expected, ExecuteNoHloPasses(std::move(module), {&p0, &p1})); +} + +XLA_TEST_F(TokenHloTest, AddDependencyOfConstant) { + string module_string = R"( +HloModule AddDependencyOfConstant, is_scheduled=true + +ENTRY %AddDependency (p0: f32[]) -> f32[] { + %p0 = f32[] parameter(0) + %forty_two = f32[] constant(42.0) + %token = token[] after-all(f32[] %p0) + %forty_two_after_token = f32[] add-dependency(f32[] %forty_two, token[] %token) + ROOT %product = f32[] multiply(f32[] %p0, f32[] %forty_two_after_token) +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseHloString(module_string, GetModuleConfigForTest())); + auto p0 = LiteralUtil::CreateR0(10.0); + auto expected = LiteralUtil::CreateR0(420.0); + EXPECT_EQ(expected, ExecuteNoHloPasses(std::move(module), {&p0})); +} + +XLA_TEST_F(TokenHloTest, AddDependencyAsRoot) { + string module_string = R"( +HloModule AddDependencyAsRoot, is_scheduled=true +ENTRY %AddDependency (p: f32[3]) -> f32[3] { + %p = f32[3] parameter(0) + %neg = f32[3] negate(f32[3] %p) + %token = token[] after-all() + ROOT %add_dep = f32[3] add-dependency(f32[3] %neg, token[] %token) +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseHloString(module_string, GetModuleConfigForTest())); + auto input = LiteralUtil::CreateR1({1.0, 3.0, 7.0}); + auto expected = LiteralUtil::CreateR1({-1.0, -3.0, -7.0}); + EXPECT_EQ(expected, ExecuteNoHloPasses(std::move(module), {&input})); +} + +XLA_TEST_F(TokenHloTest, TupleShapedAddDependency) { + string module_string = R"( +HloModule TupleShapedAddDependency, is_scheduled=true +ENTRY %TupleShapedAddDependency (p0: f32[3], p1: f32[3]) -> f32[3] { + %p0 = f32[3] parameter(0) + %p1 = f32[3] parameter(1) + %forty_two = f32[] constant(42.0) + %token = token[] after-all() + %tuple = (f32[3], token[], f32[3], f32[]) tuple(f32[3] %p0, token[] %token, f32[3] %p1, f32[] %forty_two) + %add_dep = (f32[3], token[], f32[3], f32[]) add-dependency((f32[3], token[], f32[3], f32[]) %tuple, token[] %token) + %elem0 = f32[3] get-tuple-element((f32[3], token[], f32[3], f32[]) %add_dep), index=0 + %elem2 = f32[3] get-tuple-element((f32[3], token[], f32[3], f32[]) %add_dep), index=2 + ROOT %diff = f32[3] subtract(f32[3] %elem0, f32[3] %elem2) +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseHloString(module_string, GetModuleConfigForTest())); + auto p0 = LiteralUtil::CreateR1({3.0, 3.0, 47.0}); + auto p1 = LiteralUtil::CreateR1({1.0, -2.0, 2.0}); + auto expected = LiteralUtil::CreateR1({2.0, 5.0, 45.0}); + EXPECT_EQ(expected, ExecuteNoHloPasses(std::move(module), {&p0, &p1})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index ca036f1ae0d..e57d072a063 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -157,10 +157,12 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice( stream_ptr.get(), Literal::CreateFromShape(rhs_arg_shape), rhs_arg)); + ExecutableBuildOptions build_options; + build_options.mutable_debug_options()->set_xla_hlo_profile(true); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr local_executable, client->Compile(computation, {&lhs_arg_shape, &rhs_arg_shape}, - ExecutableBuildOptions().set_hlo_profile(true))); + build_options)); Executable* executable = local_executable->executable(); HloExecutionProfile hlo_execution_profile( @@ -208,7 +210,7 @@ XLA_TEST_F(HloProfileTest, ProfileSingleComputation) { string profile_output; ExecuteAndFetchProfile(&profile_output, client, computation, lhs_shape, rhs_shape); - + VLOG(4) << "Profile Output:\n" << profile_output; std::vector profile_output_lines = absl::StrSplit(profile_output, '\n'); diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index 47be9f5adf1..ff2c3399928 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -82,13 +82,17 @@ struct Options { std::unique_ptr CompileExecutable(const HloSnapshot& module, LocalClient* client) { XlaComputation computation(module.hlo().hlo_module()); - std::vector argument_layouts; - for (const auto& param : + std::vector argument_layouts; + argument_layouts.reserve( + computation.proto().host_program_shape().parameters_size()); + std::vector argument_layout_ptrs; + for (const ShapeProto& param : computation.proto().host_program_shape().parameters()) { - argument_layouts.push_back(¶m); + argument_layouts.push_back(Shape(param)); + argument_layout_ptrs.push_back(&argument_layouts.back()); } return client - ->Compile(computation, argument_layouts, ExecutableBuildOptions()) + ->Compile(computation, argument_layout_ptrs, ExecutableBuildOptions()) .ValueOrDie(); } @@ -149,7 +153,7 @@ StatusOr ReplayComputation(const HloSnapshot& module, << "--generate_fake_infeed only works if the model has 0 or 1 " "infeed ops, but this one has >= 2."; provide_infeed = true; - infeed_shape = instruction.shape(); + infeed_shape = Shape(instruction.shape()); LOG(INFO) << "Generating fake infeed shape for inferred shape: " << ShapeUtil::HumanString(infeed_shape); } @@ -315,9 +319,10 @@ int RealMain(absl::Span args, const Options& opts) { if (snapshot.has_result()) { Literal literal = Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie(); - fprintf(stdout, "was %s:%s\n", - ShapeUtil::HumanString(snapshot.result().shape()).c_str(), - literal.ToString().c_str()); + fprintf( + stdout, "was %s:%s\n", + ShapeUtil::HumanString(Shape(snapshot.result().shape())).c_str(), + literal.ToString().c_str()); } } } diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index 8ce74164741..6722641e9d2 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -152,6 +152,13 @@ static inline absl::Span AsInt64Slice( slice.size()); } +// TODO(b/29771030): This nop overload was added to simplify the migration of +// Shape from a proto to a C++ class. Remove after class has been migrated. +static inline absl::Span AsInt64Slice( + absl::Span slice) { + return slice; +} + // As above, but for uint64 types. static inline absl::Span AsUInt64Slice( const tensorflow::protobuf::RepeatedField& v) { @@ -387,6 +394,19 @@ T CeilOfRatio(T dividend, T divisor) { return tensorflow::MathUtil::CeilOfRatio(dividend, divisor); } +template +std::vector ElementWiseCeilOfRatio(absl::Span dividends, + absl::Span divisors) { + std::vector ceil_of_ratios; + CHECK_EQ(dividends.size(), divisors.size()); + ceil_of_ratios.reserve(dividends.size()); + absl::c_transform(dividends, divisors, std::back_inserter(ceil_of_ratios), + [](const T dividend, const T divisor) { + return CeilOfRatio(dividend, divisor); + }); + return ceil_of_ratios; +} + // Rounds the value up to a multiple of the divisor by first calling CeilOfRatio // then multiplying by the divisor. For example: RoundUpToNearest(13, 8) => 16 template diff --git a/tensorflow/compiler/xla/window_util.cc b/tensorflow/compiler/xla/window_util.cc index 8ea8dbab257..f113a705b41 100644 --- a/tensorflow/compiler/xla/window_util.cc +++ b/tensorflow/compiler/xla/window_util.cc @@ -185,6 +185,17 @@ bool HasWindowReversal(const Window& window) { return false; } +bool AllOrNoneReversed(const Window& window) { + if (window.dimensions().size() == 0) { + return true; + } + bool reversed = window.dimensions()[0].window_reversal(); + return std::all_of(window.dimensions().begin(), window.dimensions().end(), + [&](const WindowDimension& dim) { + return dim.window_reversal() == reversed; + }); +} + bool HasDilation(const Window& window) { return HasBaseDilation(window) || HasWindowDilation(window); } diff --git a/tensorflow/compiler/xla/window_util.h b/tensorflow/compiler/xla/window_util.h index 1fb9e855fc1..099d7ecdd5c 100644 --- a/tensorflow/compiler/xla/window_util.h +++ b/tensorflow/compiler/xla/window_util.h @@ -56,6 +56,7 @@ bool HasWindowDilation(const Window& window); bool HasDilation(const Window& window); bool HasWindowReversal(const Window& window); +bool AllOrNoneReversed(const Window& window); // Returns true if the given logical dimension is inactive in the sense that it // has window bound 1, no striding and no padding. diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 28df3b03f39..bdeb1728fa2 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -193,7 +193,11 @@ message DebugOptions { // - Assuming that operations never produce or consume NaN or +/- Inf. // - Assuming that +0 and -0 are indistinguishable. bool xla_cpu_enable_fast_math = 99; - bool xla_gpu_enable_fast_math = 100; + + // When true we lower the Minimum and Maximum hlos in the GPU backend such + // that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NotNaN. In other words, if flag + // this is true we don't propagate NaNs through Min and Max. + bool xla_gpu_enable_fast_min_max = 100; // Crashes the program when any kind of verification fails, instead of just // logging the failures. One example is cross checking of convolution results @@ -224,7 +228,7 @@ message ExecutionOptions { // may be faster when using this layout. // // We use a Shape here to accommodate computations that return a tuple. - Shape shape_with_output_layout = 2; + ShapeProto shape_with_output_layout = 2; // Used to seed random-number generators used in this computation. If this is // 0, we generate a seed ourselves. @@ -253,7 +257,7 @@ message TransferToClientRequest { // This optional field directs the service to return the literal in this // layout. A shape is used to hold the layout to accommodate tuples. - Shape shape_with_layout = 2; + ShapeProto shape_with_layout = 2; } message TransferToClientResponse { @@ -281,7 +285,7 @@ message TransferToInfeedResponse { message TransferFromOutfeedRequest { // This optional field directs the service to return the literal in this // layout. A shape is used to hold the layout to accommodate tuples. - Shape shape_with_layout = 1; + ShapeProto shape_with_layout = 1; int64 replica_id = 2; DeviceHandle device_handle = 3; @@ -332,7 +336,7 @@ message CompileRequest { // The layouts of the input arguments. If not set, the default layout will be // used. Although the real arguments are not needed in compilation, the // layouts of the arguments can affect the compilation. - repeated Shape input_shape_with_layout = 3; + repeated ShapeProto input_shape_with_layout = 3; } message CompileResponse { @@ -406,7 +410,7 @@ message LoadDataRequest { string columnio_field = 2; // Individual element shape, excluding rows. - Shape element_shape = 3; + ShapeProto element_shape = 3; // Warning: ColumnIO does not support random-access, so use offset with // caution in performance-critical scenarios. @@ -422,7 +426,7 @@ message LoadDataRequest { message LoadDataResponse { GlobalDataHandle data = 1; - Shape data_shape = 2; + ShapeProto data_shape = 2; int64 available_rows = 3; int64 rows_loaded = 4; int64 nanoseconds = 5; @@ -433,7 +437,7 @@ message GetShapeRequest { } message GetShapeResponse { - Shape shape = 1; + ShapeProto shape = 1; } message UnpackRequest { diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 683ccc40f16..85ec83437a1 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -108,6 +108,16 @@ enum Format { SPARSE = 2; } +// Describes a tile used in tiling-based layout. Refer to +// g3doc/layout_with_tiling.md for details about tiling-based layout. +message Tile { + // Number of elements in each dimension of the tile. It's ordered from the + // most major dimension of the tile to the most minor dimension of the tile. + // The dimensions correspond to a suffix of the dimensions of the shape being + // tiled. + repeated int64 dimensions = 1; +} + // A layout describes how the array is placed in (1D) memory space. This // includes the minor-to-major ordering of dimensions within a shape. // @@ -138,6 +148,20 @@ message Layout { // memory. This field must be unset unless the format is SPARSE. int64 max_sparse_elements = 5; + // A sequence of tiles, starting from the tile that's applied first to the + // Shape. + // + // TODO(b/119839262): implement tiling in each backend or add Unimplemented + // error. + repeated Tile tiles = 6; + + // Bit size of each element. If the size is bigger than what the element + // type requires, the value is stored in the least significant + // bits and the additional most significant bits are filled with 0's. + // + // TODO(b/119839262): implement in each backend or add Unimplemented error. + int64 element_size_in_bits = 7; + // Important: if any field is added, be sure to modify ShapeUtil::Equal() and // LayoutUtil::Hash appropriately to account for the new field. } @@ -154,7 +178,7 @@ message Layout { // See the XLA documentation for more information on shapes and layouts. // // LINT.IfChange -message Shape { +message ShapeProto { reserved 1; reserved "rank"; @@ -169,7 +193,7 @@ message Shape { repeated int64 dimensions = 3; // For tuples only, the shapes of constitutent shapes in the tuple sequence. - repeated Shape tuple_shapes = 4; + repeated ShapeProto tuple_shapes = 4; // The layout used to back this shape. Layout layout = 5; @@ -183,9 +207,9 @@ message Shape { // Shape of the parameters and output of a computation (like a traditional // function signature). -message ProgramShape { - repeated Shape parameters = 1; - Shape result = 2; +message ProgramShapeProto { + repeated ShapeProto parameters = 1; + ShapeProto result = 2; repeated string parameter_names = 3; } @@ -320,7 +344,7 @@ message DeviceAssignmentProto { // Transfers to/from the client are encoded in literal form, and the structure // of the repeated fields is implied by the shape. message LiteralProto { - Shape shape = 1; + ShapeProto shape = 1; repeated bool preds = 2; bytes s8s = 15; bytes u8s = 3; @@ -521,7 +545,7 @@ message OpSharding { } Type type = 1; // The shape of the sharded tile. - Shape tile_shape = 2; + ShapeProto tile_shape = 2; // The shape of the tile assignment tensor - this must be the same rank as // tile_shape and the product of its dimensions must equal // tile_assignment_devices.size(). diff --git a/tensorflow/compiler/xrt/BUILD b/tensorflow/compiler/xrt/BUILD index 2ff97914f86..2dae746d034 100644 --- a/tensorflow/compiler/xrt/BUILD +++ b/tensorflow/compiler/xrt/BUILD @@ -22,6 +22,7 @@ xla_proto_library( deps = [ "//tensorflow/compiler/tf2xla:host_compute_metadata_proto", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla:xla_proto", "//tensorflow/compiler/xla/service:hlo_proto", ], ) @@ -32,20 +33,25 @@ cc_library( "xrt_compilation_cache.cc", "xrt_device.cc", "xrt_state.cc", + "xrt_util.cc", ], hdrs = [ "xrt_compilation_cache.h", "xrt_device.h", "xrt_state.h", + "xrt_util.h", ], deps = [ "//tensorflow/compiler/jit:xla_device", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla:xla_proto", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:device_memory_allocator", diff --git a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc index dc62cf7a6b2..2ccdf0f02d8 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xrt/xrt.pb.h" #include "tensorflow/compiler/xrt/xrt_compilation_cache.h" #include "tensorflow/compiler/xrt/xrt_device.h" +#include "tensorflow/compiler/xrt/xrt_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor.h" @@ -108,19 +109,26 @@ Status XRTCompileOp::Compile(OpKernelContext* ctx, TF_ASSIGN_OR_RETURN(xla::XlaComputation computation, client->LoadSnapshot(computation_proto.hlo_snapshot())); - std::vector argument_layouts( + std::vector argument_layouts( + config.program_shape().parameters_size()); + std::vector argument_layout_ptrs( config.program_shape().parameters_size()); for (int i = 0; i < config.program_shape().parameters_size(); ++i) { - argument_layouts[i] = &config.program_shape().parameters(i); + argument_layouts[i] = xla::Shape(config.program_shape().parameters(i)); + argument_layout_ptrs[i] = &argument_layouts[i]; } xla::ExecutableBuildOptions build_options; build_options.set_device_ordinal(client->default_device_ordinal()); - build_options.set_result_layout(config.program_shape().result()); + build_options.set_result_layout(xla::Shape(config.program_shape().result())); build_options.set_device_allocator(device_ref.backend()->memory_allocator()); + if (config.has_debug_options()) { + *build_options.mutable_debug_options() = + BuildXlaDebugOptions(config.debug_options()); + } VLOG(1) << "Building executable"; auto compile_result = - client->Compile(computation, argument_layouts, build_options); + client->Compile(computation, argument_layout_ptrs, build_options); if (!compile_result.ok()) { return compile_result.status(); } @@ -174,11 +182,12 @@ void XRTCompileOp::Compute(OpKernelContext* ctx) { ctx->set_output(0, handle_output); xla::LocalExecutable* executable = entry->get().get_executable(); - xla::ProgramShape program_shape = executable->executable() - ->module() - .config() - .entry_computation_layout() - .ComputeProgramShape(); + xla::ProgramShapeProto program_shape = executable->executable() + ->module() + .config() + .entry_computation_layout() + .ComputeProgramShape() + .ToProto(); Tensor program_shape_output(DT_STRING, TensorShape({1})); program_shape_output.vec()(0) = program_shape.SerializeAsString(); ctx->set_output(1, program_shape_output); diff --git a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc index 8c6191ddc06..751329eefc3 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc @@ -228,14 +228,35 @@ Status XRTExecuteOp::DoWork(OpKernelContext* context) { TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer( shaped_buffer, device_ref.backend(), device_ref.device_ordinal(), &output_tuple)); + if (config_proto.return_exploded_tuple() && + xla::ShapeUtil::IsTuple(output_tuple->on_device_shape())) { + int64 tuple_element_count = + xla::ShapeUtil::TupleElementCount(output_tuple->on_device_shape()); + Tensor* output_tensor; + TF_RETURN_IF_ERROR(context->allocate_output( + 0, TensorShape({tuple_element_count}), &output_tensor)); - Tensor* output_tensor; - TF_RETURN_IF_ERROR( - context->allocate_output(0, TensorShape({}), &output_tensor)); - int64 key; - TF_RETURN_IF_ERROR(output_tuple->Intern(rm, &key)); - output_tensor->scalar()() = key; + for (int64 i = 0; i < tuple_element_count; ++i) { + xla::ShapeIndex shape_index; + shape_index.push_back(i); + XRTTupleAllocation* suballocation; + TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer( + output_tuple, shape_index, &suballocation, + /*alias_parent_allocation=*/false)); + int64 key; + TF_RETURN_IF_ERROR(suballocation->Intern(rm, &key)); + output_tensor->vec()(i) = key; + } + output_tuple->Unref(); + } else { + Tensor* output_tensor; + TF_RETURN_IF_ERROR( + context->allocate_output(0, TensorShape({}), &output_tensor)); + int64 key; + TF_RETURN_IF_ERROR(output_tuple->Intern(rm, &key)); + output_tensor->scalar()() = key; + } return Status::OK(); } diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc index ffea592491d..3258286c106 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc @@ -87,6 +87,19 @@ REGISTER_KERNEL_BUILDER(Name("XRTReadLiteral") .HostMemory("literal"), XRTReadLiteralOp); +REGISTER_KERNEL_BUILDER(Name("XRTWriteLiteral") + .Device(DEVICE_XLA_GPU) + .HostMemory("handle") + .HostMemory("literal") + .HostMemory("output_handle"), + XRTWriteLiteralOp); +REGISTER_KERNEL_BUILDER(Name("XRTWriteLiteral") + .Device(DEVICE_XLA_CPU) + .HostMemory("handle") + .HostMemory("literal") + .HostMemory("output_handle"), + XRTWriteLiteralOp); + REGISTER_KERNEL_BUILDER(Name("XRTReadLiteralAndRelease") .Device(DEVICE_XLA_GPU) .HostMemory("handle") diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h index 54b06558adc..26a58fa42d8 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h +++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h @@ -393,6 +393,56 @@ class XRTReadLiteralOp : public OpKernel { } }; +// Op that writes a new literal value into device-resident memory. +template +class XRTWriteLiteralOp : public OpKernel { + public: + explicit XRTWriteLiteralOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + ~XRTWriteLiteralOp() override = default; + XRTWriteLiteralOp(const XRTWriteLiteralOp&) = delete; + XRTWriteLiteralOp& operator=(const XRTWriteLiteralOp&) = delete; + + void Compute(OpKernelContext* ctx) override { + VLOG(1) << "XRTWriteLiteralOp::Compute"; + + const Tensor& handle_tensor = ctx->input(0); + OP_REQUIRES( + ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()), + errors::Internal("computation input should be an int64 scalar")); + int64 allocation_handle = handle_tensor.scalar()(); + + const Tensor& literal_info = ctx->input(1); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(literal_info.shape()), + errors::Internal("literal input should be a string scalar")); + xla::LiteralProto literal_proto; + OP_REQUIRES(ctx, + literal_proto.ParseFromString(literal_info.scalar()()), + errors::InvalidArgument( + "Unable to parse allocation input to LiteralProto")); + xla::Literal literal; + OP_REQUIRES_OK(ctx, XRTStateHelpers::MakeLiteral(literal_proto, &literal)); + + ResourceMgr* rm; + OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); + + XRTTupleAllocation* allocation; + OP_REQUIRES_OK( + ctx, XRTTupleAllocation::Lookup(rm, allocation_handle, &allocation)); + core::ScopedUnref allocation_unref(allocation); + // We are guaranteed that the underlying device object won't be deleted out + // from under us, while the ScopedRef is live. + typename DeviceAccessor::ScopedRef device_ref; + OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef( + ctx, allocation->device_ordinal(), &device_ref)); + OP_REQUIRES_OK(ctx, + allocation->WriteLiteral(device_ref.backend(), literal)); + + Tensor output(DT_INT64, TensorShape({})); + output.scalar()() = allocation_handle; + ctx->set_output(0, output); + } +}; + // Op that discards a handle to device memory. template class XRTReleaseAllocationOp : public OpKernel { diff --git a/tensorflow/compiler/xrt/ops/xrt_state_ops.cc b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc index 07d025ce343..a3d63106fa1 100644 --- a/tensorflow/compiler/xrt/ops/xrt_state_ops.cc +++ b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc @@ -95,6 +95,20 @@ Copies an allocated tuple from device memory and returns it as a literal. 'literal' is a serialized xla::LiteralProto proto. )"); +REGISTER_OP("XRTWriteLiteral") + .Input("handle: int64") + .Input("literal: string") + .Output("output_handle: int64") + .SetShapeFn(tensorflow::shape_inference::ScalarShape) + .Doc( + R"( +Copies the input literal into the device memory pointed to by handle. +Returns the handle itself. + +'handle' is the id returned from the Op that produced the on-device allocation. +'literal' is a serialized xla::LiteralProto proto to be written to device memory. +)"); + REGISTER_OP("XRTReadLiteralAndRelease") .Input("handle: int64") .Output("literal: string") diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc index 25464b5554d..abaa17e50e3 100644 --- a/tensorflow/compiler/xrt/tests/raw_api_test.cc +++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc @@ -102,7 +102,7 @@ bool CompareLiteralProtos(const xla::LiteralProto& a, auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie(); bool equal = l_a == l_b; if (!equal) { - LOG(INFO) << "LiteralProtos don't match " << a.DebugString() + LOG(INFO) << "LiteralProtos don't match: " << a.DebugString() << " != " << b.DebugString(); } return equal; @@ -175,6 +175,18 @@ xla::XlaComputation AddAndTuple() { return builder.Build().ValueOrDie(); } +xla::XlaComputation AddAndSubTuple() { + xla::XlaBuilder builder("AddAndSubTuple"); + auto p0 = xla::Parameter(&builder, 0, xla::ShapeUtil::MakeShape(xla::F32, {}), + "P0"); + auto p1 = xla::Parameter(&builder, 1, xla::ShapeUtil::MakeShape(xla::F32, {}), + "P1"); + auto sum = xla::Add(p0, p1); + auto sub = xla::Sub(p0, p1); + xla::Tuple(&builder, {sum, sub}); + return builder.Build().ValueOrDie(); +} + void StoreComputationSnapshot(const xla::XlaComputation& computation, xla::HloSnapshot* dst) { auto snapshot = computation.Snapshot().ValueOrDie(); @@ -203,6 +215,56 @@ xla::ProgramShape XlaCompiledProgramShape( ->ComputeProgramShape(); } +TEST(RawApiTest, AllocAndRewrite) { + xrt::XLAAllocation alloc; + alloc.set_device_ordinal(0); + *alloc.mutable_value() = + xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}).ToProto(); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + auto value = + ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString()); + auto handle = ops::XRTAllocate(root, value); + auto read_back = ops::XRTReadLiteral(root, handle); + TF_ASSERT_OK(root.status()); + + tensorflow::ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({read_back, handle}, &outputs)); + EXPECT_EQ(outputs.size(), 2); + + int64 allocation_handle = outputs[1].scalar()(); + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response)); + outputs.clear(); + + xla::LiteralProto new_literal = + xla::LiteralUtil::CreateR2({{9, 2}, {4, 1}}).ToProto(); + auto new_value = ops::Const(root.WithDevice("/device:CPU:0"), + new_literal.SerializeAsString()); + auto write_op = + ops::XRTWriteLiteral(root, Input(allocation_handle), new_value); + TF_ASSERT_OK(root.status()); + TF_EXPECT_OK(session.Run({write_op}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + EXPECT_EQ(allocation_handle, outputs[0].scalar()()); + outputs.clear(); + + auto read_after_write = ops::XRTReadLiteral(root, Input(allocation_handle)); + TF_EXPECT_OK(session.Run({read_after_write}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + + xla::LiteralProto new_response; + EXPECT_TRUE(new_response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(CompareLiteralProtos(new_literal, new_response)); + + auto release = + ops::XRTReleaseAllocationHandle(root, Input(allocation_handle)); + TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {}, {release}, + &outputs)); +} + TEST(RawApiTest, ReadAndWriteState) { xrt::XLAAllocation alloc; alloc.set_device_ordinal(0); @@ -375,9 +437,12 @@ TEST(RawApiTest, CompileAndExecute) { xrt::XLAComputation c; auto config = c.mutable_config(); auto shapes = config->mutable_program_shape(); - *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2}); - *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2}); - *shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::F32, {2}); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes->mutable_result() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); StoreComputationSnapshot(AddAndScale(), c.mutable_hlo_snapshot()); xrt::XRTExecutionConfig e; @@ -411,7 +476,7 @@ TEST(RawApiTest, CompileAndExecute) { auto expected = xla::LiteralUtil::CreateR1({27.0f, 21.0f}); EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); - xla::ProgramShape program_shape; + xla::ProgramShapeProto program_shape; EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec()(0))); EXPECT_EQ(program_shape.parameters_size(), 2); } @@ -427,9 +492,12 @@ TEST(RawApiTest, CompileAndExecuteWithArgumentVector) { xrt::XLAComputation c; auto config = c.mutable_config(); auto shapes = config->mutable_program_shape(); - *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2}); - *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2}); - *shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::F32, {2}); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes->mutable_result() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); StoreComputationSnapshot(AddAndScale(), c.mutable_hlo_snapshot()); xrt::XRTExecutionConfig e; @@ -465,7 +533,7 @@ TEST(RawApiTest, CompileAndExecuteWithArgumentVector) { auto expected = xla::LiteralUtil::CreateR1({27.0f, 21.0f}); EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); - xla::ProgramShape program_shape; + xla::ProgramShapeProto program_shape; EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec()(0))); EXPECT_EQ(program_shape.parameters_size(), 2); } @@ -494,8 +562,8 @@ TEST(RawApiTest, CompileWithXlaReturnShapes) { xrt::XLAComputation c; auto config = c.mutable_config(); auto shapes = config->mutable_program_shape(); - *shapes->add_parameters() = param_shape; - *shapes->mutable_result() = result_shape; + *shapes->add_parameters() = param_shape.ToProto(); + *shapes->mutable_result() = result_shape.ToProto(); StoreComputationSnapshot(xla_computation, c.mutable_hlo_snapshot()); Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); @@ -510,8 +578,9 @@ TEST(RawApiTest, CompileWithXlaReturnShapes) { TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {c_handle.program_shape}, {release}, &outputs)); - xla::ProgramShape program_shape; - EXPECT_TRUE(program_shape.ParseFromString(outputs[0].vec()(0))); + xla::ProgramShapeProto program_shape_proto; + EXPECT_TRUE(program_shape_proto.ParseFromString(outputs[0].vec()(0))); + xla::ProgramShape program_shape(program_shape_proto); EXPECT_EQ(program_shape.parameters_size(), 1); VLOG(2) << "Param: " @@ -520,7 +589,7 @@ TEST(RawApiTest, CompileWithXlaReturnShapes) { << xla::ShapeUtil::HumanStringWithLayout(program_shape.result()); xla::ProgramShape xla_program_shape = - XlaCompiledProgramShape(xla_computation, *shapes); + XlaCompiledProgramShape(xla_computation, xla::ProgramShape(*shapes)); EXPECT_TRUE(xla::LayoutUtil::Equal( xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {0}).layout(), xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {0}) @@ -547,11 +616,11 @@ TEST(RawApiTest, DotGeneralWithLayoutTest) { auto config = c.mutable_config(); auto shapes = config->mutable_program_shape(); *shapes->add_parameters() = - xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 2}, {0, 1}); + xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 2}, {0, 1}).ToProto(); *shapes->add_parameters() = - xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 1}, {0, 1}); + xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 1}, {0, 1}).ToProto(); *shapes->mutable_result() = - xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 1}, {0, 1}); + xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 1}, {0, 1}).ToProto(); StoreComputationSnapshot(Dot(), c.mutable_hlo_snapshot()); xrt::XRTExecutionConfig e; @@ -592,7 +661,7 @@ TEST(RawApiTest, CompileAndExecuteZeroArg) { xrt::XLAComputation c; auto config = c.mutable_config(); auto shapes = config->mutable_program_shape(); - *shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::F32, {}); + *shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::F32, {}).ToProto(); xrt::XRTExecutionConfig e; e.set_release_input_handles(true); @@ -632,10 +701,13 @@ TEST(RawApiTest, CompileAndExecuteReturnTuple) { xrt::XLAComputation c; auto config = c.mutable_config(); auto shapes = config->mutable_program_shape(); - *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2}); - *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2}); - *shapes->mutable_result() = xla::ShapeUtil::MakeTupleShape( - {xla::ShapeUtil::MakeShape(xla::F32, {2})}); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes->mutable_result() = + xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {2})}) + .ToProto(); StoreComputationSnapshot(AddAndTuple(), c.mutable_hlo_snapshot()); xrt::XRTExecutionConfig e; @@ -671,14 +743,81 @@ TEST(RawApiTest, CompileAndExecuteReturnTuple) { EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); } +TEST(RawApiTest, CompileAndExecuteReturnExplodedTuple) { + xrt::XLAAllocation p0; + p0.set_device_ordinal(0); + *p0.mutable_value() = xla::LiteralUtil::CreateR0(12.0f).ToProto(); + + xrt::XLAAllocation p1; + p1.set_device_ordinal(0); + *p1.mutable_value() = xla::LiteralUtil::CreateR0(3.0f).ToProto(); + + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {}).ToProto(); + *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {}).ToProto(); + *shapes->mutable_result() = + xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {}), + xla::ShapeUtil::MakeShape(xla::F32, {})}) + .ToProto(); + StoreComputationSnapshot(AddAndSubTuple(), c.mutable_hlo_snapshot()); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(true); + e.set_release_compilation_handle(true); + e.set_return_exploded_tuple(true); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + auto e_config = + ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString()); + auto computation = + ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString()); + auto c_handle = ops::XRTCompile(root, computation); + auto p0_value = + ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString()); + auto p0_handle = ops::XRTAllocate(root, p0_value); + auto p1_value = + ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString()); + auto p1_handle = ops::XRTAllocate(root, p1_value); + auto result = ops::XRTExecute(root, c_handle.handle, e_config, + {Output(p0_handle), Output(p1_handle)}); + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({result}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + + auto handles_vec = outputs.front().vec(); + EXPECT_EQ(handles_vec.size(), 2); + + const float kResults[2] = {15.0f, 9.0f}; + for (int64 i = 0; i < handles_vec.size(); ++i) { + auto read_back = ops::XRTReadLiteralAndRelease(root, Input(handles_vec(i))); + std::vector voutputs; + TF_EXPECT_OK(session.Run({read_back}, &voutputs)); + EXPECT_EQ(voutputs.size(), 1); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(voutputs[0].scalar()())); + + auto expected = xla::LiteralUtil::CreateR0(kResults[i]); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); + } +} + TEST(RawApiTest, LeakCompilationReference) { xrt::XLAComputation c; auto config = c.mutable_config(); auto shapes = config->mutable_program_shape(); - *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2}); - *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2}); - *shapes->mutable_result() = xla::ShapeUtil::MakeTupleShape( - {xla::ShapeUtil::MakeShape(xla::F32, {2})}); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes->mutable_result() = + xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {2})}) + .ToProto(); StoreComputationSnapshot(AddAndTuple(), c.mutable_hlo_snapshot()); Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); @@ -703,9 +842,9 @@ TEST(RawApiTest, CompileAndExecuteWithS64Argument) { xrt::XLAComputation c; auto config = c.mutable_config(); auto shapes = config->mutable_program_shape(); - *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S64, {}); - *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S64, {}); - *shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::S64, {}); + *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S64, {}).ToProto(); + *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S64, {}).ToProto(); + *shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::S64, {}).ToProto(); StoreComputationSnapshot(AddS64(), c.mutable_hlo_snapshot()); xrt::XRTExecutionConfig e; @@ -739,11 +878,11 @@ TEST(RawApiTest, CompileAndExecuteWithS64Argument) { auto expected = xla::LiteralUtil::CreateR0(15123899); EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); - xla::ProgramShape program_shape; + xla::ProgramShapeProto program_shape; EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec()(0))); EXPECT_EQ(program_shape.parameters_size(), 2); - EXPECT_TRUE( - xla::ShapeUtil::HasPrimitiveType(program_shape.result(), xla::S64)); + EXPECT_TRUE(xla::ShapeUtil::HasPrimitiveType( + xla::Shape(program_shape.result()), xla::S64)); } } // namespace diff --git a/tensorflow/compiler/xrt/xrt.proto b/tensorflow/compiler/xrt/xrt.proto index 6ab77fbaaf0..378bb9246f2 100644 --- a/tensorflow/compiler/xrt/xrt.proto +++ b/tensorflow/compiler/xrt/xrt.proto @@ -3,6 +3,7 @@ syntax = "proto3"; package xrt; import "tensorflow/compiler/tf2xla/host_compute_metadata.proto"; +import "tensorflow/compiler/xla/xla.proto"; import "tensorflow/compiler/xla/xla_data.proto"; import "tensorflow/compiler/xla/service/hlo.proto"; @@ -36,16 +37,18 @@ message XLAComputationConfig { tensorflow.tf2xla.HostComputeMetadata host_compute_metadata = 3; // The arg/result shapes for the whole computation. - xla.ProgramShape program_shape = 4; + xla.ProgramShapeProto program_shape = 4; // The arg/result shapes for each core of a model-parallel // computation. per_core_args_and_result_shapes is optional for a // single-core computation. - repeated xla.ProgramShape per_core_program_shape = 5; + repeated xla.ProgramShapeProto per_core_program_shape = 5; // Describes how replicated computation instances should be assigned to // devices. There are num_cores_per_replica computations, and each one will be // sent and executed to the set of replica device numbers described in the // DeviceAssignment proto. DeviceAssignment device_assignment = 6; + // The debugging options to be passed to the XLA compilation process. + xla.DebugOptions debug_options = 7; } // Options and XLA computation for a compilation. @@ -98,4 +101,8 @@ message XRTExecutionConfig { bool release_input_handles = 5; // If true, release the handle to the computation after running. bool release_compilation_handle = 6; + // If set to true, and the result shape is a tuple, then instead of returning + // a single tuple allocation the execution will return a vector of + // allocations, one for each of the first-level elements of the result tuple. + bool return_exploded_tuple = 7; } diff --git a/tensorflow/compiler/xrt/xrt_state.cc b/tensorflow/compiler/xrt/xrt_state.cc index 3a99820d7aa..5c7c537c340 100644 --- a/tensorflow/compiler/xrt/xrt_state.cc +++ b/tensorflow/compiler/xrt/xrt_state.cc @@ -183,6 +183,20 @@ Status XRTTupleAllocation::ToLiteral(xla::Backend* backend, int device_ordinal, return Status::OK(); } +Status XRTTupleAllocation::WriteLiteral(xla::Backend* backend, + const xla::Literal& literal) { + if (!xla::ShapeUtil::Equal(literal.shape(), on_host_shape())) { + return errors::InvalidArgument( + "New literal shape not matching the existing one: literal=", + xla::ShapeUtil::HumanStringWithLayout(literal.shape()), + " device=", xla::ShapeUtil::HumanStringWithLayout(on_host_shape())); + } + auto transfer_manager = backend->transfer_manager(); + TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal())); + return transfer_manager->TransferLiteralToDevice(stream.get(), literal, + ToShapedBuffer()); +} + void XRTTupleAllocation::DiscardAllocation( const xla::ShapeIndex& buffer_index) { buffers_.element(buffer_index)->DiscardAllocation(); diff --git a/tensorflow/compiler/xrt/xrt_state.h b/tensorflow/compiler/xrt/xrt_state.h index 73b5584e38f..3664c0cd4e6 100644 --- a/tensorflow/compiler/xrt/xrt_state.h +++ b/tensorflow/compiler/xrt/xrt_state.h @@ -137,6 +137,9 @@ class XRTTupleAllocation : public ResourceBase { Status ToLiteral(xla::Backend* backend, int device_ordinal, xla::Literal* literal); + // Write a new literal value to the allocation. + Status WriteLiteral(xla::Backend* backend, const xla::Literal& literal); + // True if none of the buffers in the allocation are aliased by any other live // handle. bool IsExclusiveOwner(); diff --git a/tensorflow/compiler/xrt/xrt_util.cc b/tensorflow/compiler/xrt/xrt_util.cc new file mode 100644 index 00000000000..3ef8bedc732 --- /dev/null +++ b/tensorflow/compiler/xrt/xrt_util.cc @@ -0,0 +1,76 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xrt/xrt_util.h" + +#include +#include + +#include "tensorflow/compiler/xla/debug_options_flags.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace { + +bool DebugOptionsPassThroughEnabled() { + const char* env = getenv("TF_XLA_DEBUG_OPTIONS_PASSTHROUGH"); + bool enabled = + env != nullptr && (strcmp(env, "1") == 0 || strcmp(env, "true") == 0); + if (enabled) { + LOG(WARNING) << "Passing through XLA debug options!"; + } else { + LOG(WARNING) << "TF_XLA_DEBUG_OPTIONS_PASSTHROUGH not set, not all options " + "will be retained"; + } + return enabled; +} + +string SafeDebugPath(const string& path) { + if (path.empty() || path.compare(0, 5, "gs://") == 0 || + path.compare(0, 11, "bigstore://") == 0) { + return path; + } + LOG(WARNING) << "Invalid config path (will be dropped): " << path; + return string(); +} + +} // namespace + +xla::DebugOptions BuildXlaDebugOptions(const xla::DebugOptions& ref_options) { + static const bool options_passthrough = DebugOptionsPassThroughEnabled(); + if (options_passthrough) { + return ref_options; + } + xla::DebugOptions options = xla::GetDebugOptionsFromFlags(); + options.set_xla_generate_hlo_text_to( + SafeDebugPath(ref_options.xla_generate_hlo_text_to())); + options.set_xla_dump_optimized_hlo_proto_to( + SafeDebugPath(ref_options.xla_dump_optimized_hlo_proto_to())); + options.set_xla_dump_computations_to( + SafeDebugPath(ref_options.xla_dump_computations_to())); + options.set_xla_dump_executions_to( + SafeDebugPath(ref_options.xla_dump_executions_to())); + for (auto& pass : ref_options.xla_disable_hlo_passes()) { + options.add_xla_disable_hlo_passes(pass); + } + options.set_xla_dump_unoptimized_hlo_proto_to( + SafeDebugPath(ref_options.xla_dump_unoptimized_hlo_proto_to())); + options.set_xla_dump_per_pass_hlo_proto_to( + SafeDebugPath(ref_options.xla_dump_per_pass_hlo_proto_to())); + return options; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/xrt/xrt_util.h b/tensorflow/compiler/xrt/xrt_util.h new file mode 100644 index 00000000000..d9c05a7f340 --- /dev/null +++ b/tensorflow/compiler/xrt/xrt_util.h @@ -0,0 +1,34 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Utility functions in support of the XRT API. + +#ifndef TENSORFLOW_COMPILER_XRT_XRT_UTIL_H_ +#define TENSORFLOW_COMPILER_XRT_XRT_UTIL_H_ + +#include "tensorflow/compiler/xla/xla.pb.h" + +namespace tensorflow { + +// Filters the debug options provided as argument according to the value of the +// TF_XLA_DEBUG_OPTIONS_PASSTHROUGH environment variable. If such variable is +// set to "1" or "true", the debug options will be returned as is. Otherwise +// only a subset of them will be set in the returned ones, and all the paths +// contained in it, will be limited to gs:// and bigstore:// ones. +xla::DebugOptions BuildXlaDebugOptions(const xla::DebugOptions& ref_options); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_XRT_XRT_UTIL_H_ diff --git a/tensorflow/contrib/all_reduce/BUILD b/tensorflow/contrib/all_reduce/BUILD index a513aa1e7c4..f6c6560c1c3 100644 --- a/tensorflow/contrib/all_reduce/BUILD +++ b/tensorflow/contrib/all_reduce/BUILD @@ -9,8 +9,6 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow:tensorflow.bzl", "tf_py_test") - py_library( name = "all_reduce_py", srcs = ["__init__.py"], @@ -29,29 +27,6 @@ py_library( srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:nccl_ops", - ], -) - -tf_py_test( - name = "all_reduce_test", - srcs = ["python/all_reduce_test.py"], - additional_deps = [ - ":all_reduce", - "//third_party/py/numpy", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client", - "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:math_ops", - "//tensorflow/python:constant_op", - "//tensorflow/python:client_testlib", - "//tensorflow/python:platform", - "//tensorflow/python:platform_test", - "//tensorflow/python:state_ops", + "//tensorflow/python/distribute:all_reduce", ], ) diff --git a/tensorflow/contrib/all_reduce/python/all_reduce.py b/tensorflow/contrib/all_reduce/python/all_reduce.py index 25f4b4b8d34..238cdaf8a79 100644 --- a/tensorflow/contrib/all_reduce/python/all_reduce.py +++ b/tensorflow/contrib/all_reduce/python/all_reduce.py @@ -18,842 +18,5 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections -import math - -from tensorflow.python.framework import device as device_lib -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nccl_ops - - -def _flatten_tensors(tensors): - """Check tensors for isomorphism and flatten. - - Args: - tensors: list of T `tf.Tensor` which must all have the same shape. - - Returns: - tensors: a list of T `tf.Tensor` which are flattened (1D) views of tensors - shape: the original shape of each element of input tensors - - Raises: - ValueError: tensors are empty or non-isomorphic or have unknown shape. - """ - if not tensors: - raise ValueError("tensors cannot be empty") - shape = tensors[0].shape - for tensor in tensors: - shape = shape.merge_with(tensor.shape) - if not shape.is_fully_defined(): - raise ValueError("Tensors must have statically known shape.") - if len(shape) != 1: - reshaped = [] - for t in tensors: - with ops.colocate_with(t): - reshaped.append(array_ops.reshape(t, [-1])) - tensors = reshaped - return tensors, shape - - -def _reshape_tensors(tensors, shape): - """Reshape tensors flattened by _flatten_tensors. - - Args: - tensors: list of T `tf.Tensor` of identical length 1D tensors. - shape: list of integers describing the desired shape. Product of - the elements must equal the length of each tensor. - - Returns: - list of T `tf.Tensor` which are the reshaped inputs. - """ - reshaped = [] - for t in tensors: - with ops.colocate_with(t): - reshaped.append(array_ops.reshape(t, shape)) - return reshaped - - -def _padded_split(tensor, pieces): - """Like split for 1D tensors but pads-out case where len % pieces != 0. - - Args: - tensor: T `tf.Tensor` that must be 1D. - pieces: a positive integer specifying the number of pieces into which - tensor should be split. - - Returns: - list of T `tf.Tensor` of length pieces, which hold the values of - thin input tensor, in order. The final tensor may - be zero-padded on the end to make its size equal to those of all - of the other tensors. - - Raises: - ValueError: The input tensor is not 1D. - """ - shape = tensor.shape - if 1 != len(shape): - raise ValueError("input tensor must be 1D") - tensor_len = shape.dims[0].value - with ops.colocate_with(tensor): - if tensor_len % pieces != 0: - # pad to an even length - chunk_size = 1 + tensor_len // pieces - if pieces > tensor_len: - # This is an edge case that should not come up in practice, - # i.e. a different reduction algorithm would be better, - # but we'll make it work just for completeness. - pad_len = pieces - tensor_len - extended_whole = array_ops.concat( - [tensor, array_ops.zeros([pad_len], dtype=tensor.dtype)], 0) - parts = array_ops.split(extended_whole, pieces) - return parts, pad_len - elif (pieces - 1) * chunk_size >= tensor_len: - # Another edge case of limited real interest. - pad_len = (pieces * chunk_size) % tensor_len - extended_whole = array_ops.concat( - [tensor, array_ops.zeros([pad_len], dtype=tensor.dtype)], 0) - parts = array_ops.split(extended_whole, pieces) - return parts, pad_len - else: - last_chunk_size = tensor_len - (pieces - 1) * chunk_size - pad_len = chunk_size - last_chunk_size - piece_lens = [chunk_size for _ in range(pieces - 1)] + [last_chunk_size] - parts = array_ops.split(tensor, piece_lens) - parts[-1] = array_ops.concat( - [parts[-1], array_ops.zeros([pad_len], dtype=tensor.dtype)], 0) - return parts, pad_len - else: - return array_ops.split(tensor, pieces), 0 - - -def _strip_padding(tensors, pad_len): - """Strip the suffix padding added by _padded_split. - - Args: - tensors: list of T `tf.Tensor` of identical length 1D tensors. - pad_len: number of elements to be stripped from the end of each tensor. - - Returns: - list of T `tf.Tensor` which are the stripped inputs. - - Raises: - ValueError: tensors must be a non-empty list of 1D tensors, and - each must be longer than pad_len. - """ - if not tensors: - raise ValueError("tensors cannot be empty") - shape = tensors[0].shape - if len(shape) > 1: - raise ValueError("tensors must be 1D") - prefix_len = int(shape[0] - pad_len) - if prefix_len < 0: - raise ValueError("pad_len longer than tensor") - stripped = [] - for t in tensors: - with ops.colocate_with(t): - stripped.append(array_ops.slice(t, [0], [prefix_len])) - return stripped - - -def _ragged_split(tensor, pieces): - """Like split for 1D tensors but allows case where len % pieces != 0. - - Args: - tensor: T `tf.Tensor` that must be 1D. - pieces: a positive integer specifying the number of pieces into which - tensor should be split. - - Returns: - list of T `tf.Tensor` of length pieces, which hold the values of - the input tensor, in order. The final tensor may be shorter - than the others, which will all be of equal length. - - Raises: - ValueError: input tensor must be 1D. - """ - shape = tensor.shape - if 1 != len(shape): - raise ValueError("input tensor must be 1D") - tensor_len = shape.dims[0].value - chunk_size = tensor_len // pieces - with ops.colocate_with(tensor): - if tensor_len != (pieces * chunk_size): - # last piece will be short - assert pieces > 1 - last_chunk_size = tensor_len - ((pieces - 1) * chunk_size) - assert last_chunk_size > 0 - piece_lens = [chunk_size for _ in range(pieces - 1)] + [last_chunk_size] - return array_ops.split(tensor, piece_lens) - else: - return array_ops.split(tensor, pieces) - - -def _ring_permutations(num_workers, num_subchunks, gpu_perm): - """"Generate an array of device index arrays, one for each subchunk. - - In the basic ring reduction algorithm there are size(T)/num_devices - data chunks and each device process one chunk per tick, i.e. sending - one chunk and receiving one chunk. The idea of subchunking is that - each device processes num_subchunks smaller data regions per tick, - and the ring rank permutation is different for each subchunk index - so that a device is potentially sending to and receiving from - num_subchunks different other devices at each tick. Where multiple - independent data channels exist between devices, this strategy - supplies a method of using them in parallel. - - Args: - num_workers: number of worker tasks - num_subchunks: number of subchunks into which to divide each per-GPU chunk. - gpu_perm: an array of integers in [0, num_gpus-1] giving the default - ring order of GPUs at each worker. Other permutations will be generated - by rotating this array and splicing together per-worker instances. - - Raises: - ValueError: the number of subchunks may not exceed the number of GPUs. - - Returns: - pred_by_s_d: list of lists that maps (by index) from (subchunk, dev) to - preceding device in the permutation for that subchunk. The - device index of GPU i at worker j is i + (j * num_gpus). - rank_by_s_d: list of lists that maps (by index) from (subchunk, dev) to - local rank of device d in the permutation for that subchunk. - """ - num_gpus = len(gpu_perm) - devices = num_workers * num_gpus - if devices == 0: - return [], [] - if num_subchunks > num_gpus: - raise ValueError( - "num_subchunks %d must be <= num_gpus %d" % (num_subchunks, num_gpus)) - rotation_interval = max(1, int(num_gpus / num_subchunks)) - perms_by_s = [] - for s in range(0, num_subchunks): - full_order = [] - offset = s * rotation_interval - for w in range(0, num_workers): - default_order = [(w * num_gpus) + i for i in gpu_perm] - dev_order = default_order[offset:] + default_order[:offset] - full_order += dev_order - perms_by_s.append(full_order) - pred_by_s_d = [[-1 for d in range(0, devices)] - for s in range(0, num_subchunks)] - rank_by_s_d = [[-1 for d in range(0, devices)] - for s in range(0, num_subchunks)] - for s in range(0, num_subchunks): - for d in range(0, devices): - for t in range(0, devices): - if d == perms_by_s[s][t]: - rank_by_s_d[s][d] = t - pred_by_s_d[s][d] = perms_by_s[s][(t + devices - 1) % devices] - break - return (pred_by_s_d, rank_by_s_d) - - -def build_ring_all_reduce(input_tensors, num_workers, num_subchunks, - gpu_perm, red_op, un_op=None): - """Construct a subgraph performing a ring-style all-reduce of input_tensors. - - Args: - input_tensors: a list of T `tf.Tensor` objects, which must all - have the same shape and type. - num_workers: number of worker tasks spanned by input_tensors. - num_subchunks: number of subchunks each device should process in one tick. - gpu_perm: a list of ints giving a ring-wise rank ordering of GPUs at - each worker. All workers must have the same number of - GPUs with the same rank ordering. If NVLINK is available, this should - be a ring order supported by NVLINK edges. - red_op: a binary operator for elementwise reduction. - un_op: an optional unary operator to apply to fully reduced values. - - Raises: - ValueError: empty input_tensors or they don't all have same - size. - - Returns: - a list of T `tf.Tensor` identical sum-reductions of input_tensors. - """ - if len(input_tensors) < 2: - raise ValueError("input_tensors must be length 2 or longer") - input_tensors, shape = _flatten_tensors(input_tensors) - devices = [t.device for t in input_tensors] - (pred_by_s_d, rank_by_s_d) = _ring_permutations( - num_workers, num_subchunks, gpu_perm) - chunks_by_dev, pad_len = _build_ring_gather( - input_tensors, devices, - num_subchunks, pred_by_s_d, rank_by_s_d, red_op) - if un_op: - chunks_by_dev = _apply_unary_to_chunks(un_op, chunks_by_dev) - output_tensors = _build_ring_scatter(pred_by_s_d, rank_by_s_d, - chunks_by_dev) - if pad_len > 0: - output_tensors = _strip_padding(output_tensors, pad_len) - if len(shape) != 1: - output_tensors = _reshape_tensors(output_tensors, shape) - return output_tensors - - -def _build_ring_gather(input_tensors, devices, num_subchunks, - pred_by_s_d, rank_by_s_d, red_op): - """Construct a subgraph for the first (reduction) pass of ring all-reduce. - - Args: - input_tensors: a list of T `tf.Tensor` 1D input tensors of same - shape and type. - devices: array of device name strings - num_subchunks: number of subchunks each device should process in one tick. - pred_by_s_d: as produced by _ring_permutations - rank_by_s_d: as produced by _ring_permutations - red_op: a binary operator for elementwise reduction - - Raises: - ValueError: tensors must all be one dimensional. - - Returns: - list of list of T `tf.Tensor` of (partially) reduced values where - exactly num_subchunks chunks at each device are fully reduced. - """ - num_devices = len(input_tensors) - if num_devices == 0: - return [] - if num_devices == 1: - return input_tensors - shape = input_tensors[0].shape - if 1 != len(shape): - raise ValueError("input tensors must be 1D") - num_chunks = num_devices * num_subchunks - num_ticks = num_devices - 1 - # Initialize chunks_by_dev with splits of the input tensors. - chunks_by_dev = [] - split_pad_len = 0 - for d in range(0, num_devices): - with ops.device(devices[d]): - splits, split_pad_len = _padded_split(input_tensors[d], num_chunks) - chunks_by_dev.append(splits) - # Reduction phase - for tick in range(0, num_ticks): - # One new partial reduction for every chunk - new_partial_reductions = [None for _ in range(0, num_chunks)] - # Compute reductions with respect to last tick's values - for d in range(0, num_devices): - with ops.device(devices[d]): - for s in range(0, num_subchunks): - rank = rank_by_s_d[s][d] - seg_index = (rank + num_devices - (2 + tick)) % num_devices - pred_dev = pred_by_s_d[s][d] - chunk_index = (seg_index * num_subchunks) + s - new_partial_reductions[chunk_index] = red_op( - chunks_by_dev[pred_dev][chunk_index], - chunks_by_dev[d][chunk_index]) - # Update chunks_by_dev with the new values at the end of the tick. - for d in range(0, num_devices): - for s in range(0, num_subchunks): - rank = rank_by_s_d[s][d] - seg_index = (rank + num_devices - (2 + tick)) % num_devices - chunk_index = (seg_index * num_subchunks) + s - chunks_by_dev[d][chunk_index] = new_partial_reductions[chunk_index] - return chunks_by_dev, split_pad_len - - -def _apply_unary_to_chunks(f, chunks_by_dev): - """Apply a unary op to each tensor in chunks_by_dev, on same device. - - Args: - f: a unary function over T `tf.Tensor`. - chunks_by_dev: list of lists of T `tf.Tensor`. - - Returns: - new list of lists of T `tf.Tensor` with the same structure as - chunks_by_dev containing the derived tensors. - """ - output = [] - for x in chunks_by_dev: - with ops.colocate_with(x[0]): - output.append([f(t) for t in x]) - return output - - -def _build_ring_scatter(pred_by_s_d, rank_by_s_d, - chunks_by_dev): - """Construct subgraph for second (scatter) pass of ring all-reduce. - - Args: - pred_by_s_d: as produced by _ring_permutations - rank_by_s_d: as produced by _ring_permutations - chunks_by_dev: list of list of T `tf.Tensor` indexed by ints - (device, chunk) - - Raises: - ValueError: chunks_by_dev is not well-formed - - Returns: - list of T `tf.Tensor` which are the fully reduced tensors, one - at each device corresponding to the outer dimension of chunks_by_dev. - """ - num_devices = len(chunks_by_dev) - num_chunks = len(chunks_by_dev[0]) - if 0 != num_chunks % num_devices: - raise ValueError( - "Expect number of chunks per device to be divisible by num_devices") - num_subchunks = int(num_chunks / num_devices) - num_ticks = num_devices - 1 - for tick in range(0, num_ticks): - passed_values = [None for _ in range(0, num_chunks)] - for d in range(0, num_devices): - with ops.colocate_with(chunks_by_dev[d][0]): - for s in range(0, num_subchunks): - rank = rank_by_s_d[s][d] - seg_index = (rank + num_devices - (1 + tick)) % num_devices - pred_dev = pred_by_s_d[s][d] - chunk_index = (seg_index * num_subchunks) + s - passed_values[chunk_index] = array_ops.identity( - chunks_by_dev[pred_dev][chunk_index]) - for d in range(0, num_devices): - for s in range(0, num_subchunks): - rank = rank_by_s_d[s][d] - seg_index = (rank + num_devices - (1 + tick)) % num_devices - chunk_index = (seg_index * num_subchunks) + s - chunks_by_dev[d][chunk_index] = passed_values[chunk_index] - # Join chunks at each device. - output = [] - for x in chunks_by_dev: - with ops.colocate_with(x[0]): - output.append(array_ops.concat(x, 0)) - return output - - -def build_recursive_hd_all_reduce(input_tensors, red_op, un_op=None): - """Construct a subgraph for recursive halving-doubling all-reduce. - - The recursive halving-doubling algorithm is described in - http://www.mcs.anl.gov/~thakur/papers/ijhpca-coll.pdf - - The concept is to arrange the participating n devices in - a linear sequence where devices exchange data pairwise - with one other device in each round. During the gather - phase there are lg(n) rounds where devices exchange - increasingly smaller sub-tensors with another device - at increasingly greater distances, until at the top - each device has 1/n of the fully reduced values. During the - scatter phase each device exchanges its fully reduced - sub-tensor (which doubles in length at each round) - with one other device at increasingly smaller distances - until each device has all of the fully reduced values. - - Note: this preliminary version requires that len(input_tensors) be a - power of 2. TODO(tucker): relax this restriction. Also, the - number of elements in each tensor must be divisible by 2^h where h - is the number of hops in each phase. This will also be relaxed in - the future with edge-case specific logic. - - Args: - input_tensors: list of T `tf.Tensor` to be elementwise reduced. - red_op: a binary elementwise reduction Op. - un_op: an optional unary elementwise Op to apply to reduced values. - - Returns: - list of T `tf.Tensor` which are the fully reduced tensors, one - at each device of input_tensors. - - Raises: - ValueError: num_devices not a power of 2, or tensor len not divisible - by 2 the proper number of times. - """ - devices = [t.device for t in input_tensors] - input_tensors, shape = _flatten_tensors(input_tensors) - reduced_shards = _build_recursive_hd_gather(input_tensors, devices, red_op) - if un_op: - reduced_shards = [un_op(t) for t in reduced_shards] - output_tensors = _build_recursive_hd_scatter(reduced_shards, devices) - if len(shape) != 1: - output_tensors = _reshape_tensors(output_tensors, shape) - return output_tensors - - -def _build_recursive_hd_gather(input_tensors, devices, red_op): - """Construct the gather phase of recursive halving-doubling all-reduce. - - Args: - input_tensors: list of T `tf.Tensor` to be elementwise reduced. - devices: a list of strings naming the devices hosting input_tensors, - which will also be used to host the (partial) reduction values. - red_op: a binary elementwise reduction Op. - - Returns: - list of T `tf.Tensor` which are the fully reduced tensor shards. - - Raises: - ValueError: num_devices not a power of 2, or tensor len not divisible - by 2 the proper number of times. - """ - num_devices = len(devices) - num_hops = int(math.log(num_devices, 2)) - if num_devices != (2 ** num_hops): - raise ValueError("num_devices must be a power of 2") - chunks = input_tensors - for h in range(0, num_hops): - span = 2 ** h - group_size = span * 2 - new_chunks = [[] for _ in devices] - for d in range(0, num_devices): - if (d % group_size) >= (group_size / 2): - # skip right half of a pair - continue - left_dev = devices[d] - right_dev = devices[d + span] - left_split = array_ops.split(chunks[d], 2) - right_split = array_ops.split(chunks[d+span], 2) - with ops.device(left_dev): - new_chunks[d] = red_op(left_split[0], right_split[0]) - with ops.device(right_dev): - new_chunks[d + span] = red_op(left_split[1], right_split[1]) - chunks = new_chunks - return chunks - - -def _build_recursive_hd_scatter(input_tensors, devices): - """Construct the scatter phase of recursive halving-doublng all-reduce. - - Args: - input_tensors: list of T `tf.Tensor` that are fully-reduced shards. - devices: a list of strings naming the devices on which the reconstituted - full tensors should be placed. - - Returns: - list of T `tf.Tensor` which are the fully reduced tensors. - """ - num_devices = len(devices) - num_hops = int(math.log(num_devices, 2)) - assert num_devices == (2 ** num_hops), "num_devices must be a power of 2" - chunks = input_tensors - for h in reversed(range(0, num_hops)): - span = 2 ** h - group_size = span * 2 - new_chunks = [[] for _ in devices] - for d in range(0, num_devices): - if (d % group_size) >= (group_size / 2): - # skip right half of a pair - continue - left_idx = d - right_idx = d + span - left_dev = devices[left_idx] - right_dev = devices[right_idx] - with ops.device(left_dev): - new_chunks[left_idx] = array_ops.concat([chunks[left_idx], - chunks[right_idx]], 0) - with ops.device(right_dev): - new_chunks[right_idx] = array_ops.concat([chunks[left_idx], - chunks[right_idx]], 0) - chunks = new_chunks - return chunks - - -def build_shuffle_all_reduce(input_tensors, gather_devices, red_op, un_op=None): - """Construct a subgraph for shuffle all-reduce. - - Shuffle reduce is essentially the algorithm implemented when using - parameter servers. Suppose tensor length is n, there are d devices - and g gather shards. Each device sends a n/g length sub-tensor to - each gather shard. The gather shards perform a reduction across d - fragments, then broadcast the result back to each device. The - devices then join the g fully reduced fragments they receive from - the shards. The gather shards could perform d-1 pairwise - reductions, or one d-way reduction. The first is better where - reduction Op time is low compared to transmission time, the second - better in the other case. - - Args: - input_tensors: list of T @(tf.Tensor} values to be reduced. - gather_devices: list of names of devices on which reduction shards - should be placed. - red_op: an n-array elementwise reduction Op - un_op: optional elementwise unary Op to be applied to fully-reduced values. - - Returns: - list of T `tf.Tensor` which are the fully reduced tensors. - """ - input_tensors, shape = _flatten_tensors(input_tensors) - dst_devices = [t.device for t in input_tensors] - reduced_shards = _build_shuffle_gather(input_tensors, gather_devices, - red_op, un_op) - output_tensors = _build_shuffle_scatter(reduced_shards, dst_devices) - if len(shape) != 1: - output_tensors = _reshape_tensors(output_tensors, shape) - return output_tensors - - -def _build_shuffle_gather(input_tensors, gather_devices, red_op, un_op=None): - """Construct the gather (concentrate and reduce) phase of shuffle all-reduce. - - Args: - input_tensors: list of T @(tf.Tensor} values to be reduced. - gather_devices: list of names of devices on which reduction shards - should be placed. - red_op: the binary reduction Op - un_op: optional elementwise unary Op to be applied to fully-reduced values. - - Returns: - list of T `tf.Tensor` which are the fully reduced shards. - - Raises: - ValueError: inputs not well-formed. - """ - num_source_devices = len(input_tensors) - num_gather_devices = len(gather_devices) - shape = input_tensors[0].shape - if len(shape) != 1: - raise ValueError("input_tensors must be 1D") - shards_by_source = [] - for d in range(0, num_source_devices): - with ops.colocate_with(input_tensors[d]): - shards_by_source.append( - _ragged_split(input_tensors[d], num_gather_devices)) - reduced_shards = [] - for d in range(0, num_gather_devices): - with ops.device(gather_devices[d]): - values = [s[d] for s in shards_by_source] - red_shard = red_op(values) - if un_op: - red_shard = un_op(red_shard) - reduced_shards.append(red_shard) - return reduced_shards - - -def _build_shuffle_scatter(reduced_shards, dst_devices): - """Build the scatter phase of shuffle all-reduce. - - Args: - reduced_shards: list of T @(tf.Tensor} fully reduced shards - dst_devices: list of names of devices at which the fully-reduced value - should be reconstituted. - - Returns: - list of T `tf.Tensor` scattered tensors. - """ - num_devices = len(dst_devices) - out_tensors = [] - for d in range(0, num_devices): - with ops.device(dst_devices[d]): - out_tensors.append(array_ops.concat(reduced_shards, 0)) - return out_tensors - - -def _split_by_task(devices, values): - """Partition devices and values by common task. - - Args: - devices: list of device name strings - values: list of T `tf.tensor` of same length as devices. - - Returns: - (per_task_devices, per_task_values) where both values are - lists of lists with isomorphic structure: the outer list is - indexed by task, and the inner list has length of the number - of values belonging to that task. per_task_devices contains - the specific devices to which the values are local, and - per_task_values contains the corresponding values. - - Raises: - ValueError: devices must be same length as values. - """ - num_devices = len(devices) - if num_devices != len(values): - raise ValueError("len(devices) must equal len(values)") - per_task_devices = collections.OrderedDict() - per_task_values = collections.OrderedDict() - for d in range(num_devices): - d_spec = device_lib.DeviceSpec.from_string(devices[d]) - if not hasattr(d_spec, "task") or d_spec.task is None: - assert False, "failed to parse device %s" % devices[d] - index = (d_spec.job or "localhost", d_spec.replica or 0, d_spec.task) - if index not in per_task_devices: - per_task_devices[index] = [] - per_task_values[index] = [] - per_task_devices[index].append(devices[d]) - per_task_values[index].append(values[d]) - - return (list(per_task_devices.values()), list(per_task_values.values())) - - -def build_nccl_all_reduce(input_tensors, red_op, un_op=None): - """Build a subgraph that does one full all-reduce, using NCCL. - - Args: - input_tensors: list of T `tf.Tensor` of same-shape and type values to - be reduced. - red_op: binary elementwise reduction operator. Must be one of - {tf.add} - un_op: optional unary elementwise Op to apply to fully-reduce values. - - Returns: - list of T `tf.Tensor` of reduced values. - - Raises: - ValueError: red_op not supported. - """ - if red_op == math_ops.add: - output_tensors = nccl_ops.all_sum(input_tensors) - else: - raise ValueError("red_op not supported by NCCL all-reduce: ", red_op) - if un_op: - un_op_wrapped = [] - for t in output_tensors: - with ops.colocate_with(t): - un_op_wrapped.append(un_op(t)) - output_tensors = un_op_wrapped - return output_tensors - - -def _build_nccl_hybrid(input_tensors, red_op, upper_level_f): - """Construct a subgraph for NCCL hybrid all-reduce. - - Args: - input_tensors: list of T `tf.Tensor` of same-shape and type values to - be reduced. - red_op: binary elementwise reduction operator. - upper_level_f: function for reducing one value per worker, across - workers. - - Returns: - list of T `tf.Tensor` of reduced values. - - Raises: - ValueError: inputs not well-formed. - """ - input_tensors, shape = _flatten_tensors(input_tensors) - devices = [t.device for t in input_tensors] - per_worker_devices, per_worker_values = _split_by_task(devices, input_tensors) - num_workers = len(per_worker_devices) - up_values = [None for w in range(0, num_workers)] - up_devices = up_values[:] - down_values = up_values[:] - # First stage: reduce within each worker using NCCL - for w in range(0, num_workers): - worker_values = build_nccl_all_reduce(per_worker_values[w], red_op) - # NOTE: these reductions will not run to completion unless - # every output value is used. Since we only need one, we - # need to put control dependencies on the rest. - with ops.control_dependencies(worker_values): - with ops.device(worker_values[0].device): - up_values[w] = array_ops.identity(worker_values[0]) - up_devices[w] = per_worker_devices[w][0] - # Second stage: Apply upper_level_f to reduce across first device at - # each worker - level_2_output = upper_level_f(up_values) - # Third stage: propagate within each worker using NCCL Broadcast - for w in range(0, num_workers): - dst_tensors = [] - with ops.device(per_worker_devices[w][0]): - broadcast_src = nccl_ops.broadcast(array_ops.identity(level_2_output[w])) - for d in per_worker_devices[w]: - with ops.device(d): - dst_tensors.append(array_ops.identity(broadcast_src)) - down_values[w] = dst_tensors - output_tensors = [v for sublist in down_values for v in sublist] - if len(shape) != 1: - output_tensors = _reshape_tensors(output_tensors, shape) - return output_tensors - - -def _reduce_non_singleton(input_tensors, red_f, un_op): - """If input_tensors has more than one element apply red_f, else apply un_op.""" - if len(input_tensors) > 1: - return red_f(input_tensors) - else: - if not un_op: - return input_tensors - output_tensors = [] - for t in input_tensors: - with ops.colocate_with(t): - output_tensors.append(un_op(t)) - return output_tensors - - -def build_nccl_then_ring(input_tensors, subdiv, red_op, un_op=None): - """Construct hybrid of NCCL within workers, Ring across workers.""" - def upper_builder(y): - return build_ring_all_reduce(y, len(y), subdiv, [0], red_op, un_op) - def upper_level_f(x): - return _reduce_non_singleton(x, upper_builder, un_op) - return _build_nccl_hybrid(input_tensors, red_op, upper_level_f) - - -def build_nccl_then_recursive_hd(input_tensors, red_op, un_op=None): - """Construct hybrid of NCCL within workers, Recursive-HD across workers.""" - upper_level_f = lambda x: build_recursive_hd_all_reduce(x, red_op, un_op) - return _build_nccl_hybrid(input_tensors, red_op, upper_level_f) - - -def build_nccl_then_shuffle(input_tensors, gather_devices, nccl_red_op, - shuffle_red_op, un_op=None): - """Construct hybrid of NCCL within workers, Shuffle across workers.""" - upper_level_f = lambda x: build_shuffle_all_reduce(x, gather_devices, - shuffle_red_op, un_op) - return _build_nccl_hybrid(input_tensors, nccl_red_op, upper_level_f) - - -def _build_shuffle_hybrid(input_tensors, gather_devices, red_op, upper_level_f): - """Construct a subgraph for Shuffle hybrid all-reduce. - - Args: - input_tensors: list of T `tf.Tensor` of same-shape and type values to - be reduced. - gather_devices: list of device names on which to host gather shards. - red_op: binary elementwise reduction operator. - upper_level_f: function for reducing one value per worker, across - workers. - - Returns: - list of T `tf.Tensor` of reduced values. - - Raises: - ValueError: inputs not well-formed. - """ - input_tensors, shape = _flatten_tensors(input_tensors) - # First stage, reduce across each worker using gather_devices. - devices = [t.device for t in input_tensors] - per_worker_devices, per_worker_values = _split_by_task(devices, input_tensors) - num_workers = len(per_worker_devices) - up_values = [] - if len(gather_devices) != num_workers: - raise ValueError("For shuffle hybrid, gather_devices must contain one " - "device per worker. ") - for w in range(0, num_workers): - reduced_shards = _build_shuffle_gather( - per_worker_values[w], [gather_devices[w]], red_op) - up_values.append(reduced_shards[0]) - # Second stage, apply upper_level_f. - level_2_output = upper_level_f(up_values) - # Third stage, apply shuffle scatter at each worker. - output_tensors = [] - for w in range(0, num_workers): - output_tensors += _build_shuffle_scatter( - [level_2_output[w]], per_worker_devices[w]) - if len(shape) != 1: - output_tensors = _reshape_tensors(output_tensors, shape) - return output_tensors - - -def build_shuffle_then_ring(input_tensors, gather_devices, subdiv, - red_n_op, red_op, un_op=None): - """Construct hybrid of Shuffle within workers, Ring across workers.""" - def upper_builder(tensors): - return build_ring_all_reduce(tensors, len(tensors), subdiv, [0], - red_op, un_op) - def upper_level_f(tensors): - return _reduce_non_singleton(tensors, upper_builder, un_op) - return _build_shuffle_hybrid( - input_tensors, gather_devices, red_n_op, upper_level_f) - - -def build_shuffle_then_shuffle(input_tensors, first_gather_devices, - second_gather_devices, red_op, un_op=None): - """Construct hybrid of Shuffle within workers, Shuffle across workers.""" - def upper_builder(tensors): - return build_shuffle_all_reduce(tensors, second_gather_devices, - red_op, un_op) - def upper_level_f(tensors): - return _reduce_non_singleton(tensors, upper_builder, un_op) - return _build_shuffle_hybrid( - input_tensors, first_gather_devices, red_op, upper_level_f) +# pylint: disable=unused-import,wildcard-import +from tensorflow.python.distribute.all_reduce import * diff --git a/tensorflow/contrib/autograph/examples/benchmarks/BUILD b/tensorflow/contrib/autograph/examples/benchmarks/BUILD new file mode 100644 index 00000000000..6d2d70c99b4 --- /dev/null +++ b/tensorflow/contrib/autograph/examples/benchmarks/BUILD @@ -0,0 +1,36 @@ +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow/tools/test:performance.bzl", "tf_py_logged_benchmark") + +py_library( + name = "benchmark_base", + srcs = [ + "benchmark_base.py", + ], + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_test( + name = "cartpole_benchmark", + size = "enormous", + srcs = ["cartpole_benchmark.py"], + tags = [ + "local", + "manual", + "no_oss", + "notap", + "nozapfhahn", + ], + deps = [ + ":benchmark_base", + # Note: required gym dependency may need to be added here. + ], +) + +tf_py_logged_benchmark( + name = "cartpole_logged_benchmark", + target = "//tensorflow/contrib/autograph/examples/benchmarks:cartpole_benchmark", +) diff --git a/tensorflow/contrib/autograph/examples/benchmarks/benchmark_base.py b/tensorflow/contrib/autograph/examples/benchmarks/benchmark_base.py new file mode 100644 index 00000000000..93c694849c4 --- /dev/null +++ b/tensorflow/contrib/autograph/examples/benchmarks/benchmark_base.py @@ -0,0 +1,62 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Common benchmarking code. + +See https://www.tensorflow.org/community/benchmarks for usage. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +import numpy as np + +import tensorflow as tf + + +class ReportingBenchmark(tf.test.Benchmark): + """Base class for a benchmark that reports general performance metrics. + + Subclasses only need to call one of the _profile methods, and optionally + report_results. + """ + + def time_execution(self, name, target, iters, warm_up_iters=5): + for _ in range(warm_up_iters): + target() + + all_times = [] + for _ in range(iters): + iter_time = time.time() + target() + all_times.append(time.time() - iter_time) + + avg_time = np.average(all_times) + + extras = dict() + extras['all_times'] = all_times + + if isinstance(name, tuple): + extras['name'] = name + name = '_'.join(str(piece) for piece in name) + + self.report_benchmark( + iters=iters, wall_time=avg_time, name=name, extras=extras) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/autograph/examples/benchmarks/cartpole_benchmark.py b/tensorflow/contrib/autograph/examples/benchmarks/cartpole_benchmark.py new file mode 100644 index 00000000000..4f553be58e9 --- /dev/null +++ b/tensorflow/contrib/autograph/examples/benchmarks/cartpole_benchmark.py @@ -0,0 +1,492 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A basic RL cartpole benchmark. + +The RL model uses the OpenAI Gym environment to train a simple network using +the policy gradients method. The training scales the gradients for each step +by the episode's cumulative discounted reward and averages these gradients over +a fixed number of games before applying the optimization step. + +For benchmarking purposes, we replace the OpenAI Gym environment to a fake +that returns random actions and rewards and never ends the episode. This way +the benchmarks compare the same amount of computation at each step. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gym +import numpy as np +import tensorflow as tf + +from tensorflow.contrib import eager +from tensorflow.contrib.autograph.examples.benchmarks import benchmark_base +from tensorflow.python import autograph as ag +from tensorflow.python.eager import context + +# +# AutoGraph implementation +# + + +@ag.convert() +def graph_append_discounted_rewards(destination, rewards, discount_rate): + """Discounts episode rewards and appends them to destination.""" + ag.set_element_type(rewards, tf.float32) + + cdr = 0.0 + reverse_discounted = [] + ag.set_element_type(reverse_discounted, tf.float32) + + for i in range(len(rewards) - 1, -1, -1): + cdr = cdr * discount_rate + rewards[i] + cdr.set_shape(()) + reverse_discounted.append(cdr) + + retval = destination + # Note: AutoGraph doesn't yet support reversed() so we use a loop instead. + for i in range(len(reverse_discounted) - 1, -1, -1): + retval.append(reverse_discounted[i]) + + return retval + + +class GraphPolicyNetwork(tf.keras.Model): + """Policy network for the cart-pole reinforcement learning problem. + + The forward path of the network takes an observation from the cart-pole + environment (length-4 vector) and outputs an action. + """ + + def __init__(self, hidden_size): + super(GraphPolicyNetwork, self).__init__() + self._hidden_layer = tf.keras.layers.Dense( + hidden_size, activation=tf.nn.elu) + self._output_layer = tf.keras.layers.Dense(1) + + def call(self, inputs): + """Calculates logits and action. + + Args: + inputs: Observations from a step in the cart-pole environment, of shape + `(batch_size, input_size)` + + Returns: + logits: the logits output by the output layer. This can be viewed as the + likelihood vales of choosing the left (0) action. Shape: + `(batch_size, 1)`. + actions: randomly selected actions ({0, 1}) based on the logits. Shape: + `(batch_size, 1)`. + """ + hidden = self._hidden_layer(inputs) + logits = self._output_layer(hidden) + + left_prob = tf.nn.sigmoid(logits) + action_probs = tf.concat([left_prob, 1.0 - left_prob], 1) + + actions = tf.multinomial(tf.log(action_probs), 1) + return logits, actions + + # TODO(mdan): Move this method out of the class. + @ag.convert() + def train(self, cart_pole_env, optimizer, discount_rate, num_games, + max_steps_per_game): + var_list = tf.trainable_variables() + grad_list = [ + tf.TensorArray(tf.float32, 0, dynamic_size=True) for _ in var_list + ] + + step_counts = [] + discounted_rewards = [] + ag.set_element_type(discounted_rewards, tf.float32) + ag.set_element_type(step_counts, tf.int32) + + # Note: we use a shared object, cart_pole_env here. Because calls to the + # object's method are made through py_func, TensorFlow cannot detect its + # data dependencies. Hence we must manually synchronize access to it + # and ensure the control dependencies are set in such a way that + # calls to reset(), take_one_step, etc. are made in the correct order. + sync_counter = tf.constant(0) + + for _ in tf.range(num_games): + with tf.control_dependencies([sync_counter]): + obs = cart_pole_env.reset() + with tf.control_dependencies([obs]): + sync_counter += 1 + + game_rewards = [] + ag.set_element_type(game_rewards, tf.float32) + + for step in tf.range(max_steps_per_game): + logits, actions = self(obs) # pylint:disable=not-callable + logits = tf.reshape(logits, ()) + actions = tf.reshape(actions, ()) + + labels = 1.0 - tf.cast(actions, tf.float32) + loss = tf.nn.sigmoid_cross_entropy_with_logits( + labels=labels, logits=logits) + grads = tf.gradients(loss, var_list) + + for i in range(len(grads)): + grad_list[i].append(grads[i]) + + with tf.control_dependencies([sync_counter]): + obs, reward, done = cart_pole_env.step(actions) + with tf.control_dependencies([obs]): + sync_counter += 1 + obs = tf.reshape(obs, (1, 4)) + + game_rewards.append(reward) + if reward < 0.1 or done: + step_counts.append(step + 1) + break + + discounted_rewards = graph_append_discounted_rewards( + discounted_rewards, game_rewards, discount_rate) + + discounted_rewards = ag.stack(discounted_rewards) + discounted_rewards.set_shape((None,)) + mean, variance = tf.nn.moments(discounted_rewards, [0]) + normalized_rewards = (discounted_rewards - mean) / tf.sqrt(variance) + + for i in range(len(grad_list)): + g = ag.stack(grad_list[i]) + + # This block just adjusts the shapes to match for multiplication. + r = normalized_rewards + if r.shape.ndims < g.shape.ndims: + r = tf.expand_dims(r, -1) + if r.shape.ndims < g.shape.ndims: + r = tf.expand_dims(r, -1) + + grad_list[i] = tf.reduce_mean(g * r, axis=0) + + optimizer.apply_gradients( + zip(grad_list, var_list), global_step=tf.train.get_global_step()) + + return ag.stack(step_counts) + + +@ag.convert() +def graph_train_model(policy_network, cart_pole_env, optimizer, iterations): + """Trains the policy network for a given number of iterations.""" + i = tf.constant(0) + mean_steps_per_iteration = [] + ag.set_element_type(mean_steps_per_iteration, tf.int32) + + while i < iterations: + steps_per_game = policy_network.train( + cart_pole_env, + optimizer, + discount_rate=0.95, + num_games=20, + max_steps_per_game=200) + mean_steps_per_iteration.append(tf.reduce_mean(steps_per_game)) + i += 1 + + return ag.stack(mean_steps_per_iteration) + + +class GraphGymCartpoleEnv(object): + """An env backed by OpenAI Gym's CartPole environment. + + Used to confirm a functional model only. + """ + + def __init__(self): + cart_pole_env = gym.make('CartPole-v1') + cart_pole_env.seed(0) + cart_pole_env.reset() + self.env = cart_pole_env + + def reset(self): + obs = ag.utils.wrap_py_func(self.env.reset, tf.float64, ()) + obs = tf.reshape(obs, (1, 4)) + obs = tf.cast(obs, tf.float32) + return obs + + def step(self, actions): + + def take_one_step(actions): + obs, reward, done, _ = self.env.step(actions) + obs = obs.astype(np.float32) + reward = np.float32(reward) + return obs, reward, done + + return ag.utils.wrap_py_func(take_one_step, + (tf.float32, tf.float32, tf.bool), (actions,)) + + +class GraphRandomCartpoleEnv(object): + """An environment that returns random actions and never finishes. + + Used during benchmarking, it will cause training to run a constant number of + steps. + """ + + def reset(self): + return tf.random.normal((1, 4)) + + def step(self, actions): + with tf.control_dependencies([actions]): + random_obs = tf.random.normal((1, 4)) + fixed_reward = tf.constant(0.001) + done = tf.constant(False) + return random_obs, fixed_reward, done + + +# +# Eager implementation +# + + +def eager_append_discounted_rewards(discounted_rewards, rewards, discount_rate): + cdr = 0.0 + reverse_discounted = [] + + for i in range(len(rewards) - 1, -1, -1): + cdr = cdr * discount_rate + rewards[i] + reverse_discounted.append(cdr) + + discounted_rewards.extend(reversed(reverse_discounted)) + return discounted_rewards + + +class EagerPolicyNetwork(tf.keras.Model): + """Policy network for the cart-pole reinforcement learning problem. + + The forward path of the network takes an observation from the cart-pole + environment (length-4 vector) and outputs an action. + """ + + def __init__(self, hidden_size): + super(EagerPolicyNetwork, self).__init__() + self._hidden_layer = tf.keras.layers.Dense( + hidden_size, activation=tf.nn.elu) + self._output_layer = tf.keras.layers.Dense(1) + + def call(self, inputs): + """Calculates logits and action. + + Args: + inputs: Observations from a step in the cart-pole environment, of shape + `(batch_size, input_size)` + + Returns: + logits: the logits output by the output layer. This can be viewed as the + likelihood vales of choosing the left (0) action. Shape: + `(batch_size, 1)`. + actions: randomly selected actions ({0, 1}) based on the logits. Shape: + `(batch_size, 1)`. + """ + hidden = self._hidden_layer(inputs) + logits = self._output_layer(hidden) + + left_prob = tf.nn.sigmoid(logits) + action_probs = tf.concat([left_prob, 1.0 - left_prob], 1) + + self._grad_fn = eager.implicit_gradients( + self._get_cross_entropy_and_save_actions) + + actions = tf.multinomial(tf.log(action_probs), 1) + return logits, actions + + def _get_cross_entropy_and_save_actions(self, inputs): + logits, actions = self(inputs) # pylint:disable=not-callable + self._current_actions = actions + labels = 1.0 - tf.cast(actions, tf.float32) + return tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits) + + def train(self, cart_pole_env, optimizer, discount_rate, num_games, + max_steps_per_game): + grad_list = None + + step_counts = [] + discounted_rewards = [] + + for _ in range(num_games): + obs = cart_pole_env.reset() + + game_rewards = [] + + for step in range(max_steps_per_game): + grads_and_vars = self._grad_fn(tf.constant([obs], dtype=tf.float32)) + grads, var_list = zip(*grads_and_vars) + actions = self._current_actions.numpy()[0][0] + + if grad_list is None: + grad_list = [[g] for g in grads] + else: + for i in range(len(grads)): + grad_list[i].append(grads[i]) + + obs, reward, done = cart_pole_env.step(actions) + + game_rewards.append(reward) + if reward < 0.1 or done: + step_counts.append(step + 1) + break + + discounted_rewards = eager_append_discounted_rewards( + discounted_rewards, game_rewards, discount_rate) + + discounted_rewards = tf.stack(discounted_rewards) + mean, variance = tf.nn.moments(discounted_rewards, [0]) + normalized_rewards = (discounted_rewards - mean) / tf.sqrt(variance) + + for i in range(len(grad_list)): + g = tf.stack(grad_list[i]) + + r = normalized_rewards + while r.shape.ndims < g.shape.ndims: + r = tf.expand_dims(r, -1) + + grad_list[i] = tf.reduce_mean(g * r, axis=0) + + optimizer.apply_gradients( + zip(grad_list, var_list), global_step=tf.train.get_global_step()) + + return tf.stack(step_counts) + + +def eager_train_model(policy_network, cart_pole_env, optimizer, iterations): + """Trains the policy network for a given number of iterations.""" + mean_steps_per_iteration = [] + + for _ in range(iterations): + steps_per_game = policy_network.train( + cart_pole_env, + optimizer, + discount_rate=0.95, + num_games=20, + max_steps_per_game=200) + mean_steps_per_iteration.append(tf.reduce_mean(steps_per_game)) + + return mean_steps_per_iteration + + +class EagerGymCartpoleEnv(object): + """An env backed by OpenAI Gym's CartPole environment. + + Used to confirm a functional model only. + """ + + def __init__(self): + cart_pole_env = gym.make('CartPole-v1') + cart_pole_env.seed(0) + cart_pole_env.reset() + self.env = cart_pole_env + + def reset(self): + return self.env.reset() + + def step(self, actions): + obs, reward, done, _ = self.env.step(actions) + return obs, reward, done + + +class EagerRandomCartpoleEnv(object): + """An environment that returns random actions and never finishes. + + Used during benchmarking, it will cause training to run a constant number of + steps. + """ + + def reset(self): + return np.random.normal(size=(4,)) + + def step(self, actions): + with tf.control_dependencies([actions]): + random_obs = np.random.normal(size=(4,)) + fixed_reward = 0.001 + done = False + return random_obs, fixed_reward, done + + +def graph_demo_training(): + """Not used in the benchmark. Used to confirm a functional model.""" + with tf.Graph().as_default(): + tf.set_random_seed(0) + + network = GraphPolicyNetwork(hidden_size=5) + network.build((1, 4)) + env = GraphGymCartpoleEnv() + opt = tf.train.AdamOptimizer(0.05) + + train_ops = graph_train_model(network, env, opt, iterations=5) + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + sess.run(tf.local_variables_initializer()) + steps_per_iteration = sess.run(train_ops) + for i, steps in enumerate(steps_per_iteration): + print('Step {} iterations: {}'.format(i, steps)) + + +def eager_demo_training(): + with context.eager_mode(): + network = EagerPolicyNetwork(hidden_size=5) + network.build((1, 4)) + env = EagerGymCartpoleEnv() + opt = tf.train.AdamOptimizer(0.05) + + steps_per_iteration = eager_train_model(network, env, opt, iterations=5) + for i, steps in enumerate(steps_per_iteration): + print('Step {} iterations: {}'.format(i, steps)) + + +class RLCartPoleBenchmark(benchmark_base.ReportingBenchmark): + """Actual benchmark. + + Trains the RL agent a fixed number of times, on random environments that + result in constant number of steps. + """ + + def benchmark_cartpole(self): + + def train_session(sess, ops): + return lambda: sess.run(ops) + + def train_eager(network, env, opt): + return lambda: eager_train_model(network, env, opt, iterations=10) + + for model_size in (10, 100, 1000): + with tf.Graph().as_default(): + network = GraphPolicyNetwork(hidden_size=model_size) + network.build((1, 4)) + env = GraphRandomCartpoleEnv() + opt = tf.train.AdamOptimizer(0.05) + train_ops = graph_train_model(network, env, opt, iterations=10) + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + sess.run(tf.local_variables_initializer()) + + self.time_execution(('cartpole', 'autograph', model_size), + train_session(sess, train_ops), 20) + + with context.eager_mode(): + network = EagerPolicyNetwork(hidden_size=model_size) + network.build((1, 4)) + env = EagerRandomCartpoleEnv() + opt = tf.train.AdamOptimizer(0.05) + + self.time_execution(('cartpole', 'eager', model_size), + train_eager(network, env, opt), 20) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/batching/python/ops/batch_ops.py b/tensorflow/contrib/batching/python/ops/batch_ops.py index 55faad983f2..3e4d0dc1cec 100644 --- a/tensorflow/contrib/batching/python/ops/batch_ops.py +++ b/tensorflow/contrib/batching/python/ops/batch_ops.py @@ -18,8 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.framework import function +from tensorflow.python.eager import function from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_spec from tensorflow.python.ops import gen_batch_ops # go/tf-wildcard-import # pylint: disable=wildcard-import @@ -101,12 +102,15 @@ def batch_function(num_batch_threads, def decorator(fn): # pylint: disable=missing-docstring def decorated(*args): # pylint: disable=missing-docstring - types = [arg.dtype for arg in args] - @function.Defun(*types) + @function.defun() def computation(*computation_args): return fn(*computation_args) + computation = computation.get_concrete_function( + *[tensor_spec.TensorSpec(dtype=x.dtype, shape=x.shape, name=str(i)) + for i, x in enumerate(args)]) + with ops.name_scope("batch") as name: for a in args: if not isinstance(a, ops.Tensor): @@ -123,7 +127,7 @@ def batch_function(num_batch_threads, f=computation, in_tensors=list(args), captured_tensors=computation.captured_inputs, - Tout=[o.type for o in computation.definition.signature.output_arg]) + Tout=[o.dtype for o in computation.outputs]) return decorated diff --git a/tensorflow/contrib/batching/python/ops/batch_ops_test.py b/tensorflow/contrib/batching/python/ops/batch_ops_test.py index 01ee8703a93..9109b9c1c91 100644 --- a/tensorflow/contrib/batching/python/ops/batch_ops_test.py +++ b/tensorflow/contrib/batching/python/ops/batch_ops_test.py @@ -219,6 +219,7 @@ class BatchOpsTest(test.TestCase): @batch_ops.batch_function(1, 10, 100000) def computation(in_t): + self.assertTrue(in_t.shape is not None) return in_t + 1 inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py index 13215ffabf3..8b6ed9f041b 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py @@ -81,7 +81,7 @@ class ExpectationImportanceSampleTest(test.TestCase): # Compute E_p[X_1 * X_2 > 0], with X_i the ith component of X ~ p(x). # Should equal 1/2 because p is a spherical Gaussian centered at (0, 0). def indicator(x): - x1_times_x2 = math_ops.reduce_prod(x, reduction_indices=[-1]) + x1_times_x2 = math_ops.reduce_prod(x, axis=[-1]) return 0.5 * (math_ops.sign(x1_times_x2) + 1.0) prob = mc.expectation_importance_sampler( diff --git a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py index 18d40fc1dff..e83a5485119 100644 --- a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py @@ -353,12 +353,12 @@ def expectation(f, samples, log_prob=None, use_reparametrization=True, def _sample_mean(values): """Mean over sample indices. In this module this is always [0].""" - return math_ops.reduce_mean(values, reduction_indices=[0]) + return math_ops.reduce_mean(values, axis=[0]) def _sample_max(values): """Max over sample indices. In this module this is always [0].""" - return math_ops.reduce_max(values, reduction_indices=[0]) + return math_ops.reduce_max(values, axis=[0]) def _get_samples(dist, z, n, seed): diff --git a/tensorflow/contrib/bigtable/README.md b/tensorflow/contrib/bigtable/README.md index 2c44abed5e1..79052bee35c 100644 --- a/tensorflow/contrib/bigtable/README.md +++ b/tensorflow/contrib/bigtable/README.md @@ -51,25 +51,18 @@ BIGTABLE_TABLE_NAME = '' PREFIX = 'train-' def main(): + tf.enable_eager_execution() + client = tf.contrib.cloud.BigtableClient(GCP_PROJECT_ID, BIGTABLE_INSTANCE_ID) table = client.table(BIGTABLE_TABLE_NAME) dataset = table.keys_by_prefix_dataset(PREFIX) - iterator = dataset.make_initializable_iterator() - get_next_op = iterator.get_next() - with tf.Session() as sess: - print('Initializing the iterator.') - sess.run(iterator.initializer) - print('Retrieving rows:') - row_index = 0 - while True: - try: - row_key = sess.run(get_next_op) - print('Row key %d: %s' % (row_index, row_key)) - row_index += 1 - except tf.errors.OutOfRangeError: - print('Finished reading data!') - break + print('Retrieving rows:') + row_index = 0 + for row_key in dataset: + print('Row key %d: %s' % (row_index, row_key)) + row_index += 1 + print('Finished reading data!') if __name__ == '__main__': main() diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc index f083ce6f44b..e95dc577184 100644 --- a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc @@ -366,6 +366,39 @@ BigtableTestClient::MutateRows( return MakeUnique(request.entries_size()); } +std::unique_ptr> +BigtableTestClient::AsyncMutateRow( + grpc::ClientContext* context, + google::bigtable::v2::MutateRowRequest const& request, + grpc::CompletionQueue* cq) { + LOG(WARNING) << "Call to InMemoryDataClient::" << __func__ + << "(); this will likely cause a crash!"; + return nullptr; +} + +std::unique_ptr<::grpc::ClientAsyncReaderInterface< + ::google::bigtable::v2::SampleRowKeysResponse>> +BigtableTestClient::AsyncSampleRowKeys( + ::grpc::ClientContext* context, + const ::google::bigtable::v2::SampleRowKeysRequest& request, + ::grpc::CompletionQueue* cq, void* tag) { + LOG(WARNING) << "Call to InMemoryDataClient::" << __func__ + << "(); this will likely cause a crash!"; + return nullptr; +} + +std::unique_ptr<::grpc::ClientAsyncReaderInterface< + ::google::bigtable::v2::MutateRowsResponse>> +BigtableTestClient::AsyncMutateRows( + ::grpc::ClientContext* context, + const ::google::bigtable::v2::MutateRowsRequest& request, + ::grpc::CompletionQueue* cq, void* tag) { + LOG(WARNING) << "Call to InMemoryDataClient::" << __func__ + << "(); this will likely cause a crash!"; + return nullptr; +} + std::shared_ptr BigtableTestClient::Channel() { LOG(WARNING) << "Call to InMemoryDataClient::Channel(); this will likely " "cause a crash!"; diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h index dac2b16a216..c4a1f06bc50 100644 --- a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h @@ -61,6 +61,25 @@ class BigtableTestClient : public ::google::cloud::bigtable::DataClient { MutateRows(grpc::ClientContext* context, google::bigtable::v2::MutateRowsRequest const& request) override; + std::unique_ptr> + AsyncMutateRow(grpc::ClientContext* context, + google::bigtable::v2::MutateRowRequest const& request, + grpc::CompletionQueue* cq) override; + + std::unique_ptr<::grpc::ClientAsyncReaderInterface< + ::google::bigtable::v2::SampleRowKeysResponse>> + AsyncSampleRowKeys( + ::grpc::ClientContext* context, + const ::google::bigtable::v2::SampleRowKeysRequest& request, + ::grpc::CompletionQueue* cq, void* tag) override; + + std::unique_ptr<::grpc::ClientAsyncReaderInterface< + ::google::bigtable::v2::MutateRowsResponse>> + AsyncMutateRows(::grpc::ClientContext* context, + const ::google::bigtable::v2::MutateRowsRequest& request, + ::grpc::CompletionQueue* cq, void* tag) override; + std::shared_ptr Channel() override; private: diff --git a/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py b/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py index 316da9ebe15..197f5578eb0 100644 --- a/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py +++ b/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py @@ -57,7 +57,7 @@ class BigtableOpsTest(test.TestCase): sess.run(write_op) def runReadKeyTest(self, read_ds): - itr = read_ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(read_ds) n = itr.get_next() expected = list(self.COMMON_ROW_KEYS) expected.reverse() @@ -78,7 +78,7 @@ class BigtableOpsTest(test.TestCase): self.runReadKeyTest(self._table.keys_by_range_dataset("r1", "r4")) def runScanTest(self, read_ds): - itr = read_ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(read_ds) n = itr.get_next() expected_keys = list(self.COMMON_ROW_KEYS) expected_keys.reverse() @@ -120,7 +120,7 @@ class BigtableOpsTest(test.TestCase): def testLookup(self): ds = self._table.keys_by_prefix_dataset("r") ds = ds.apply(self._table.lookup_columns(cf1="c1")) - itr = ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(ds) n = itr.get_next() expected_keys = list(self.COMMON_ROW_KEYS) expected_values = list(self.COMMON_VALUES) @@ -141,7 +141,7 @@ class BigtableOpsTest(test.TestCase): def testSampleKeys(self): ds = self._table.sample_keys() - itr = ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(ds) n = itr.get_next() expected_key = self.COMMON_ROW_KEYS[0] with self.cached_session() as sess: @@ -161,7 +161,7 @@ class BigtableOpsTest(test.TestCase): sess.run(n) def runSampleKeyPairsTest(self, ds, expected_key_pairs): - itr = ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(ds) n = itr.get_next() with self.cached_session() as sess: self._writeCommonValues(sess) @@ -218,7 +218,7 @@ class BigtableOpsTest(test.TestCase): def testSampleKeyPairsPrefixAndStartKey(self): ds = bigtable_api._BigtableSampleKeyPairsDataset( self._table, prefix="r", start="r1", end="") - itr = ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(ds) with self.cached_session() as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run(itr.initializer) @@ -226,14 +226,14 @@ class BigtableOpsTest(test.TestCase): def testSampleKeyPairsPrefixAndEndKey(self): ds = bigtable_api._BigtableSampleKeyPairsDataset( self._table, prefix="r", start="", end="r3") - itr = ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(ds) with self.cached_session() as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run(itr.initializer) def testParallelScanPrefix(self): ds = self._table.parallel_scan_prefix(prefix="r", cf1="c1") - itr = ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(ds) n = itr.get_next() with self.cached_session() as sess: self._writeCommonValues(sess) @@ -251,7 +251,7 @@ class BigtableOpsTest(test.TestCase): def testParallelScanRange(self): ds = self._table.parallel_scan_range(start="r1", end="r4", cf1="c1") - itr = ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(ds) n = itr.get_next() with self.cached_session() as sess: self._writeCommonValues(sess) diff --git a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py index 7c87b0daeb0..9f97934193d 100644 --- a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py +++ b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py @@ -222,7 +222,7 @@ class BigtableTable(object): A `tf.data.Dataset`. containing `tf.string` Tensors corresponding to all of the row keys matching that prefix. """ - return _BigtablePrefixKeyDataset(self, prefix) + return dataset_ops.DatasetV1Adapter(_BigtablePrefixKeyDataset(self, prefix)) def sample_keys(self): """Retrieves a sampling of row keys from the Bigtable table. @@ -234,7 +234,7 @@ class BigtableTable(object): Returns: A `tf.data.Dataset` returning string row keys. """ - return _BigtableSampleKeysDataset(self) + return dataset_ops.DatasetV1Adapter(_BigtableSampleKeysDataset(self)) def scan_prefix(self, prefix, probability=None, columns=None, **kwargs): """Retrieves row (including values) from the Bigtable service. @@ -279,7 +279,8 @@ class BigtableTable(object): """ probability = _normalize_probability(probability) normalized = _normalize_columns(columns, kwargs) - return _BigtableScanDataset(self, prefix, "", "", normalized, probability) + return dataset_ops.DatasetV1Adapter( + _BigtableScanDataset(self, prefix, "", "", normalized, probability)) def scan_range(self, start, end, probability=None, columns=None, **kwargs): """Retrieves rows (including values) from the Bigtable service. @@ -324,7 +325,8 @@ class BigtableTable(object): """ probability = _normalize_probability(probability) normalized = _normalize_columns(columns, kwargs) - return _BigtableScanDataset(self, "", start, end, normalized, probability) + return dataset_ops.DatasetV1Adapter( + _BigtableScanDataset(self, "", start, end, normalized, probability)) def parallel_scan_prefix(self, prefix, @@ -380,7 +382,8 @@ class BigtableTable(object): """ probability = _normalize_probability(probability) normalized = _normalize_columns(columns, kwargs) - ds = _BigtableSampleKeyPairsDataset(self, prefix, "", "") + ds = dataset_ops.DatasetV1Adapter( + _BigtableSampleKeyPairsDataset(self, prefix, "", "")) return self._make_parallel_scan_dataset(ds, num_parallel_scans, probability, normalized) @@ -442,7 +445,8 @@ class BigtableTable(object): """ probability = _normalize_probability(probability) normalized = _normalize_columns(columns, kwargs) - ds = _BigtableSampleKeyPairsDataset(self, "", start, end) + ds = dataset_ops.DatasetV1Adapter( + _BigtableSampleKeyPairsDataset(self, "", start, end)) return self._make_parallel_scan_dataset(ds, num_parallel_scans, probability, normalized) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD index 14b6fc4ac26..d3b23d949ee 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD +++ b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD @@ -132,6 +132,7 @@ py_library( srcs = ["estimator.py"], srcs_version = "PY2AND3", deps = [ + ":custom_loss_head", ":estimator_utils", ":model", "//tensorflow/contrib/boosted_trees:losses", diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py index a3df272e692..b314b4d74df 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py @@ -41,7 +41,8 @@ def make_custom_export_strategy(name, convert_fn, feature_columns, export_input_fn, - use_core_columns=False): + use_core_columns=False, + feature_engineering_fn=None): """Makes custom exporter of GTFlow tree format. Args: @@ -52,6 +53,7 @@ def make_custom_export_strategy(name, export_input_fn: A function that takes no arguments and returns an `InputFnOps`. use_core_columns: A boolean, whether core feature columns were used. + feature_engineering_fn: Feature eng function to be called on the input. Returns: An `ExportStrategy`. @@ -59,9 +61,12 @@ def make_custom_export_strategy(name, base_strategy = saved_model_export_utils.make_export_strategy( serving_input_fn=export_input_fn, strip_default_attrs=True) input_fn = export_input_fn() + features = input_fn.features + if feature_engineering_fn is not None: + features, _ = feature_engineering_fn(features, labels=None) (sorted_feature_names, dense_floats, sparse_float_indices, _, _, sparse_int_indices, _, _) = gbdt_batch.extract_features( - input_fn.features, feature_columns, use_core_columns) + features, feature_columns, use_core_columns) def export_fn(estimator, export_dir, checkpoint_path=None, eval_result=None): """A wrapper to export to SavedModel, and convert it to other formats.""" diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py index ca73e4af2fb..358404cd946 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py @@ -36,7 +36,7 @@ from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import head as head_lib from tensorflow.python.estimator import estimator as core_estimator from tensorflow.contrib.learn.python.learn.estimators import model_fn -from tensorflow.python.feature_column import feature_column as feature_column_lib +from tensorflow.python.feature_column import feature_column_lib from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py index 38d19976ef3..a178820841c 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools + from tensorflow.contrib.boosted_trees.estimator_batch import model from tensorflow.contrib.boosted_trees.python.utils import losses from tensorflow.contrib.learn.python.learn.estimators import estimator @@ -26,7 +28,8 @@ from tensorflow.python.estimator.canned import head as core_head_lib from tensorflow.python.estimator import estimator as core_estimator from tensorflow.python.ops import math_ops from tensorflow.python.ops.losses import losses as core_losses - +from tensorflow.contrib.boosted_trees.estimator_batch import custom_loss_head +from tensorflow.python.ops import array_ops # ================== Old estimator interface=================================== # The estimators below were designed for old feature columns and old estimator @@ -414,6 +417,108 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator): config=config, feature_engineering_fn=feature_engineering_fn) +# When using this estimator, make sure to regularize the hessian (at least l2, +# min_node_weight)! +# TODO(nponomareva): extend to take multiple quantiles in one go. +class GradientBoostedDecisionTreeQuantileRegressor(estimator.Estimator): + """An estimator that does quantile regression and returns quantile estimates. + """ + + def __init__(self, + learner_config, + examples_per_layer, + quantiles, + label_dimension=1, + num_trees=None, + feature_columns=None, + weight_column_name=None, + model_dir=None, + config=None, + feature_engineering_fn=None, + logits_modifier_function=None, + center_bias=True, + use_core_libs=False, + output_leaf_index=False, + override_global_step_value=None, + num_quantiles=100): + """Initializes a GradientBoostedDecisionTreeQuantileRegressor instance. + + Args: + learner_config: A config for the learner. + examples_per_layer: Number of examples to accumulate before growing a + layer. It can also be a function that computes the number of examples + based on the depth of the layer that's being built. + quantiles: a list of quantiles for the loss, each between 0 and 1. + label_dimension: Dimension of regression label. This is the size + of the last dimension of the labels `Tensor` (typically, this has shape + `[batch_size, label_dimension]`). When label_dimension>1, it is + recommended to use multiclass strategy diagonal hessian or full hessian. + num_trees: An int, number of trees to build. + feature_columns: A list of feature columns. + weight_column_name: Name of the column for weights, or None if not + weighted. + model_dir: Directory for model exports, etc. + config: `RunConfig` object to configure the runtime settings. + feature_engineering_fn: Feature engineering function. Takes features and + labels which are the output of `input_fn` and returns features and + labels which will be fed into the model. + logits_modifier_function: A modifier function for the logits. + center_bias: Whether a separate tree should be created for first fitting + the bias. + use_core_libs: Whether feature columns and loss are from the core (as + opposed to contrib) version of tensorflow. + output_leaf_index: whether to output leaf indices along with predictions + during inference. The leaf node indexes are available in predictions + dict by the key 'leaf_index'. For example, + result_dict = classifier.predict(...) + for example_prediction_result in result_dict: + # access leaf index list by example_prediction_result["leaf_index"] + # which contains one leaf index per tree + override_global_step_value: If after the training is done, global step + value must be reset to this value. This should be used to reset global + step to a number > number of steps used to train the current ensemble. + For example, the usual way is to train a number of trees and set a very + large number of training steps. When the training is done (number of + trees were trained), this parameter can be used to set the global step + to a large value, making it look like that number of training steps ran. + If None, no override of global step will happen. + num_quantiles: Number of quantiles to build for numeric feature values. + """ + + if len(quantiles) > 1: + raise ValueError('For now, just one quantile per estimator is supported') + + def _quantile_regression_head(quantile): + # Use quantile regression. + head = custom_loss_head.CustomLossHead( + loss_fn=functools.partial( + losses.per_example_quantile_regression_loss, quantile=quantile), + link_fn=array_ops.identity, + logit_dimension=label_dimension) + return head + + learner_config.num_classes = max(2, label_dimension) + + super(GradientBoostedDecisionTreeQuantileRegressor, self).__init__( + model_fn=model.model_builder, + params={ + 'head': _quantile_regression_head(quantiles[0]), + 'feature_columns': feature_columns, + 'learner_config': learner_config, + 'num_trees': num_trees, + 'weight_column_name': weight_column_name, + 'examples_per_layer': examples_per_layer, + 'logits_modifier_function': logits_modifier_function, + 'center_bias': center_bias, + 'use_core_libs': use_core_libs, + 'output_leaf_index': False, + 'override_global_step_value': override_global_step_value, + 'num_quantiles': num_quantiles, + }, + model_dir=model_dir, + config=config, + feature_engineering_fn=feature_engineering_fn) + # ================== New Estimator interface=================================== # The estimators below use new core Estimator interface and must be used with # new feature columns and heads. @@ -437,12 +542,42 @@ def core_multiclass_head( # pylint:disable=protected-access head_fn = core_head_lib._multi_class_head_with_softmax_cross_entropy_loss( - n_classes=n_classes, loss_fn=loss_fn, loss_reduction=loss_reduction) + n_classes=n_classes, + loss_fn=loss_fn, + loss_reduction=loss_reduction, + weight_column=weight_column) # pylint:enable=protected-access return head_fn +# For quantile regression, use this head with Core..Estimator, or use +# Core..QuantileRegressor directly, +def core_quantile_regression_head( + quantiles, + label_dimension=1, + weight_column=None, + loss_reduction=core_losses.Reduction.SUM_OVER_NONZERO_WEIGHTS): + """Core head for quantile regression problems.""" + + def loss_fn(labels, logits): + result = losses.per_example_quantile_regression_loss( + labels=labels, + predictions=logits, + weights=weight_column, + quantile=quantiles) + return result[0] + + # pylint:disable=protected-access + head_fn = core_head_lib._regression_head( + label_dimension=label_dimension, + loss_fn=loss_fn, + loss_reduction=loss_reduction, + weight_column=weight_column) + # pylint:enable=protected-access + return head_fn + + class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator): """An estimator using gradient boosted decision trees. @@ -606,3 +741,104 @@ class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator): super(CoreGradientBoostedDecisionTreeRanker, self).__init__( model_fn=_model_fn, model_dir=model_dir, config=config) + + +# When using this estimator, make sure to regularize the hessian (at least l2, +# min_node_weight)! +# TODO(nponomareva): extend to take multiple quantiles in one go. +class CoreGradientBoostedDecisionTreeQuantileRegressor( + core_estimator.Estimator): + """An estimator that does quantile regression and returns quantile estimates. + """ + + def __init__(self, + learner_config, + examples_per_layer, + quantiles, + label_dimension=1, + num_trees=None, + feature_columns=None, + weight_column_name=None, + model_dir=None, + config=None, + label_keys=None, + feature_engineering_fn=None, + logits_modifier_function=None, + center_bias=True, + output_leaf_index=False, + num_quantiles=100): + """Initializes a core version of GradientBoostedDecisionTreeEstimator. + + Args: + learner_config: A config for the learner. + examples_per_layer: Number of examples to accumulate before growing a + layer. It can also be a function that computes the number of examples + based on the depth of the layer that's being built. + quantiles: a list of quantiles for the loss, each between 0 and 1. + label_dimension: Dimension of regression label. This is the size + of the last dimension of the labels `Tensor` (typically, this has shape + `[batch_size, label_dimension]`). When label_dimension>1, it is + recommended to use multiclass strategy diagonal hessian or full hessian. + num_trees: An int, number of trees to build. + feature_columns: A list of feature columns. + weight_column_name: Name of the column for weights, or None if not + weighted. + model_dir: Directory for model exports, etc. + config: `RunConfig` object to configure the runtime settings. + label_keys: Optional list of strings with size `[n_classes]` defining the + label vocabulary. Only supported for `n_classes` > 2. + feature_engineering_fn: Feature engineering function. Takes features and + labels which are the output of `input_fn` and returns features and + labels which will be fed into the model. + logits_modifier_function: A modifier function for the logits. + center_bias: Whether a separate tree should be created for first fitting + the bias. + output_leaf_index: whether to output leaf indices along with predictions + during inference. The leaf node indexes are available in predictions + dict by the key 'leaf_index'. For example, + result_dict = classifier.predict(...) + for example_prediction_result in result_dict: + # access leaf index list by example_prediction_result["leaf_index"] + # which contains one leaf index per tree + num_quantiles: Number of quantiles to build for numeric feature values. + """ + if len(quantiles) > 1: + raise ValueError('For now, just one quantile per estimator is supported') + + def _model_fn(features, labels, mode, config): + return model.model_builder( + features=features, + labels=labels, + mode=mode, + config=config, + params={ + 'head': + core_quantile_regression_head( + quantiles[0], label_dimension=label_dimension), + 'feature_columns': + feature_columns, + 'learner_config': + learner_config, + 'num_trees': + num_trees, + 'weight_column_name': + weight_column_name, + 'examples_per_layer': + examples_per_layer, + 'center_bias': + center_bias, + 'logits_modifier_function': + logits_modifier_function, + 'use_core_libs': + True, + 'output_leaf_index': + output_leaf_index, + 'override_global_step_value': + None, + 'num_quantiles': + num_quantiles, + }, + output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC) + + super(CoreGradientBoostedDecisionTreeQuantileRegressor, self).__init__( + model_fn=_model_fn, model_dir=model_dir, config=config) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py index c155128c0e4..ee052ac6038 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py @@ -25,6 +25,7 @@ from tensorflow.contrib.boosted_trees.proto import learner_pb2 from tensorflow.contrib.layers.python.layers import feature_column as contrib_feature_column from tensorflow.contrib.learn.python.learn.estimators import run_config from tensorflow.python.estimator.canned import head as head_lib +from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.feature_column import feature_column_lib as core_feature_column from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -47,8 +48,8 @@ def _multiclass_train_input_fn(): features = { "x": constant_op.constant([[2.], [1.], [1.], [5.], [3.5], [4.6], [3.5]]) } - label = constant_op.constant( - [[1], [0], [0], [2], [2], [0], [1]], dtype=dtypes.int32) + label = constant_op.constant([[1], [0], [0], [2], [2], [0], [1]], + dtype=dtypes.int32) return features, label @@ -77,6 +78,59 @@ def _infer_ranking_train_input_fn(): return features, None +_QUANTILE_REGRESSION_SIZE = 1000 + + +def _quantile_regression_input_fns(two_dimension=False): + # The data generation is taken from + # http://scikit-learn.org/stable/auto_examples/ensemble/plot_gradient_boosting_quantile.html + np.random.seed(1) + + def f(x): + """The function to predict.""" + return x * np.sin(x) + + def g(x): + """The function to predict.""" + return x * np.cos(x) + + # Training data. + x = np.atleast_2d(np.random.uniform(0, 10.0, + size=_QUANTILE_REGRESSION_SIZE)).T + x = x.astype(np.float32) + + # Labels. + if not two_dimension: + y = f(x).ravel() + else: + y = np.column_stack((f(x).ravel(), g(x).ravel())) + + # Add random noise. + dy = 1.5 + 1.0 * np.random.random(y.shape) + noise = np.random.normal(0, dy) + y += noise + y_original = y.astype(np.float32) + if not two_dimension: + y = y.reshape(_QUANTILE_REGRESSION_SIZE, 1) + + train_input_fn = numpy_io.numpy_input_fn( + x=x, + y=y, + batch_size=_QUANTILE_REGRESSION_SIZE, + num_epochs=None, + shuffle=True) + + # Test on the training data to make sure the predictions are calibrated. + test_input_fn = numpy_io.numpy_input_fn( + x=x, + y=y, + batch_size=_QUANTILE_REGRESSION_SIZE, + num_epochs=1, + shuffle=False) + + return train_input_fn, test_input_fn, y_original + + class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): def setUp(self): @@ -341,6 +395,130 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): for prediction_dict in result_iter: self.assertTrue("classes" in prediction_dict) + # One dimensional quantile regression. + def testQuantileRegression(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 3 + learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE + learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.l1 = 1.0 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.tree_complexity = ( + 1.0 / _QUANTILE_REGRESSION_SIZE) + + train_input_fn, test_input_fn, y = _quantile_regression_input_fns() + + # 95% percentile. + model_upper = estimator.GradientBoostedDecisionTreeQuantileRegressor( + quantiles=[0.95], + learner_config=learner_config, + num_trees=100, + examples_per_layer=_QUANTILE_REGRESSION_SIZE, + center_bias=False) + + model_upper.fit(input_fn=train_input_fn, steps=1000) + result_iter = model_upper.predict(input_fn=test_input_fn) + upper = [] + for prediction_dict in result_iter: + upper.append(prediction_dict["scores"]) + + frac_below_upper = round(1. * np.count_nonzero(upper > y) / len(y), 3) + # +/- 3% + self.assertTrue(frac_below_upper >= 0.92) + self.assertTrue(frac_below_upper <= 0.98) + + train_input_fn, test_input_fn, _ = _quantile_regression_input_fns() + model_lower = estimator.GradientBoostedDecisionTreeQuantileRegressor( + quantiles=[0.05], + learner_config=learner_config, + num_trees=100, + examples_per_layer=_QUANTILE_REGRESSION_SIZE, + center_bias=False) + + model_lower.fit(input_fn=train_input_fn, steps=1000) + result_iter = model_lower.predict(input_fn=test_input_fn) + lower = [] + for prediction_dict in result_iter: + lower.append(prediction_dict["scores"]) + + frac_above_lower = round(1. * np.count_nonzero(lower < y) / len(y), 3) + # +/- 3% + self.assertTrue(frac_above_lower >= 0.92) + self.assertTrue(frac_above_lower <= 0.98) + + # Multi-dimensional quantile regression. + def testQuantileRegressionMultiDimLabel(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 3 + learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE + learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.l1 = 1.0 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.tree_complexity = ( + 1.0 / _QUANTILE_REGRESSION_SIZE) + + train_input_fn, test_input_fn, y = _quantile_regression_input_fns( + two_dimension=True) + + # 95% percentile. + model_upper = estimator.GradientBoostedDecisionTreeQuantileRegressor( + quantiles=[0.95], + learner_config=learner_config, + label_dimension=2, + num_trees=100, + examples_per_layer=_QUANTILE_REGRESSION_SIZE, + center_bias=False) + + model_upper.fit(input_fn=train_input_fn, steps=1000) + result_iter = model_upper.predict(input_fn=test_input_fn) + upper = [] + for prediction_dict in result_iter: + upper.append(prediction_dict["scores"]) + + count_below_upper = np.count_nonzero(upper > y, axis=0) + count_both_below_upper = np.count_nonzero(np.prod(upper > y, axis=1)) + frac_below_upper_0 = round(1. * count_below_upper[0] / len(y), 3) + frac_below_upper_1 = round(1. * count_below_upper[1] / len(y), 3) + frac_both_below_upper = round(1. * count_both_below_upper / len(y), 3) + # +/- 3% + self.assertTrue(frac_below_upper_0 >= 0.92) + self.assertTrue(frac_below_upper_0 <= 0.98) + self.assertTrue(frac_below_upper_1 >= 0.92) + self.assertTrue(frac_below_upper_1 <= 0.98) + self.assertTrue(frac_both_below_upper >= 0.92) + self.assertTrue(frac_both_below_upper <= 0.98) + + train_input_fn, test_input_fn, _ = _quantile_regression_input_fns( + two_dimension=True) + model_lower = estimator.GradientBoostedDecisionTreeQuantileRegressor( + quantiles=[0.05], + learner_config=learner_config, + label_dimension=2, + num_trees=100, + examples_per_layer=_QUANTILE_REGRESSION_SIZE, + center_bias=False) + + model_lower.fit(input_fn=train_input_fn, steps=1000) + result_iter = model_lower.predict(input_fn=test_input_fn) + lower = [] + for prediction_dict in result_iter: + lower.append(prediction_dict["scores"]) + + count_above_lower = np.count_nonzero(lower < y, axis=0) + count_both_aboce_lower = np.count_nonzero(np.prod(lower < y, axis=1)) + frac_above_lower_0 = round(1. * count_above_lower[0] / len(y), 3) + frac_above_lower_1 = round(1. * count_above_lower[1] / len(y), 3) + frac_both_above_lower = round(1. * count_both_aboce_lower / len(y), 3) + # +/- 3% + self.assertTrue(frac_above_lower_0 >= 0.92) + self.assertTrue(frac_above_lower_0 <= 0.98) + self.assertTrue(frac_above_lower_1 >= 0.92) + self.assertTrue(frac_above_lower_1 <= 0.98) + self.assertTrue(frac_both_above_lower >= 0.92) + self.assertTrue(frac_both_above_lower <= 0.98) + class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): @@ -489,8 +667,8 @@ class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): feature_columns = [ core_feature_column.weighted_categorical_column( - categorical_column=core_feature_column. - categorical_column_with_vocabulary_list( + categorical_column=core_feature_column + .categorical_column_with_vocabulary_list( key="word", vocabulary_list=["the", "cat", "dog"]), weight_feature_key="weight") ] @@ -509,8 +687,8 @@ class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): # Weights for the words are 5 - cat, 6- dog and 1 -the. features_dict["word"] = sparse_tensor.SparseTensor( indices=[[0, 0], [0, 1], [1, 0], [3, 0]], - values=constant_op.constant( - ["the", "cat", "dog", "the"], dtype=dtypes.string), + values=constant_op.constant(["the", "cat", "dog", "the"], + dtype=dtypes.string), dense_shape=[4, 3]) features_dict["weight"] = sparse_tensor.SparseTensor( indices=[[0, 0], [0, 1], [1, 0], [3, 0]], @@ -534,6 +712,132 @@ class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): est.evaluate(input_fn=input_fn, steps=1) est.predict(input_fn=input_fn) + # One dimensional quantile regression. + def testQuantileRegression(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 3 + learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE + learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.l1 = 1.0 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.tree_complexity = ( + 1.0 / _QUANTILE_REGRESSION_SIZE) + + train_input_fn, test_input_fn, y = _quantile_regression_input_fns() + y = y.reshape(_QUANTILE_REGRESSION_SIZE, 1) + + # 95% percentile. + model_upper = estimator.CoreGradientBoostedDecisionTreeQuantileRegressor( + quantiles=[0.95], + learner_config=learner_config, + num_trees=100, + examples_per_layer=_QUANTILE_REGRESSION_SIZE, + center_bias=False) + + model_upper.train(input_fn=train_input_fn, steps=1000) + result_iter = model_upper.predict(input_fn=test_input_fn) + upper = [] + for prediction_dict in result_iter: + upper.append(prediction_dict["predictions"]) + + frac_below_upper = round(1. * np.count_nonzero(upper > y) / len(y), 3) + # +/- 3% + self.assertTrue(frac_below_upper >= 0.92) + self.assertTrue(frac_below_upper <= 0.98) + + train_input_fn, test_input_fn, _ = _quantile_regression_input_fns() + model_lower = estimator.CoreGradientBoostedDecisionTreeQuantileRegressor( + quantiles=[0.05], + learner_config=learner_config, + num_trees=100, + examples_per_layer=_QUANTILE_REGRESSION_SIZE, + center_bias=False) + + model_lower.train(input_fn=train_input_fn, steps=1000) + result_iter = model_lower.predict(input_fn=test_input_fn) + lower = [] + for prediction_dict in result_iter: + lower.append(prediction_dict["predictions"]) + + frac_above_lower = round(1. * np.count_nonzero(lower < y) / len(y), 3) + # +/- 3% + self.assertTrue(frac_above_lower >= 0.92) + self.assertTrue(frac_above_lower <= 0.98) + + # Multi-dimensional quantile regression. + def testQuantileRegressionMultiDimLabel(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 3 + learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE + learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.l1 = 1.0 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.tree_complexity = ( + 1.0 / _QUANTILE_REGRESSION_SIZE) + + train_input_fn, test_input_fn, y = _quantile_regression_input_fns( + two_dimension=True) + y = y.reshape(_QUANTILE_REGRESSION_SIZE, 2) + + # 95% percentile. + model_upper = estimator.CoreGradientBoostedDecisionTreeQuantileRegressor( + quantiles=[0.95], + learner_config=learner_config, + num_trees=100, + label_dimension=2, + examples_per_layer=_QUANTILE_REGRESSION_SIZE, + center_bias=False) + + model_upper.train(input_fn=train_input_fn, steps=1000) + result_iter = model_upper.predict(input_fn=test_input_fn) + upper = [] + for prediction_dict in result_iter: + upper.append(prediction_dict["predictions"]) + + count_below_upper = np.count_nonzero(upper > y, axis=0) + count_both_below_upper = np.count_nonzero(np.prod(upper > y, axis=1)) + frac_below_upper_0 = round(1. * count_below_upper[0] / len(y), 3) + frac_below_upper_1 = round(1. * count_below_upper[1] / len(y), 3) + frac_both_below_upper = round(1. * count_both_below_upper / len(y), 3) + # +/- 3% + self.assertTrue(frac_below_upper_0 >= 0.92) + self.assertTrue(frac_below_upper_0 <= 0.98) + self.assertTrue(frac_below_upper_1 >= 0.92) + self.assertTrue(frac_below_upper_1 <= 0.98) + self.assertTrue(frac_both_below_upper >= 0.92) + self.assertTrue(frac_both_below_upper <= 0.98) + + train_input_fn, test_input_fn, _ = _quantile_regression_input_fns( + two_dimension=True) + model_lower = estimator.CoreGradientBoostedDecisionTreeQuantileRegressor( + quantiles=[0.05], + learner_config=learner_config, + num_trees=100, + label_dimension=2, + examples_per_layer=_QUANTILE_REGRESSION_SIZE, + center_bias=False) + + model_lower.train(input_fn=train_input_fn, steps=1000) + result_iter = model_lower.predict(input_fn=test_input_fn) + lower = [] + for prediction_dict in result_iter: + lower.append(prediction_dict["predictions"]) + + count_above_lower = np.count_nonzero(lower < y, axis=0) + count_both_aboce_lower = np.count_nonzero(np.prod(lower < y, axis=1)) + frac_above_lower_0 = round(1. * count_above_lower[0] / len(y), 3) + frac_above_lower_1 = round(1. * count_above_lower[1] / len(y), 3) + frac_both_above_lower = round(1. * count_both_aboce_lower / len(y), 3) + # +/- 3% + self.assertTrue(frac_above_lower_0 >= 0.92) + self.assertTrue(frac_above_lower_0 <= 0.98) + self.assertTrue(frac_above_lower_1 >= 0.92) + self.assertTrue(frac_above_lower_1 <= 0.98) + self.assertTrue(frac_both_above_lower >= 0.92) + self.assertTrue(frac_both_above_lower <= 0.98) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/contrib/boosted_trees/examples/boston.py b/tensorflow/contrib/boosted_trees/examples/boston.py index 54c4ff059e3..09b240a7006 100644 --- a/tensorflow/contrib/boosted_trees/examples/boston.py +++ b/tensorflow/contrib/boosted_trees/examples/boston.py @@ -90,13 +90,13 @@ def _make_experiment_fn(output_dir): (x_train, y_train), (x_test, y_test) = tf.keras.datasets.boston_housing.load_data() - train_input_fn = tf.estimator.inputs.numpy_input_fn( + train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn( x={"x": x_train}, y=y_train, batch_size=FLAGS.batch_size, num_epochs=None, shuffle=True) - eval_input_fn = tf.estimator.inputs.numpy_input_fn( + eval_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn( x={"x": x_test}, y=y_test, num_epochs=1, shuffle=False) feature_columns = [ diff --git a/tensorflow/contrib/boosted_trees/examples/boston_combined.py b/tensorflow/contrib/boosted_trees/examples/boston_combined.py index e04b56afbfd..d640af354f5 100644 --- a/tensorflow/contrib/boosted_trees/examples/boston_combined.py +++ b/tensorflow/contrib/boosted_trees/examples/boston_combined.py @@ -80,13 +80,13 @@ def _make_experiment_fn(output_dir): (x_train, y_train), (x_test, y_test) = tf.keras.datasets.boston_housing.load_data() - train_input_fn = tf.estimator.inputs.numpy_input_fn( + train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn( x={"x": x_train}, y=y_train, batch_size=FLAGS.batch_size, num_epochs=None, shuffle=True) - eval_input_fn = tf.estimator.inputs.numpy_input_fn( + eval_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn( x={"x": x_test}, y=y_test, num_epochs=1, shuffle=False) feature_columns = [ diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc index 8edb5d6c640..6d78e27e8f6 100644 --- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc @@ -834,8 +834,13 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel { root_gradient_stats *= normalizer_ratio; NodeStats root_stats = state->ComputeNodeStats(root_gradient_stats); int32 best_feature_idx = 0; + bool best_feature_updated = false; NodeStats best_right_node_stats(0); NodeStats best_left_node_stats(0); + CHECK(end_index - start_index >= 2) + << "Partition should have a non bias feature. Start index " + << start_index << " and end index " << end_index; + for (int64 feature_idx = start_index + 1; feature_idx < end_index; ++feature_idx) { GradientStats left_gradient_stats(*gradients_t, *hessians_t, @@ -845,11 +850,13 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel { root_gradient_stats - left_gradient_stats; NodeStats left_stats = state->ComputeNodeStats(left_gradient_stats); NodeStats right_stats = state->ComputeNodeStats(right_gradient_stats); - if (left_stats.gain + right_stats.gain > best_gain) { + if (!best_feature_updated || + left_stats.gain + right_stats.gain > best_gain) { best_gain = left_stats.gain + right_stats.gain; best_left_node_stats = left_stats; best_right_node_stats = right_stats; best_feature_idx = feature_idx; + best_feature_updated = true; } } SplitInfo split_info; @@ -864,7 +871,7 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel { << feature_ids(best_feature_idx, 0) << ", " << feature_ids(best_feature_idx, 1) << "\nPartition IDS: " << partition_ids(start_index) << " " - << partition_ids(best_feature_idx); + << partition_ids(best_feature_idx) << " and best gain " << best_gain; equality_split->set_feature_id(feature_ids(best_feature_idx, 0)); auto* left_child = split_info.mutable_left_child(); auto* right_child = split_info.mutable_right_child(); diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py index 4da25298cb8..d26af584197 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py @@ -119,7 +119,7 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler): def not_active_inputs(): return (constant_op.constant([], dtype=dtypes.int32), - constant_op.constant([], dtype=dtypes.int64, shape=[1, 2]), + constant_op.constant_v1([], dtype=dtypes.int64, shape=[1, 2]), empty_gradients, empty_hessians) def active_inputs(): diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py index a2f708081a4..386dc19fc7b 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py @@ -36,9 +36,9 @@ def get_empty_tensors(gradient_shape, hessian_shape): empty_hess_shape = [1] + hessian_shape.as_list() empty_grad_shape = [1] + gradient_shape.as_list() - empty_gradients = constant_op.constant( + empty_gradients = constant_op.constant_v1( [], dtype=dtypes.float32, shape=empty_grad_shape) - empty_hessians = constant_op.constant( + empty_hessians = constant_op.constant_v1( [], dtype=dtypes.float32, shape=empty_hess_shape) return empty_gradients, empty_hessians @@ -486,8 +486,8 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase): gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0]) hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13]) partition_ids = [0, 0, 0, 1] - indices = array_ops.constant([], dtype=dtypes.int64, shape=[0, 2]) - values = array_ops.constant([], dtype=dtypes.int64) + indices = constant_op.constant_v1([], dtype=dtypes.int64, shape=[0, 2]) + values = constant_op.constant_v1([], dtype=dtypes.int64) gradient_shape = tensor_shape.scalar() hessian_shape = tensor_shape.scalar() diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py index 1fffbb5f660..0476bed2cd3 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py @@ -605,7 +605,7 @@ def dense_make_stats_update(is_active, are_buckets_ready, float_column, quantile_buckets, example_partition_ids, gradients, hessians, weights, empty_gradients, empty_hessians): """Updates the state for dense split handler.""" - empty_float = constant_op.constant([], dtype=dtypes.float32) + empty_float = constant_op.constant_v1([], dtype=dtypes.float32) quantile_values, quantile_weights = control_flow_ops.cond( is_active[1], # For the next layer, this handler is inactive. @@ -621,8 +621,8 @@ def dense_make_stats_update(is_active, are_buckets_ready, float_column, return (example_partition_ids, quantized_feature, gradients, hessians) def not_ready_inputs_fn(): - return (constant_op.constant([], dtype=dtypes.int32), - constant_op.constant([[]], dtype=dtypes.int64, shape=[1, 2]), + return (constant_op.constant_v1([], dtype=dtypes.int32), + constant_op.constant_v1([[]], dtype=dtypes.int64, shape=[1, 2]), empty_gradients, empty_hessians) example_partition_ids, feature_ids, gradients, hessians = ( @@ -708,11 +708,11 @@ def sparse_make_stats_update( def quantiles_not_ready(): """The subgraph for when the quantiles are not ready.""" - return (constant_op.constant([], dtype=dtypes.int32), - constant_op.constant([], dtype=dtypes.int64, shape=[1, 2]), + return (constant_op.constant_v1([], dtype=dtypes.int32), + constant_op.constant_v1([], dtype=dtypes.int64, shape=[1, 2]), empty_gradients, empty_hessians) - empty_float = constant_op.constant([], dtype=dtypes.float32) + empty_float = constant_op.constant_v1([], dtype=dtypes.float32) handler_not_active = (constant_op.constant( [], dtype=dtypes.int64, shape=[0, 2]), empty_float, constant_op.constant([0, 1], dtype=dtypes.int64), diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py index 74b0ea6989c..4a1b528646e 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py @@ -39,9 +39,9 @@ def get_empty_tensors(gradient_shape, hessian_shape): empty_hess_shape = [1] + hessian_shape.as_list() empty_grad_shape = [1] + gradient_shape.as_list() - empty_gradients = constant_op.constant( + empty_gradients = constant_op.constant_v1( [], dtype=dtypes.float32, shape=empty_grad_shape) - empty_hessians = constant_op.constant( + empty_hessians = constant_op.constant_v1( [], dtype=dtypes.float32, shape=empty_hess_shape) return empty_gradients, empty_hessians @@ -1476,9 +1476,9 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): def testEmpty(self): with self.cached_session() as sess: - indices = array_ops.constant([], dtype=dtypes.int64, shape=[0, 2]) + indices = constant_op.constant_v1([], dtype=dtypes.int64, shape=[0, 2]) # No values in this feature column in this mini-batch. - values = array_ops.constant([], dtype=dtypes.float32) + values = constant_op.constant_v1([], dtype=dtypes.float32) sparse_column = sparse_tensor.SparseTensor(indices, values, [4, 1]) gradient_shape = tensor_shape.scalar() @@ -1549,8 +1549,9 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): sparse_column = array_ops.sparse_placeholder(dtypes.float32) # We have two batches - at first, a sparse feature is empty. - empty_indices = array_ops.constant([], dtype=dtypes.int64, shape=[0, 2]) - empty_values = array_ops.constant([], dtype=dtypes.float32) + empty_indices = constant_op.constant_v1([], dtype=dtypes.int64, + shape=[0, 2]) + empty_values = constant_op.constant_v1([], dtype=dtypes.float32) empty_sparse_column = sparse_tensor.SparseTensor(empty_indices, empty_values, [4, 2]) empty_sparse_column = empty_sparse_column.eval(session=sess) diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index ab5713fbe26..9fdc2fc0c2c 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -897,9 +897,9 @@ class GradientBoostedDecisionTreeModel(object): empty_hess_shape = [1] + self._hessian_shape.as_list() empty_grad_shape = [1] + self._gradient_shape.as_list() - empty_gradients = constant_op.constant( + empty_gradients = constant_op.constant_v1( [], dtype=dtypes.float32, shape=empty_grad_shape) - empty_hessians = constant_op.constant( + empty_hessians = constant_op.constant_v1( [], dtype=dtypes.float32, shape=empty_hess_shape) active_handlers = array_ops.unstack(active_handlers, axis=0) @@ -1257,13 +1257,12 @@ class GradientBoostedDecisionTreeModel(object): def _get_replica_device_setter(self, worker_device): """Creates a replica device setter.""" ps_tasks = self._num_ps_replicas - ps_ops = [ - "Variable", - "VariableV2", + ps_ops = list(device_setter.STANDARD_PS_OPS) + ps_ops.extend([ "DecisionTreeEnsembleResourceHandleOp", "StatsAccumulatorScalarResourceHandleOp", "StatsAccumulatorTensorResourceHandleOp", - ] + ]) ps_strategy = _OpRoundRobinStrategy(ps_ops, ps_tasks) return device_setter.replica_device_setter( worker_device=worker_device, diff --git a/tensorflow/contrib/boosted_trees/python/utils/losses.py b/tensorflow/contrib/boosted_trees/python/utils/losses.py index b5ebaf19995..220e981618b 100644 --- a/tensorflow/contrib/boosted_trees/python/utils/losses.py +++ b/tensorflow/contrib/boosted_trees/python/utils/losses.py @@ -48,6 +48,47 @@ def per_example_logistic_loss(labels, weights, predictions): labels=labels, logits=predictions) return unweighted_loss * weights, control_flow_ops.no_op() +# MUST USE WITH HESSIAN REGULARIZATION, +# This loss can have zero hessian, so it must be used with l2 or min_node_weight +# regularization. +# An example config is +# learner_config.constraints.min_node_weight = 1 / num_examples_per_layer +# learner_config.regularization.l2 = 1.0 / num_examples_per_layer +# TODO(nponomareva): make it multidimensional so we can estimate several +# quantiles at once. +def per_example_quantile_regression_loss(labels, weights, predictions, + quantile): + """Smoothed loss for quantile regression. + + The standard quantile regression loss is quantile*(y-y') when y>y' and + (quantile-1)*(y-y') otherwise, y' is a prediction, y is a label. The impl + below is this loss but squared in the region where the loss value < 1. + + Args: + labels: Rank 2 (N, D) tensor of per-example labels. + weights: Rank 2 (N, 1) tensor of per-example weights. + predictions: Rank 2 (N, D) tensor of per-example predictions. + quantile: The quantile to use. + + Returns: + loss: A Rank 2 (N, 1) tensor of per-example quantile loss. + update_op: An update operation to update the loss's internal state. + """ + labels = math_ops.to_float(labels) + error = labels - predictions + square_loss_right = array_ops.where(error * quantile < 1.0, + math_ops.square(quantile * error), + quantile * error) + square_loss_left = array_ops.where(error * (quantile - 1) < 1, + math_ops.square((quantile - 1) * error), + (quantile - 1) * error) + + unweighted_loss = array_ops.where(error > 0, square_loss_right, + square_loss_left) + if weights is None: + return unweighted_loss, control_flow_ops.no_op() + else: + return unweighted_loss * weights, control_flow_ops.no_op() # This is classical form of Maximum entropy loss, that is twice differentiable # (sparse_softmax_cross_entropy which is what we go for is not twice @@ -78,8 +119,7 @@ def per_example_maxent_loss(labels, weights, logits, num_classes, eps=1e-15): labels = array_ops.expand_dims(labels, 1) # Labels are indices of classes, convert them to one hot encodings. target_one_hot = array_ops.one_hot(indices=labels, depth=num_classes) - labels = math_ops.reduce_sum( - input_tensor=target_one_hot, reduction_indices=[1]) + labels = math_ops.reduce_sum(input_tensor=target_one_hot, axis=[1]) labels = math_ops.to_float(labels) # Calculate softmax probabilities for each class. diff --git a/tensorflow/contrib/checkpoint/python/containers.py b/tensorflow/contrib/checkpoint/python/containers.py index 242c1e8ba45..5418e2605b7 100644 --- a/tensorflow/contrib/checkpoint/python/containers.py +++ b/tensorflow/contrib/checkpoint/python/containers.py @@ -46,6 +46,10 @@ class UniqueNameTracker(data_structures.CheckpointableDataStructure): self._maybe_initialize_checkpointable() self._name_counts = {} + @property + def _values(self): + return [dep.ref for dep in self._checkpoint_dependencies] + def track(self, checkpointable, base_name): """Add a dependency on `checkpointable`. diff --git a/tensorflow/contrib/cluster_resolver/BUILD b/tensorflow/contrib/cluster_resolver/BUILD index 9e1867ea9d0..f944b7f8843 100644 --- a/tensorflow/contrib/cluster_resolver/BUILD +++ b/tensorflow/contrib/cluster_resolver/BUILD @@ -21,85 +21,18 @@ py_library( py_library( name = "cluster_resolver_py", - srcs = [ + srcs = glob([ "__init__.py", - "python/training/__init__.py", - ], + "python/training/*.py", + ]), srcs_version = "PY2AND3", visibility = ["//visibility:public"], - deps = [ - ":base_cluster_resolver_py", - ":gce_cluster_resolver_py", - ":kubernetes_cluster_resolver_py", - ":slurm_cluster_resolver_py", - ":tfconfig_cluster_resolver_py", - ":tpu_cluster_resolver_py", - "//tensorflow/python:util", - ], -) - -py_library( - name = "base_cluster_resolver_py", - srcs = ["python/training/cluster_resolver.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/python:training", - ], -) - -py_library( - name = "gce_cluster_resolver_py", - srcs = ["python/training/gce_cluster_resolver.py"], - srcs_version = "PY2AND3", - deps = [ - ":base_cluster_resolver_py", - "//tensorflow/python:training", - ], -) - -py_library( - name = "tfconfig_cluster_resolver_py", - srcs = ["python/training/tfconfig_cluster_resolver.py"], - srcs_version = "PY2AND3", - deps = [ - ":base_cluster_resolver_py", - "//tensorflow/python:training", - ], -) - -py_library( - name = "tpu_cluster_resolver_py", - srcs = ["python/training/tpu_cluster_resolver.py"], - srcs_version = "PY2AND3", - deps = [ - ":base_cluster_resolver_py", - "//tensorflow/python:training", - ], -) - -py_library( - name = "slurm_cluster_resolver_py", - srcs = ["python/training/slurm_cluster_resolver.py"], - srcs_version = "PY2AND3", - deps = [ - ":base_cluster_resolver_py", - "//tensorflow/python:training", - ], -) - -py_library( - name = "kubernetes_cluster_resolver_py", - srcs = ["python/training/kubernetes_cluster_resolver.py"], - srcs_version = "PY2AND3", - deps = [ - ":base_cluster_resolver_py", - "//tensorflow/python:training", - ], + deps = ["//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib"], ) tf_py_test( - name = "base_cluster_resolver_py_test", - srcs = ["python/training/cluster_resolver_test.py"], + name = "cluster_resolver_initialization_test", + srcs = ["cluster_resolver_initialization_test.py"], additional_deps = [ ":cluster_resolver_py", "//tensorflow/python:client_testlib", @@ -108,86 +41,5 @@ tf_py_test( "//tensorflow/python:platform_test", "//tensorflow/python:training", ], - main = "python/training/cluster_resolver_test.py", -) - -tf_py_test( - name = "gce_cluster_resolver_py_test", - size = "small", - srcs = ["python/training/gce_cluster_resolver_test.py"], - additional_deps = [ - ":cluster_resolver_py", - ":gce_cluster_resolver_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", - "//tensorflow/python:training", - ], - main = "python/training/gce_cluster_resolver_test.py", -) - -tf_py_test( - name = "tfconfig_cluster_resolver_py_test", - size = "small", - srcs = ["python/training/tfconfig_cluster_resolver_test.py"], - additional_deps = [ - ":tfconfig_cluster_resolver_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", - "//tensorflow/python:training", - ], - grpc_enabled = True, - main = "python/training/tfconfig_cluster_resolver_test.py", -) - -tf_py_test( - name = "tpu_cluster_resolver_py_test", - size = "small", - srcs = ["python/training/tpu_cluster_resolver_test.py"], - additional_deps = [ - ":tpu_cluster_resolver_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", - "//tensorflow/python:training", - ], - grpc_enabled = True, - main = "python/training/tpu_cluster_resolver_test.py", -) - -tf_py_test( - name = "slurm_cluster_resolver_py_test", - size = "small", - srcs = ["python/training/slurm_cluster_resolver_test.py"], - additional_deps = [ - ":cluster_resolver_py", - ":slurm_cluster_resolver_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", - "//tensorflow/python:training", - ], - main = "python/training/slurm_cluster_resolver_test.py", - tags = [], -) - -tf_py_test( - name = "kubernetes_cluster_resolver_py_test", - size = "small", - srcs = ["python/training/kubernetes_cluster_resolver_test.py"], - additional_deps = [ - ":cluster_resolver_py", - ":kubernetes_cluster_resolver_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", - "//tensorflow/python:training", - ], - main = "python/training/kubernetes_cluster_resolver_test.py", + main = "cluster_resolver_initialization_test.py", ) diff --git a/tensorflow/contrib/cluster_resolver/__init__.py b/tensorflow/contrib/cluster_resolver/__init__.py index fd1263fe81a..390b3e7550b 100644 --- a/tensorflow/contrib/cluster_resolver/__init__.py +++ b/tensorflow/contrib/cluster_resolver/__init__.py @@ -20,12 +20,14 @@ from __future__ import division from __future__ import print_function # pylint: disable=wildcard-import,unused-import -from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver -from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import SimpleClusterResolver -from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import UnionClusterResolver -from tensorflow.contrib.cluster_resolver.python.training.gce_cluster_resolver import GceClusterResolver -from tensorflow.contrib.cluster_resolver.python.training.slurm_cluster_resolver import SlurmClusterResolver -from tensorflow.contrib.cluster_resolver.python.training.tpu_cluster_resolver import TPUClusterResolver +from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver +from tensorflow.python.distribute.cluster_resolver.cluster_resolver import SimpleClusterResolver +from tensorflow.python.distribute.cluster_resolver.cluster_resolver import UnionClusterResolver +from tensorflow.python.distribute.cluster_resolver.gce_cluster_resolver import GceClusterResolver +from tensorflow.python.distribute.cluster_resolver.kubernetes_cluster_resolver import KubernetesClusterResolver +from tensorflow.python.distribute.cluster_resolver.slurm_cluster_resolver import SlurmClusterResolver +from tensorflow.python.distribute.cluster_resolver.tfconfig_cluster_resolver import TFConfigClusterResolver +from tensorflow.python.distribute.cluster_resolver.tpu_cluster_resolver import TPUClusterResolver # pylint: enable=wildcard-import,unused-import from tensorflow.python.util.all_util import remove_undocumented @@ -35,6 +37,8 @@ _allowed_symbols = [ 'SimpleClusterResolver', 'UnionClusterResolver', 'GceClusterResolver', + 'KubernetesClusterResolver', + 'TFConfigClusterResolver', 'TPUClusterResolver', 'SlurmClusterResolver', ] diff --git a/tensorflow/contrib/cluster_resolver/cluster_resolver_initialization_test.py b/tensorflow/contrib/cluster_resolver/cluster_resolver_initialization_test.py new file mode 100644 index 00000000000..01ff1478c69 --- /dev/null +++ b/tensorflow/contrib/cluster_resolver/cluster_resolver_initialization_test.py @@ -0,0 +1,53 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests to ensure ClusterResolvers are usable via the old contrib path.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.cluster_resolver import SimpleClusterResolver +from tensorflow.contrib.cluster_resolver.python.training import cluster_resolver +from tensorflow.contrib.cluster_resolver.python.training import UnionClusterResolver +from tensorflow.python.platform import test +from tensorflow.python.training import server_lib + + +class ClusterResolverInitializationTest(test.TestCase): + + def testCreateSimpleClusterResolverFromLib(self): + base_cluster_spec = server_lib.ClusterSpec({ + "ps": ["ps0:2222", "ps1:2222"], + "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] + }) + cluster_resolver.SimpleClusterResolver(base_cluster_spec) + + def testCreateSimpleClusterResolver(self): + base_cluster_spec = server_lib.ClusterSpec({ + "ps": ["ps0:2222", "ps1:2222"], + "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] + }) + SimpleClusterResolver(base_cluster_spec) + + def testCreateUnionClusterResolver(self): + base_cluster_spec = server_lib.ClusterSpec({ + "ps": ["ps0:2222", "ps1:2222"], + "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] + }) + simple_cr = SimpleClusterResolver(base_cluster_spec) + UnionClusterResolver(simple_cr) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/cluster_resolver/python/training/__init__.py b/tensorflow/contrib/cluster_resolver/python/training/__init__.py index 6d9120a3b96..10d93549ebb 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/__init__.py +++ b/tensorflow/contrib/cluster_resolver/python/training/__init__.py @@ -18,11 +18,36 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver -from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import SimpleClusterResolver -from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import UnionClusterResolver -from tensorflow.contrib.cluster_resolver.python.training.gce_cluster_resolver import GceClusterResolver -from tensorflow.contrib.cluster_resolver.python.training.kubernetes_cluster_resolver import KubernetesClusterResolver -from tensorflow.contrib.cluster_resolver.python.training.slurm_cluster_resolver import SlurmClusterResolver -from tensorflow.contrib.cluster_resolver.python.training.tfconfig_cluster_resolver import TFConfigClusterResolver -from tensorflow.contrib.cluster_resolver.python.training.tpu_cluster_resolver import TPUClusterResolver +# This file (and all files in this directory in general) is a backwards +# compatibility shim that exists to re-export ClusterResolvers such that +# existing OSS code will not be broken. + +from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver +from tensorflow.python.distribute.cluster_resolver.cluster_resolver import SimpleClusterResolver +from tensorflow.python.distribute.cluster_resolver.cluster_resolver import UnionClusterResolver +from tensorflow.python.distribute.cluster_resolver.gce_cluster_resolver import GceClusterResolver +from tensorflow.python.distribute.cluster_resolver.kubernetes_cluster_resolver import KubernetesClusterResolver +from tensorflow.python.distribute.cluster_resolver.slurm_cluster_resolver import SlurmClusterResolver +from tensorflow.python.distribute.cluster_resolver.tfconfig_cluster_resolver import TFConfigClusterResolver +from tensorflow.python.distribute.cluster_resolver.tpu_cluster_resolver import TPUClusterResolver + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + 'cluster_resolver', + 'gce_cluster_resolver', + 'kubernetes_cluster_resolver', + 'slurm_cluster_resolver', + 'tfconfig_cluster_resolver', + 'tpu_cluster_resolver', + 'ClusterResolver', + 'SimpleClusterResolver', + 'UnionClusterResolver', + 'GceClusterResolver', + 'KubernetesClusterResolver', + 'TFConfigClusterResolver', + 'TPUClusterResolver', + 'SlurmClusterResolver', +] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py index 40b1e667ee6..99840fb5166 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py @@ -1,4 +1,4 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,333 +12,29 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Cluster Resolvers are used for dynamic cluster IP/hostname resolution.""" +"""Stub file for ClusterResolver to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import abc +# This file (and all files in this directory in general) is a backwards +# compatibility shim that exists to re-export ClusterResolvers such that +# existing OSS code will not be broken. -import six +# pylint: disable=unused-import +from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver +from tensorflow.python.distribute.cluster_resolver.cluster_resolver import SimpleClusterResolver +from tensorflow.python.distribute.cluster_resolver.cluster_resolver import UnionClusterResolver +# pylint: enable=unused-import -from tensorflow.python.training.server_lib import ClusterSpec +from tensorflow.python.util.all_util import remove_undocumented +_allowed_symbols = [ + 'ClusterResolver', + 'SimpleClusterResolver', + 'UnionClusterResolver', +] -def _format_master_url(master, rpc_layer=None): - if rpc_layer: - return '%s://%s' % (rpc_layer, master) - else: - return master +remove_undocumented(__name__, _allowed_symbols) - -@six.add_metaclass(abc.ABCMeta) -class ClusterResolver(object): - """Abstract class for all implementations of ClusterResolvers. - - This defines the skeleton for all implementations of ClusterResolvers. - ClusterResolvers are a way for TensorFlow to communicate with various cluster - management systems (e.g. GCE, AWS, etc...). - - By letting TensorFlow communicate with these systems, we will be able to - automatically discover and resolve IP addresses for various TensorFlow - workers. This will eventually allow us to automatically recover from - underlying machine failures and scale TensorFlow worker clusters up and down. - """ - - @abc.abstractmethod - def cluster_spec(self): - """Retrieve the current state of the cluster and returns a ClusterSpec. - - Returns: - A ClusterSpec representing the state of the cluster at the moment this - function is called. - - Implementors of this function must take care in ensuring that the - ClusterSpec returned is up-to-date at the time of calling this function. - This usually means retrieving the information from the underlying cluster - management system every time this function is invoked and reconstructing - a cluster_spec, rather than attempting to cache anything. - """ - raise NotImplementedError( - 'cluster_spec is not implemented for {}.'.format(self)) - - @abc.abstractmethod - def master(self, task_type=None, task_index=None, rpc_layer=None): - """Retrieves the name or URL of the session master. - - Args: - task_type: (Optional) The type of the TensorFlow task of the master. - task_index: (Optional) The index of the TensorFlow task of the master. - rpc_layer: (Optional) The RPC protocol for the given cluster. - - Returns: - The name or URL of the session master. - - Implementors of this function must take care in ensuring that the master - returned is up-to-date at the time to calling this function. This usually - means retrieving the master every time this function is invoked. - """ - raise NotImplementedError('master is not implemented for {}.'.format(self)) - - -class SimpleClusterResolver(ClusterResolver): - """Simple implementation of ClusterResolver that accepts a ClusterSpec.""" - - def __init__(self, cluster_spec, master='', task_type=None, task_index=None, - environment='', num_accelerators_per_worker=0, - rpc_layer=None): - """Creates a SimpleClusterResolver from a ClusterSpec.""" - super(SimpleClusterResolver, self).__init__() - - self._task_type = task_type - self._task_index = task_index - self._environment = environment - self._num_accelerators_per_worker = num_accelerators_per_worker - self._rpc_layer = rpc_layer - - if not isinstance(cluster_spec, ClusterSpec): - raise TypeError('cluster_spec must be a ClusterSpec.') - self._cluster_spec = cluster_spec - - if not isinstance(master, str): - raise TypeError('master must be a string.') - self._master = master - - def cluster_spec(self): - """Returns the ClusterSpec passed into the constructor.""" - return self._cluster_spec - - def master(self, task_type=None, task_index=None, rpc_layer=None): - """Returns the master address to use when creating a session. - - Args: - task_type: (Optional) The type of the TensorFlow task of the master. - task_index: (Optional) The index of the TensorFlow task of the master. - rpc_layer: (Optional) The RPC used by distributed TensorFlow. - - Returns: - The name or URL of the session master. - - If a task_type and task_index is given, this will override the `master` - string passed into the initialization function. - """ - if task_type is not None and task_index is not None: - master = self.cluster_spec().task_address(task_type, task_index) - else: - master = self._master - - return _format_master_url(master, rpc_layer or self._rpc_layer) - - @property - def task_type(self): - return self._task_type - - @property - def task_index(self): - return self._task_index - - @task_type.setter - def task_type(self, task_type): - self._task_type = task_type - - @task_index.setter - def task_index(self, task_index): - self._task_index = task_index - - @property - def environment(self): - return self._environment - - def num_accelerators_per_worker(self, session_config=None): - """Returns the number of accelerator cores per worker. - - Args: - session_config: Unused. The SimpleClusterResolver does not do automatic - detection of accelerators, so a TensorFlow session will never be - created, and thus a `session_config` is never necessary here, and will - be ignored. - """ - del session_config - return self._num_accelerators_per_worker - - @property - def rpc_layer(self): - return self._rpc_layer - - @rpc_layer.setter - def rpc_layer(self, rpc_layer): - self._rpc_layer = rpc_layer - - -class UnionClusterResolver(ClusterResolver): - """Performs a union on underlying ClusterResolvers. - - This class performs a union given two or more existing ClusterResolvers. It - merges the underlying ClusterResolvers, and returns one unified ClusterSpec - when cluster_spec is called. The details of the merge function is - documented in the cluster_spec function. - - For additional Cluster Resolver properties such as task type, task index, - rpc layer, environment, etc..., we will return the value from the first - ClusterResolver in the union. - """ - - def __init__(self, *args, **kwargs): - """Initializes a UnionClusterResolver with other ClusterResolvers. - - Args: - *args: `ClusterResolver` objects to be unionized. - **kwargs: - rpc_layer - (Optional) Override value for the RPC layer used by - TensorFlow. - task_type - (Optional) Override value for the current task type. - task_index - (Optional) Override value for the current task index. - - Raises: - TypeError: If any argument is not a subclass of `ClusterResolvers`. - ValueError: If there are no arguments passed. - """ - super(UnionClusterResolver, self).__init__() - - self._rpc_layer = kwargs.pop('rpc_layer', None) - self._task_type = kwargs.pop('task_type', None) - self._task_index = kwargs.pop('task_index', None) - - if kwargs: - raise ValueError('Unexpected kwargs provided {!r}'.format(kwargs)) - - if not args: - raise ValueError('At least one ClusterResolver is required.') - - for cluster_resolver in args: - if not isinstance(cluster_resolver, ClusterResolver): - raise TypeError('All arguments must be a sub-class of ' - '`ClusterResolver.`') - self._cluster_resolvers = args - - def cluster_spec(self): - """Returns a union of all the ClusterSpecs from the ClusterResolvers. - - Returns: - A ClusterSpec containing host information merged from all the underlying - ClusterResolvers. - - Raises: - KeyError: If there are conflicting keys detected when merging two or - more dictionaries, this exception is raised. - - Note: If there are multiple ClusterResolvers exposing ClusterSpecs with the - same job name, we will merge the list/dict of workers. - - If *all* underlying ClusterSpecs expose the set of workers as lists, we will - concatenate the lists of workers, starting with the list of workers from - the first ClusterResolver passed into the constructor. - - If *any* of the ClusterSpecs expose the set of workers as a dict, we will - treat all the sets of workers as dicts (even if they are returned as lists) - and will only merge them into a dict if there is no conflicting keys. If - there is a conflicting key, we will raise a `KeyError`. - """ - - merged_cluster = {} - - # We figure out whether it is all lists for a particular job, or whether - # there are dicts inside. - for cluster_resolver in self._cluster_resolvers: - cluster_spec = cluster_resolver.cluster_spec() - cluster_dict = cluster_spec.as_dict() - - for job_name, tasks in cluster_dict.items(): - if job_name in merged_cluster: - # If we see a dict, then we write a dict out regardless. - if isinstance(tasks, dict): - merged_cluster[job_name] = {} - else: - # We take whichever type is present. - if isinstance(tasks, list): - merged_cluster[job_name] = [] - else: - merged_cluster[job_name] = {} - - # We then do the merge as appropriate in merged_cluster[job]. - for cluster_resolver in self._cluster_resolvers: - cluster_spec = cluster_resolver.cluster_spec() - cluster_dict = cluster_spec.as_dict() - - for job_name, tasks in cluster_dict.items(): - if isinstance(merged_cluster[job_name], list): - # We all have lists, we can just concatenate and be done. - merged_cluster[job_name].extend(tasks) - else: - if isinstance(tasks, list): - # We convert to a dictionary if the type is a list. - task_dict = dict(zip(range(0, len(tasks)), tasks)) - else: - # We can simply make a copy (for update) and be done. - task_dict = tasks.copy() - - # We detect if there are duplicates, and raise an error if so. - task_keys = set(task_dict) - merged_keys = set(merged_cluster[job_name].keys()) - intersected_keys = task_keys.intersection(merged_keys) - if intersected_keys: - raise KeyError('Duplicate keys detected when merging two ' - 'ClusterSpecs: %s' % repr(intersected_keys)) - - # We do the merge after all the processing. - merged_cluster[job_name].update(task_dict) - - return ClusterSpec(merged_cluster) - - def master(self, task_type=None, task_index=None, rpc_layer=None): - """Returns the master address to use when creating a session. - - This usually returns the master from the first ClusterResolver passed in, - but you can override this by specifying the task_type and task_index. - - Args: - task_type: (Optional) The type of the TensorFlow task of the master. - task_index: (Optional) The index of the TensorFlow task of the master. - rpc_layer: (Optional) The RPC protocol for the given cluster. - - Returns: - The name or URL of the session master. - """ - if task_type is not None and task_index is not None: - master = self.cluster_spec().task_address(task_type, task_index) - return _format_master_url(master, rpc_layer or self._rpc_layer) - - return self._cluster_resolvers[0].master(rpc_layer=rpc_layer) - - @property - def task_type(self): - return self._task_type or self._cluster_resolvers[0].task_type - - @property - def task_index(self): - return self._task_index or self._cluster_resolvers[0].task_index - - @task_type.setter - def task_type(self, task_type): - self._task_type = task_type - - @task_index.setter - def task_index(self, task_index): - self._task_index = task_index - - @property - def environment(self): - return self._cluster_resolvers[0].environment - - def num_accelerators_per_worker(self, session_config=None): - return self._cluster_resolvers[0].num_accelerators_per_worker( - session_config) - - @property - def rpc_layer(self): - return self._rpc_layer or self._cluster_resolvers[0].rpc_layer - - @rpc_layer.setter - def rpc_layer(self, rpc_layer): - self._rpc_layer = rpc_layer diff --git a/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py index 195b68959b6..55e61155c68 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py @@ -1,4 +1,4 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,197 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Implementation of Cluster Resolvers for GCE Instance Groups.""" +"""Stub file for GceClusterResolver to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function +# This file (and all files in this directory in general) is a backwards +# compatibility shim that exists to re-export ClusterResolvers such that +# existing OSS code will not be broken. -from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver -from tensorflow.python.training.server_lib import ClusterSpec +# pylint: disable=unused-import +from tensorflow.python.distribute.cluster_resolver.gce_cluster_resolver import GceClusterResolver +# pylint: enable=unused-import -_GOOGLE_API_CLIENT_INSTALLED = True -try: - from googleapiclient import discovery # pylint: disable=g-import-not-at-top - from oauth2client.client import GoogleCredentials # pylint: disable=g-import-not-at-top -except ImportError: - _GOOGLE_API_CLIENT_INSTALLED = False +from tensorflow.python.util.all_util import remove_undocumented +_allowed_symbols = [ + 'GceClusterResolver', +] -def _format_master_url(master, rpc_layer=None): - return '%s://%s' % (rpc_layer, master) if rpc_layer else master - - -class GceClusterResolver(ClusterResolver): - """Cluster Resolver for Google Compute Engine. - - This is an implementation of cluster resolvers for the Google Compute Engine - instance group platform. By specifying a project, zone, and instance group, - this will retrieve the IP address of all the instances within the instance - group and return a Cluster Resolver object suitable for use for distributed - TensorFlow. - """ - - def __init__(self, - project, - zone, - instance_group, - port, - task_type='worker', - task_index=0, - rpc_layer='grpc', - num_accelerators_per_worker=0, - credentials='default', - service=None): - """Creates a new GceClusterResolver object. - - This takes in a few parameters and creates a GceClusterResolver project. It - will then use these parameters to query the GCE API for the IP addresses of - each instance in the instance group. - - Args: - project: Name of the GCE project. - zone: Zone of the GCE instance group. - instance_group: Name of the GCE instance group. - port: Port of the listening TensorFlow server (default: 8470) - task_type: Name of the TensorFlow job this GCE instance group of VM - instances belong to. - task_index: The task index for this particular VM, within the GCE - instance group. In particular, every single instance should be assigned - a unique ordinal index within an instance group manually so that they - can be distinguished from each other. - rpc_layer: The RPC layer TensorFlow should use to communicate across - instances. - num_accelerators_per_worker: Number of accelerators (GPUs) present per - instance. - credentials: GCE Credentials. If nothing is specified, this defaults to - GoogleCredentials.get_application_default(). - service: The GCE API object returned by the googleapiclient.discovery - function. (Default: discovery.build('compute', 'v1')). If you specify a - custom service object, then the credentials parameter will be ignored. - - Raises: - ImportError: If the googleapiclient is not installed. - """ - self._project = project - self._zone = zone - self._instance_group = instance_group - self._task_type = task_type - self._task_index = task_index - self._rpc_layer = rpc_layer - self._port = port - self._credentials = credentials - - if credentials == 'default': - if _GOOGLE_API_CLIENT_INSTALLED: - self._credentials = GoogleCredentials.get_application_default() - - if service is None: - if not _GOOGLE_API_CLIENT_INSTALLED: - raise ImportError('googleapiclient must be installed before using the ' - 'GCE cluster resolver') - self._service = discovery.build( - 'compute', 'v1', - credentials=self._credentials) - else: - self._service = service - - def cluster_spec(self): - """Returns a ClusterSpec object based on the latest instance group info. - - This returns a ClusterSpec object for use based on information from the - specified instance group. We will retrieve the information from the GCE APIs - every time this method is called. - - Returns: - A ClusterSpec containing host information retrieved from GCE. - """ - request_body = {'instanceState': 'RUNNING'} - request = self._service.instanceGroups().listInstances( - project=self._project, - zone=self._zone, - instanceGroups=self._instance_group, - body=request_body, - orderBy='name') - - worker_list = [] - - while request is not None: - response = request.execute() - - items = response['items'] - for instance in items: - instance_name = instance['instance'].split('/')[-1] - - instance_request = self._service.instances().get( - project=self._project, - zone=self._zone, - instance=instance_name) - - if instance_request is not None: - instance_details = instance_request.execute() - ip_address = instance_details['networkInterfaces'][0]['networkIP'] - instance_url = '%s:%s' % (ip_address, self._port) - worker_list.append(instance_url) - - request = self._service.instanceGroups().listInstances_next( - previous_request=request, - previous_response=response) - - worker_list.sort() - return ClusterSpec({self._task_type: worker_list}) - - def master(self, task_type=None, task_index=None, rpc_layer=None): - task_type = task_type if task_type is not None else self._task_type - task_index = task_index if task_index is not None else self._task_index - - if task_type is not None and task_index is not None: - master = self.cluster_spec().task_address(task_type, task_index) - if rpc_layer or self._rpc_layer: - return '%s://%s' % (rpc_layer or self._rpc_layer, master) - else: - return master - - return '' - - @property - def task_type(self): - return self._task_type - - @property - def task_index(self): - return self._task_index - - @task_type.setter - def task_type(self, task_type): - raise RuntimeError( - 'You cannot reset the task_type of the GceClusterResolver after it has ' - 'been created.') - - @task_index.setter - def task_index(self, task_index): - self._task_index = task_index - - @property - def environment(self): - """Returns the current environment which TensorFlow is running in. - - For users in the GCE environment, the environment property is always an - empty string, and Google users will not use this ClusterResolver for running - on internal systems. - """ - return '' - - @property - def rpc_layer(self): - return self._rpc_layer - - @rpc_layer.setter - def rpc_layer(self, rpc_layer): - self._rpc_layer = rpc_layer - - def num_accelerators_per_worker(self, session_config=None): - del session_config # Unused, since this is set manually in __init__. - return self._num_accelerators_per_worker - +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/cluster_resolver/python/training/kubernetes_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/kubernetes_cluster_resolver.py index ddae64839f0..a8eaf33629a 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/kubernetes_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/kubernetes_cluster_resolver.py @@ -12,121 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Implementation of Cluster Resolvers for Kubernetes.""" +"""Stub file for KubernetesClusterResolver for backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver -from tensorflow.python.training import server_lib +# This file (and all files in this directory in general) is a backwards +# compatibility shim that exists to re-export ClusterResolvers such that +# existing OSS code will not be broken. -_KUBERNETES_API_CLIENT_INSTALLED = True -try: - from kubernetes import client as k8sclient # pylint: disable=g-import-not-at-top - from kubernetes import config as k8sconfig # pylint: disable=g-import-not-at-top -except ImportError: - _KUBERNETES_API_CLIENT_INSTALLED = False +# pylint: disable=unused-import +from tensorflow.python.distribute.cluster_resolver.kubernetes_cluster_resolver import KubernetesClusterResolver +# pylint: enable=unused-import +from tensorflow.python.util.all_util import remove_undocumented -class KubernetesClusterResolver(ClusterResolver): - """Cluster Resolver for Kubernetes. +_allowed_symbols = [ + 'KubernetesClusterResolver', +] - This is an implementation of cluster resolvers for Kubernetes. When given the - the Kubernetes namespace and label selector for pods, we will retrieve the - pod IP addresses of all running pods matching the selector, and return a - ClusterSpec based on that information. - """ +remove_undocumented(__name__, _allowed_symbols) - def __init__(self, - job_to_label_mapping=None, - tf_server_port=8470, - override_client=None): - """Initializes a new KubernetesClusterResolver. - - This initializes a new Kubernetes Cluster Resolver. The Cluster Resolver - will attempt to talk to the Kubernetes master to retrieve all the instances - of pods matching a label selector. - - Args: - job_to_label_mapping: A mapping of TensorFlow jobs to label selectors. - This allows users to specify many TensorFlow jobs in one Cluster - Resolver, and each job can have pods belong with different label - selectors. For example, a sample mapping might be - ``` - {'worker': ['job-name=worker-cluster-a', 'job-name=worker-cluster-b'], - 'ps': ['job-name=ps-1', 'job-name=ps-2']} - ``` - tf_server_port: The port the TensorFlow server is listening on. - override_client: The Kubernetes client (usually automatically retrieved - using `from kubernetes import client as k8sclient`). If you pass this - in, you are responsible for setting Kubernetes credentials manually. - - Raises: - ImportError: If the Kubernetes Python client is not installed and no - `override_client` is passed in. - """ - if _KUBERNETES_API_CLIENT_INSTALLED: - k8sconfig.load_kube_config() - - if not job_to_label_mapping: - job_to_label_mapping = {'worker': ['job-name=tensorflow']} - - if not override_client and not _KUBERNETES_API_CLIENT_INSTALLED: - raise ImportError('The Kubernetes Python client must be installed before' - 'using the Kubernetes Cluster Resolver. To install the' - 'Kubernetes Python client, run `pip install ' - 'kubernetes` on your command line.') - - self._job_to_label_mapping = job_to_label_mapping - self._tf_server_port = tf_server_port - self._override_client = override_client - - def master(self): - # TODO(frankchn): Figure out a standard way to pass in the current task type - # and task id via Kubernetes. - pass - - def get_master(self): - return self.master() - - def get_job_name(self): - return self._job_name - - def cluster_spec(self): - """Returns a ClusterSpec object based on the latest info from Kubernetes. - - We retrieve the information from the Kubernetes master every time this - method is called. - - Returns: - A ClusterSpec containing host information returned from Kubernetes. - - Raises: - RuntimeError: If any of the pods returned by the master is not in the - `Running` phase. - """ - if not self._override_client: - k8sconfig.load_kube_config() - - client = self._override_client or k8sclient.CoreV1Api() - cluster_map = {} - - for tf_job in self._job_to_label_mapping: - all_pods = [] - for selector in self._job_to_label_mapping[tf_job]: - ret = client.list_pod_for_all_namespaces(label_selector=selector) - selected_pods = [] - - # Sort the list by the name to make sure it doesn't change call to call. - for pod in sorted(ret.items, key=lambda x: x.metadata.name): - if pod.status.phase == 'Running': - selected_pods.append( - '%s:%s' % (pod.status.host_ip, self._tf_server_port)) - else: - raise RuntimeError('Pod "%s" is not running; phase: "%s"' % - (pod.metadata.name, pod.status.phase)) - all_pods.extend(selected_pods) - cluster_map[tf_job] = all_pods - - return server_lib.ClusterSpec(cluster_map) diff --git a/tensorflow/contrib/cluster_resolver/python/training/slurm_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/slurm_cluster_resolver.py index dabe2fe1d39..fcd2a846eeb 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/slurm_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/slurm_cluster_resolver.py @@ -12,185 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Implementation of Cluster Resolvers for Slurm workload manager.""" +"""Stub file for SlurmClusterResolver to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections -import os -import subprocess +# This file (and all files in this directory in general) is a backwards +# compatibility shim that exists to re-export ClusterResolvers such that +# existing OSS code will not be broken. -from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver -from tensorflow.python.training.server_lib import ClusterSpec +# pylint: disable=unused-import +from tensorflow.python.distribute.cluster_resolver.slurm_cluster_resolver import SlurmClusterResolver +# pylint: enable=unused-import +from tensorflow.python.util.all_util import remove_undocumented -class SlurmClusterResolver(ClusterResolver): - """Cluster Resolver for system with Slurm workload manager. +_allowed_symbols = [ + 'SlurmClusterResolver', +] - This is an implementation of cluster resolvers for Slurm clusters. This allows - the specification of jobs and task counts, number of tasks per node, number of - GPUs on each node and number of GPUs for each task, It retrieves system - attributes by Slurm environment variables, resolves allocated computing node - names, construct a cluster and return a Cluster Resolver object which an be - use for distributed TensorFlow. - """ - - def _resolve_hostnames(self): - """Resolve host names of nodes allocated in current jobs. - - Returns: - A list of node names as strings. - """ - hostlist = (subprocess.check_output(['scontrol', 'show', 'hostname']). - decode('utf-8').strip().split('\n')) - return hostlist - - def __init__(self, - jobs, - port_base=8888, - gpus_per_node=1, - gpus_per_task=1, - tasks_per_node=None, - auto_set_gpu=True): - """Creates a new SlurmClusterResolver object. - - This takes in parameters and creates a SlurmClusterResolver object. It uses - those parameters to check which nodes will processes reside and resolves - their hostnames. With the number of the GPUs on each node and number of GPUs - for each task it offsets the port number for each processes and allocate - GPUs to tasks by setting environment variables. The resolver currently - supports homogeneous tasks and default Slurm process allocation. - - Args: - jobs: Dictionary with job names as key and number of tasks in the job as - value - port_base: The first port number to start with for processes on a node. - gpus_per_node: Number of GPUs available on each node. - gpus_per_task: Number of GPUs to be used for each task. - tasks_per_node: Number of tasks to run on each node, if not set defaults - to Slurm's output environment variable SLURM_NTASKS_PER_NODE. - auto_set_gpu: Set the visible CUDA devices automatically while resolving - the cluster by setting CUDA_VISIBLE_DEVICES environment variable. - Defaults to True. - - Returns: - A ClusterResolver object which can be used with distributed TensorFlow. - - Raises: - RuntimeError: If requested more GPUs per node then available or requested - more tasks then assigned tasks. - """ - - # check if launched by mpirun - if 'OMPI_COMM_WORLD_RANK' in os.environ: - self._rank = int(os.environ['OMPI_COMM_WORLD_RANK']) - num_tasks = int(os.environ['OMPI_COMM_WORLD_SIZE']) - else: - self._rank = int(os.environ['SLURM_PROCID']) - num_tasks = int(os.environ['SLURM_NTASKS']) - - self._jobs = collections.OrderedDict(sorted(jobs.items())) - self._port_base = port_base - - # user specification overrides SLURM specification - if tasks_per_node is not None: - self._tasks_per_node = tasks_per_node - elif tasks_per_node is None and 'SLURM_NTASKS_PER_NODE' in os.environ: - self._tasks_per_node = int(os.environ['SLURM_NTASKS_PER_NODE']) - else: - raise RuntimeError('Neither `tasks_per_node` or ' - 'SLURM_NTASKS_PER_NODE is set.') - - self._gpus_per_node = gpus_per_node - self._gpus_per_task = gpus_per_task - - self._auto_set_gpu = auto_set_gpu - self._job_name = None - self._task_index = None - - self._gpu_allocation = [] - self._cluster_allocation = {} - - if self._tasks_per_node * self._gpus_per_task > self._gpus_per_node: - raise RuntimeError('Requested more GPUs per node then available.') - - if sum(self._jobs.values()) != num_tasks: - raise RuntimeError('Requested more tasks then assigned tasks.') - - def cluster_spec(self): - """Returns a ClusterSpec object based on the latest instance group info. - - This returns a ClusterSpec object for use based on information from the - specified initialization parameters and Slurm environment variables. The - cluster specification is resolved each time this function is called. The - resolver extract hostnames of nodes by scontrol and pack tasks in that - order until a node a has number of tasks that is equal to specification. - GPUs on nodes are allocated to tasks by specification through setting - CUDA_VISIBLE_DEVICES environment variable. - - Returns: - A ClusterSpec containing host information retrieved from Slurm's - environment variables. - """ - hostlist = self._resolve_hostnames() - - task_list = [] - self._gpu_allocation = [] - self._cluster_allocation = {} - - for host in hostlist: - for port_offset, gpu_offset in zip( - range(self._tasks_per_node), - range(0, self._gpus_per_node, self._gpus_per_task)): - - host_addr = '%s:%d' % (host, self._port_base + port_offset) - task_list.append(host_addr) - gpu_id_list = [] - - for gpu_id in range(gpu_offset, gpu_offset + self._gpus_per_task): - gpu_id_list.append(str(gpu_id)) - - self._gpu_allocation.append(','.join(gpu_id_list)) - - cluster_rank_offset_start = 0 - cluster_rank_offset_end = 0 - - for job_name, num_tasks in self._jobs.items(): - cluster_rank_offset_end = cluster_rank_offset_start + num_tasks - - self._cluster_allocation[job_name] = \ - task_list[cluster_rank_offset_start:cluster_rank_offset_end] - - if self._rank >= cluster_rank_offset_start and \ - self._rank < cluster_rank_offset_end: - - self._job_name = job_name - self._task_index = self._rank - cluster_rank_offset_start - - cluster_rank_offset_start = cluster_rank_offset_end - - if self._auto_set_gpu is True: - os.environ['CUDA_VISIBLE_DEVICES'] = self._gpu_allocation[self._rank] - - return ClusterSpec(self._cluster_allocation) - - def get_task_info(self): - """Returns job name and task_index for the process which calls this. - - This returns the job name and task index for the process which calls this - function according to its rank and cluster specification. The job name and - task index are set after a cluster is constructed by cluster_spec otherwise - defaults to None. - - Returns: - A string specifying job name the process belongs to and an integner - specifying the task index the process belongs to in that job. - """ - return self._job_name, self._task_index - - def master(self, task_type=None, task_index=None): - if task_type and task_index: - return self.cluster_spec().task_address(task_type, task_index) - return self._cluster_allocation[str(self._job_name)][self._task_index] +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/cluster_resolver/python/training/tfconfig_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tfconfig_cluster_resolver.py index 7bbd189d03d..9db7f47dcb4 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tfconfig_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tfconfig_cluster_resolver.py @@ -12,81 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Implementation of Cluster Resolvers for TF_CONFIG Environment Variables.""" - +"""Stub file for TFConfigClusterResolver to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import json -import os +# This file (and all files in this directory in general) is a backwards +# compatibility shim that exists to re-export ClusterResolvers such that +# existing OSS code will not be broken. -from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver -from tensorflow.python.training.server_lib import ClusterSpec +# pylint: disable=unused-import +from tensorflow.python.distribute.cluster_resolver.tfconfig_cluster_resolver import TFConfigClusterResolver +# pylint: enable=unused-import -_TF_CONFIG_ENV = 'TF_CONFIG' -_SESSION_MASTER_KEY = 'session_master' +from tensorflow.python.util.all_util import remove_undocumented +_allowed_symbols = [ + 'TFConfigClusterResolver', +] -class TFConfigClusterResolver(ClusterResolver): - """Implementation of a ClusterResolver which reads the TF_CONFIG EnvVar.""" +remove_undocumented(__name__, _allowed_symbols) - def _load_tf_config(self): - return json.loads(os.environ.get(_TF_CONFIG_ENV, '{}')) - - def cluster_spec(self): - """Returns a ClusterSpec based on the TF_CONFIG environment variable. - - Returns: - A ClusterSpec with information from the TF_CONFIG environment variable. - """ - tf_config = self._load_tf_config() - if 'cluster' not in tf_config: - return ClusterSpec({}) - return ClusterSpec(tf_config['cluster']) - - def master(self, task_type=None, task_index=0): - """Returns the master address to use when creating a TensorFlow session. - - Args: - task_type: (String, optional) Overrides and sets the task_type of the - master. - task_index: (Integer, optional) Overrides and sets the task id of the - master. - - Returns: - The address of the master. - - Raises: - RuntimeError: If the task_type or task_id is not specified and the - `TF_CONFIG` environment variable does not contain a task section. - """ - - # If `session_master` is set, just use that. - tf_config = self._load_tf_config() - if _SESSION_MASTER_KEY in tf_config: - return tf_config[_SESSION_MASTER_KEY] - - if 'rpc_layer' in tf_config: - rpclayer = '%s://' % tf_config['rpc_layer'] - else: - rpclayer = '' - - # Return an empty string if we are the only job in the ClusterSpec. - cluster_spec = self.cluster_spec() - if (not cluster_spec.jobs or - (len(cluster_spec.jobs) == 1 and - len(cluster_spec.job_tasks(cluster_spec.jobs[0])) == 1)): - return '' - - # We try to auto-detect the task type and id, but uses the user-supplied one - # where available - if not task_type: - if 'task' not in tf_config: - raise RuntimeError('You must either specify a `task_type`, or your ' - 'TF_CONFIG must contain a `task` section.') - task_type = tf_config['task']['type'] - task_index = tf_config['task']['index'] - - return rpclayer + cluster_spec.task_address(task_type, task_index) diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index 1f6803a9ff9..3a1eaccd06e 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -1,4 +1,4 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,341 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Implementation of Cluster Resolvers for Cloud TPUs.""" +"""Stub file for TPUClusterResolver to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os +# This file (and all files in this directory in general) is a backwards +# compatibility shim that exists to re-export ClusterResolvers such that +# existing OSS code will not be broken. -from six.moves.urllib.request import Request -from six.moves.urllib.request import urlopen +# pylint: disable=unused-import +from tensorflow.python.distribute.cluster_resolver.tpu_cluster_resolver import TPUClusterResolver +# pylint: enable=unused-import -from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver -from tensorflow.python.training import server_lib -from tensorflow.python.util import compat +from tensorflow.python.util.all_util import remove_undocumented -_GOOGLE_API_CLIENT_INSTALLED = True -try: - from googleapiclient import discovery # pylint: disable=g-import-not-at-top - from oauth2client.client import GoogleCredentials # pylint: disable=g-import-not-at-top -except ImportError: - _GOOGLE_API_CLIENT_INSTALLED = False +_allowed_symbols = [ + 'TPUClusterResolver', +] - -_GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS' -_ENDPOINTS_SEPARATOR = ',' -_DEFAULT_ENV_VARIABLE = 'TPU_NAME' -_DISCOVERY_SERVICE_URL_ENV_VARIABLE = 'TPU_API_DISCOVERY_URL' - - -class TPUClusterResolver(ClusterResolver): - """Cluster Resolver for Google Cloud TPUs. - - This is an implementation of cluster resolvers for the Google Cloud TPU - service. As Cloud TPUs are in alpha, you will need to specify a API definition - file for this to consume, in addition to a list of Cloud TPUs in your Google - Cloud Platform project. - """ - - def _tpuService(self): - """Creates a new Cloud TPU API object. - - This works around an issue where the underlying HTTP connection sometimes - times out when the script has been running for too long. Other methods in - this object calls this method to get a new API object whenever they need - to communicate with the Cloud API. - - Returns: - A Google Cloud TPU API object. - """ - if self._service: - return self._service - - credentials = self._credentials - if credentials is None or credentials == 'default': - credentials = GoogleCredentials.get_application_default() - - if self._discovery_url: - return discovery.build( - 'tpu', 'v1alpha1', - credentials=credentials, - discoveryServiceUrl=self._discovery_url) - else: - return discovery.build( - 'tpu', 'v1alpha1', - credentials=credentials) - - def _requestComputeMetadata(self, path): - req = Request('http://metadata/computeMetadata/v1/%s' % path, - headers={'Metadata-Flavor': 'Google'}) - resp = urlopen(req) - return compat.as_bytes(resp.read()) - - def _shouldResolve(self): - if (self._tpu == compat.as_bytes('') or - self._tpu == compat.as_bytes('local') or - self._tpu.startswith(compat.as_bytes('/bns')) or - self._tpu.startswith(compat.as_bytes('localhost:')) or - self._tpu.startswith(compat.as_bytes('grpc://'))): - return False - return True - - @staticmethod - def _inGke(): - """When running in GKE, the environment variable will be set.""" - return _GKE_ENV_VARIABLE in os.environ - - @staticmethod - def _gkeEndpoints(): - return os.environ[_GKE_ENV_VARIABLE] - - @staticmethod - def _envVarFallback(): - if _DEFAULT_ENV_VARIABLE in os.environ: - return os.environ[_DEFAULT_ENV_VARIABLE] - return None - - @staticmethod - def _environmentDiscoveryUrl(): - return os.environ.get(_DISCOVERY_SERVICE_URL_ENV_VARIABLE) - - def __init__(self, - tpu=None, - zone=None, - project=None, - job_name='worker', - coordinator_name=None, - coordinator_address=None, - credentials='default', - service=None, - discovery_url=None): - """Creates a new TPUClusterResolver object. - - The ClusterResolver will then use the parameters to query the Cloud TPU APIs - for the IP addresses and ports of each Cloud TPU listed. - - Args: - tpu: Either a string, or a list of strings corresponding to the TPUs to - use. If the single string is the empty string, the string 'local', or a - string that begins with 'grpc://' or '/bns', then it is assumed to not - correspond with a Cloud TPU and will instead be passed as the session - master and no ClusterSpec propagation will be done. - zone: Zone where the TPUs are located. If omitted or empty, we will assume - that the zone of the TPU is the same as the zone of the GCE VM, which we - will try to discover from the GCE metadata service. - project: Name of the GCP project containing Cloud TPUs. If omitted or - empty, we will try to discover the project name of the GCE VM from the - GCE metadata service. - job_name: Name of the TensorFlow job the TPUs belong to. - coordinator_name: The name to use for the coordinator. Set to None if the - coordinator should not be included in the computed ClusterSpec. - coordinator_address: The address of the coordinator (typically an ip:port - pair). If set to None, a TF server will be started. If coordinator_name - is None, a TF server will not be started even if coordinator_address is - None. - credentials: GCE Credentials. If None, then we use default credentials - from the oauth2client - service: The GCE API object returned by the googleapiclient.discovery - function. If you specify a custom service object, then the credentials - parameter will be ignored. - discovery_url: A URL template that points to the location of - the discovery service. It should have two parameters {api} and - {apiVersion} that when filled in produce an absolute URL to the - discovery document for that service. The environment variable - 'TPU_API_DISCOVERY_URL' will override this. - - Raises: - ImportError: If the googleapiclient is not installed. - ValueError: If no TPUs are specified. - """ - if isinstance(tpu, list): - if not tpu: - raise ValueError('At least one TPU must be specified.') - if len(tpu) != 1: - raise NotImplementedError( - 'Using multiple TPUs in a single session is not yet implemented') - tpu = tpu[0] - - in_gke = self._inGke() - # When using GKE with Cloud TPUs, the env variable will be set. - if tpu is None: - if in_gke: - tpu = self._gkeEndpoints() - else: - tpu = self._envVarFallback() - - if tpu is None: - raise ValueError('Please provide a TPU Name to connect to.') - - self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes - self._job_name = job_name - - # Whether we should actually attempt to contact Cloud APIs - should_resolve = self._shouldResolve() - - # We error out if we are in a non-Cloud environment which cannot talk to the - # Cloud APIs using the standard class and a special object is not passed in. - self._service = service - if (self._service is None and should_resolve and - not _GOOGLE_API_CLIENT_INSTALLED): - raise ImportError('googleapiclient and oauth2client must be installed ' - 'before using the TPU cluster resolver. Execute: ' - '`pip install --upgrade google-api-python-client` ' - 'and `pip install --upgrade oauth2client` to ' - 'install with pip.') - - # We save user-passed credentials, unless the user didn't pass in anything. - self._credentials = credentials - if (credentials == 'default' and should_resolve and - _GOOGLE_API_CLIENT_INSTALLED): - self._credentials = None - - # Automatically detect project and zone if unspecified. - if not project and should_resolve: - project = compat.as_str( - self._requestComputeMetadata('project/project-id')) - if not zone and should_resolve: - zone_path = compat.as_str(self._requestComputeMetadata('instance/zone')) - zone = zone_path.split('/')[-1] - self._project = project - self._zone = zone - - self._discovery_url = self._environmentDiscoveryUrl() or discovery_url - - self._coordinator_name = coordinator_name - if (coordinator_name and not coordinator_address and - (should_resolve or in_gke)): - self._start_local_server() - else: - self._coordinator_address = coordinator_address - - def master(self, task_type=None, task_index=None): - """Get the Master string to be used for the session. - - In the normal case, this returns the grpc path (grpc://1.2.3.4:8470) of - first instance in the ClusterSpec returned by the cluster_spec function. - - If a non-TPU name is used when constructing a TPUClusterResolver, that will - be returned instead (e.g. If the tpus argument's value when constructing - this TPUClusterResolver was 'grpc://10.240.1.2:8470', - 'grpc://10.240.1.2:8470' will be returned). - - Args: - task_type: (Optional) The type of the TensorFlow task of the master. - task_index: (Optional) The index of the TensorFlow task of the master. - - Returns: - string, the connection string to use when creating a session. - - Raises: - ValueError: If none of the TPUs specified exists. - """ - if not self._shouldResolve(): - return self._tpu.split(compat.as_bytes(_ENDPOINTS_SEPARATOR))[0] - - cluster_spec = self.cluster_spec() - if task_type and task_index: - return cluster_spec.task_address(task_type, task_index) - - job_tasks = cluster_spec.job_tasks(self._job_name) - if not job_tasks: - raise ValueError('No TPUs exists with the specified names exist.') - - return 'grpc://' + job_tasks[0] - - def get_master(self): - return self.master() - - def get_job_name(self): - if self._shouldResolve(): - return self._job_name - - def cluster_spec(self): - """Returns a ClusterSpec object based on the latest TPU information. - - We retrieve the information from the GCE APIs every time this method is - called. - - Returns: - A ClusterSpec containing host information returned from Cloud TPUs. - - Raises: - RuntimeError: If the provided TPU is not healthy. - """ - ############################################################################ - # There are 5 potential cases this code must handle: - # 1. [Normal case.] We should resolve the TPU name to a set of tasks, and - # a. Create a ClusterSpec that includes the coordinator job - # b. Create a ClusterSpec without the coordinator job. - # 2. [GKE / No API Access.] We should not resolve the TPU name to a set of - # tasks and - # a. Create a ClusterSpec with the coordinator - # b. Create a ClusterSpec without the coordinator - # 3. [Other (legacy non-gRPC).] We should return an empty ClusterSpec. - ############################################################################ - - if self._shouldResolve(): - # Case 1. - full_name = 'projects/%s/locations/%s/nodes/%s' % ( - self._project, self._zone, compat.as_text(self._tpu)) - service = self._tpuService() - request = service.projects().locations().nodes().get(name=full_name) - response = request.execute() - - if 'state' in response and response['state'] != 'READY': - raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' % - (compat.as_text(self._tpu), response['state'])) - - if 'health' in response and response['health'] != 'HEALTHY': - raise RuntimeError('TPU "%s" is unhealthy: "%s"' % - (compat.as_text(self._tpu), response['health'])) - - if 'networkEndpoints' in response: - worker_list = [ - '%s:%s' % (endpoint['ipAddress'], endpoint['port']) - for endpoint in response['networkEndpoints'] - ] - else: - # Fall back to the deprecated response format - instance_url = '%s:%s' % (response['ipAddress'], response['port']) - worker_list = [instance_url] - - cluster_spec = {self._job_name: worker_list} - else: - if not self._tpu.startswith(compat.as_bytes('grpc://')): - # Case 3. - return None - # Case 2. - cluster_spec = { - self._job_name: [ - x[len(compat.as_bytes('grpc://')):] - for x in self._tpu.split(compat.as_bytes(_ENDPOINTS_SEPARATOR)) - ] - } - - if self._coordinator_address: - # {1, 2}.a - cluster_spec[self._coordinator_name] = [self._coordinator_address] - - return server_lib.ClusterSpec(cluster_spec) - - def _start_local_server(self): - address = self._requestComputeMetadata('instance/network-interfaces/0/ip') - self._server = server_lib.Server( - { - 'local': ['0.0.0.0:0'] - }, protocol='grpc', config=None, start=True) - # self._server.target is of the form: grpc://ipaddress:port - target = compat.as_bytes(self._server.target) - splits = target.split(compat.as_bytes(':')) - assert len(splits) == 3, self._server.target - assert splits[0] == compat.as_bytes('grpc'), self._server.target - self._coordinator_port = compat.as_text(splits[2]) - self._coordinator_address = '%s:%s' % ( - address, compat.as_text(self._coordinator_port)) - - def __deepcopy__(self, memo): - # TODO(b/73668574): Remove this once RunConfig avoids performing deepcopy. - return self +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt index a63366e1361..124d6cfd478 100644 --- a/tensorflow/contrib/cmake/CMakeLists.txt +++ b/tensorflow/contrib/cmake/CMakeLists.txt @@ -12,7 +12,7 @@ if(WIN32) endif() # Project -project(tensorflow C CXX) +project(tensorflow VERSION 1.12.0 LANGUAGES C CXX) # Set C++14 as standard for the whole project set(CMAKE_CXX_STANDARD 14) @@ -193,6 +193,7 @@ if(WIN32) set(CMAKE_SUPPRESS_REGENERATION ON) endif() + if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-exceptions -std=c++11") endif() @@ -281,6 +282,14 @@ else (systemlib_ZLIB) ${zlib_STATIC_LIBRARIES}) endif (systemlib_ZLIB) +if (systemlib_ABSEIL_CPP) + set(tensorflow_EXTERNAL_LIBRARIES ${tensorflow_EXTERNAL_LIBRARIES} + ${abseil_cpp_LIBRARIES}) +else (systemlib_ABSEIL_CPP) + set(tensorflow_EXTERNAL_LIBRARIES ${tensorflow_EXTERNAL_LIBRARIES} + ${abseil_cpp_STATIC_LIBRARIES}) +endif (systemlib_ABSEIL_CPP) + set(tensorflow_EXTERNAL_DEPENDENCIES zlib_copy_headers_to_destination gif_copy_headers_to_destination @@ -394,6 +403,7 @@ if (tensorflow_ENABLE_GPU) set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};--include-path ${PROJECT_BINARY_DIR}/$\{build_configuration\};--expt-relaxed-constexpr) set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-ftz=true) # Flush denormals to zero set(CUDA_INCLUDE ${CUDA_TOOLKIT_TARGET_DIR} ${CUDA_TOOLKIT_TARGET_DIR}/extras/CUPTI/include) + include_directories(${CUDA_INCLUDE}) if (WIN32) add_definitions(-DGOOGLE_CUDA=1 -DTF_EXTRA_CUDA_CAPABILITIES=3.7,5.2,6.0,6.1,7.0) @@ -546,14 +556,20 @@ if (tensorflow_ENABLE_GPU) cudnn_version_number=${tensorflow_CUDNN_VERSION}) endif(WIN32) else(tensorflow_ENABLE_GPU) - set(tensorflow_BUILD_INFO_FLAGS --build_config cpu --key_value - msvcp_dll_name=msvcp140.dll) + if(WIN32) + set(tensorflow_BUILD_INFO_FLAGS --build_config cpu --key_value + msvcp_dll_name=msvcp140.dll) + else() + set(tensorflow_BUILD_INFO_FLAGS --build_config cpu) + endif() endif(tensorflow_ENABLE_GPU) -# Find python executable -include(FindPythonInterp) -if(NOT ${PYTHONINTERP_FOUND}) - message(FATAL_ERROR "CMake was unable to find a python interpreter.") +if(tensorflow_BUILD_PYTHON_BINDINGS) + # Find python executable + include(FindPythonInterp) + if(NOT ${PYTHONINTERP_FOUND}) + message(FATAL_ERROR "CMake was unable to find a python interpreter.") + endif() endif() # Let's get to work! @@ -574,6 +590,7 @@ include(tf_cc_ops.cmake) include(tf_c.cmake) include(tf_grappler.cmake) include(tf_core_profiler.cmake) +include(tf_core_eager_runtime.cmake) if(tensorflow_BUILD_CC_EXAMPLE) include(tf_tutorials.cmake) include(tf_label_image_example.cmake) @@ -587,4 +604,4 @@ if(tensorflow_BUILD_SHARED_LIB) endif() if(tensorflow_BUILD_CC_TESTS OR tensorflow_BUILD_PYTHON_TESTS) include(tf_tests.cmake) -endif() +endif() \ No newline at end of file diff --git a/tensorflow/contrib/cmake/README.md b/tensorflow/contrib/cmake/README.md index 84c679162c3..df5ff6cd532 100644 --- a/tensorflow/contrib/cmake/README.md +++ b/tensorflow/contrib/cmake/README.md @@ -6,9 +6,9 @@ platforms. For details, see the [TensorFlow install guide](https://www.tensorflow.org/install/). This directory contains CMake files for building TensorFlow on Microsoft -Windows. [CMake](https://cmake.org) is a cross-platform tool that can +Windows and Linux. [CMake](https://cmake.org) is a cross-platform tool that can generate build scripts for multiple build systems, including Microsoft -Visual Studio. +Visual Studio and GCC. "The method has not been tested on Mac OS X. **N.B.** We provide Linux build instructions primarily for the purpose of testing the build. We recommend using the standard Bazel-based build on @@ -23,6 +23,7 @@ for instructions on how to install a pre-built TensorFlow package on Windows. ### Current known limitations * It is not possible to load a custom Op library. * GCS file system is not supported. +* Debug build is not available since Python for Windows is no longer distributed with a debug library. ## Building with CMake @@ -53,12 +54,12 @@ bindings. ### Known-good configurations * Microsoft Windows 10 - - Microsoft Visual Studio Enterprise 2015 with Visual C++ 2015 + - Microsoft Visual Studio Enterprise/ Community 2015 with Visual C++ 2015 - [Anaconda 4.1.1 (Python 3.5 64-bit)](https://www.anaconda.com/download/) - [Git for Windows version 2.9.2.windows.1](https://git-scm.com/download/win) - [swigwin-3.0.10](http://www.swig.org/download.html) - - [NVidia CUDA Toolkit 8.0](https://developer.nvidia.com/cuda-downloads) - - [NVidia CUDNN 5.1](https://developer.nvidia.com/cudnn) + - [NVidia CUDA Toolkit 9.0](https://developer.nvidia.com/cuda-downloads) + - [NVidia CUDNN 7](https://developer.nvidia.com/cudnn) - [CMake 3.6](https://cmake.org/files/v3.6/cmake-3.6.3-win64-x64.msi) * Ubuntu 14.04 @@ -66,8 +67,8 @@ bindings. - Docker 1.9.1 (for automated testing) ### Current known limitations - - The Python package supports **Python 3.5 only**, because that is the only - version for which standard Python binaries exist and those binaries are + - The Python package supports **Python 3.5/3.6 only**, because these are the only + versions for which standard Python binaries exist and those binaries are compatible with the TensorFlow runtime. (On Windows, the standard Python binaries for versions earlier than 3.5 were compiled with older compilers that do not have all of the features (e.g. C++11 support) needed to compile @@ -104,8 +105,151 @@ We are actively working on improving CMake and Windows support, and addressing these limitations. We would appreciate pull requests that implement missing ops or APIs. +CMake GUI build (all platforms) +================================== +Install from CMake GUI would be a convenient way to generate C++ build projects. The software supports Windows, MacOS and Linux, while the posix platform provides an extra ccmake binary to run command line GUI. Both working principal of cmake, ccmake and cmake-gui are the same, the only difference is by providing suitable interface for project configuration and dependency setting. -Step-by-step Windows build +0. Pre-buid checklist: + The following binary/libraries should be setted in system path, otherwise you need to set manualy via cmake. + * Compiler (GCC for Linux, MSVC for Windows) + * Make sure compiler directory has been set to system path + * CUDA 9.0 (GPU build) + * CUDNN (GPU build) + * NCCL (GPU build on Linux) + * SWIG (python binding) + * Perl (required if you need ssl support, optional) + * Go (required if you need ssl support, optional) + * NASM/YASM (required by grpc for ssl support, optional) +1. Start CMake GUI +2. Click on `Browse Source` and direct to the the folder `/tensorflow/contrib/cmake` +3. Click on `Browse Build` and spectify a location that you want tensorflow to be build +4. Click on `Configure`, a new window will be prompted out, specify the generator mode for the project generation. For Windows, choose `Visual Studio Win64`, for Linux, choose `Unix Makefiles`, then press `Finish`. Wait for a moment, the default project dependecy would automatically generate. +5. There are a few options that you can customize your own build. **The setting here is crucial for a sucessful build, please check all items carefully.** + * `tensorflow_BUILD_ALL_KERNELS` should alway be `on` + * `tensorflow_BUILD_CC_EXAMPLE` is default to be `on`. This can help you to test build (optional) + * `tensorflow_BUILD_CONTRIB_KERNELS` is default to be `on`, but it won't affect tensorflow function, turn it to `off` if you want a slim build. (optional) + * `tensorflow_BUILD_PYTHON_BINDING` is default to be `on`. Set to `off` if you don't need python interaface. If SWIG is not in system path, you need set it manually. (optional) + * `tensorflow_BUILD_SHARED_LIB` is default to be `off`. Set to `on` if you want the c++ interface. (optional) + * `tensorflow_ENABLE_GPU` is default to be `off`. Set to `on` if you want GPU support. It will search CUDA and CUDNN dependecies if you have set them to system path, otherwise CMake would prompt error and request you to set it manually. (optional) + * `tensorflow_ENABLE_GRPC_SUPPORT` is default to be `on`. For Linux build, this option must always be `on`. This need to be `on` for a gpu build. Reminded that Perl, Go and NASM/YASM are required for this option if you want to build grpc with offical SSL support. + * `tensorflow_ENABLE_POSITION_INDEPENDENT_CODE` should always be `on` + * `tensorflow_ENABLE_SNAPPY_SUPPORT` should always be `on` + * `tensorflow_OPTIMIZE_FOR_NATIVE_ARCH` should always be `on` + * `CMAKE_INSTALL_PREFIX` is the location where the final package will be installed. You may change it to your own preferred path (optional) + +6. After changing the configuration in step 5, press `Configure` again +7. If not error is found, press `Generate` + +#### Windows + +1. Open `tensorflow.sln` in the build folder (Windows). Change build type from `Debug` to `Release`. Choose `Build`->`Build Solution`. This may take more than hours of compilation. If everything is alright, the output window would show no error. + + ##### Python + + In solution explorer, right click on `tf_python_build_pip_package` -> `build`. It will generate the wheel file in `/tf_python/dist`. Install with following command: + + ```pip install --upgrade tensorflow-.whl``` + + ***The wheel name varies depends on you config. Change to your own wheel filename.*** + + Reminded that some pip installation requires administrator right command prompt. + + ##### C++ + + You can directly use the build folder tree for C++ interface with cmake. If you want to do installation for api releasing, right click on `Install` -> `build`. The headers and library will be installed in the directory specify by `CMAKE_INSTALL_PREFIX` during configuration. + +2. For smaller RAM computer, it is noticed that out of heap space error appears. Change to command prompt build is an alternative to do step 1. + + Open `VS2015 x64 Native Tools Command Prompt`. You can open it by press `Start`, then type the binary name. Use `VS2017 x64 Native Tools Command Prompt` if you are using MSVC 2017. + + ##### Python + + Directly build python wheel package by following command: + + ```MSBuild /p:Configuration=Release ``` + + Remember to change `` to the actual path of the file, it can be found at the root of build directory + + Install the wheel file generated as instructed by step 1. + + ##### C++ interface + Build from VS native toolchain with following command: + ```MSBuild /p:Configuration=Release ``` + + Headers are discretely located in the build folders. Tensorflow library can be found at `/Release`, namely `tensorflow.dll` and `tensorflow.lib`. + + * Build to install for api release (optional): + ```MSBuild /p:Configuration=Release ``` + + Remember to change `` and `` to the actual path of the file, it can be found at the root of build directory. + +#### Linux/MacOS (command line GNU build) + +1. Open the terminal, change working directory to the one specified in step 3. + +2. Type the following command: + + ```make -sj all``` + + ##### Python + + **Important Note** CMake generated python wheel for Linux/MacOs is currently under development. Please use bazel build. + + Follow code is an expected Linux/MacOS python package build after development work is completed. + + ``` + make -sj tf_python_build_pip_package + cd tf_python + pip install --upgrade tensorflow-.whl + ``` + + ##### C++ interface + + ```make -sj install``` + + Where `` is the threads used for the compilation, change to any integer less or equal to your computer's maxiumum thread number. + + Headers are discretely located in the build folders. Tensorflow library can be found at ``, namely `tensorflow.so` (Linux) or `tensorflow.dylib` (MacOS). + +#### Start a Tensorflow C++ project with CMake +Here we assume that you have basic knowledge on gathering dependency with `CMakeLists.txt`. Here we introduce how the C++ api works with [official hello world tutorial](https://www.tensorflow.org/api_guides/cc/guide). + +1. Create a new working directory and create a new text file named `CMakeLists.txt` and the c++ file `main.cxx` +2. Fill in the `main.cxx` with the code provided in [official c++ api basic](https://www.tensorflow.org/api_guides/cc/guide). +3. Fill in the `CMakeLists.txt` with following code: + ``` cmake + cmake_minimum_required (VERSION 2.6) + project (tf_hello) + + # Tensorflow + find_package(Tensorflow REQUIRED) + include_directories(${TENSORFLOW_INCLUDE_DIRS}) + + # compiler setting required by tensorflow, to be tested on all compilers + # currently only tested on MSVC and GCC + if (${CMAKE_CXX_COMPILER_ID} STREQUAL MSVC) + add_definitions(-DCOMPILER_MSVC) + elseif (${CMAKE_CXX_COMPILER_ID} STREQUAL GNU) + if (${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS "3") + add_definitions(-DCOMPILER_GCC3) + else() + add_definitions(-D__GNUC__) + endif() + else() + message(ERROR " compiler ${CMAKE_CXX_COMPILER_ID} not supported by this CMakeList.txt, under development") + endif() + + add_executable(tf_hello main.cxx) + target_link_libraries(tf_hello ${TENSORFLOW_LIBRARIES}) + ``` +4. Configure the folder with cmake-gui, an error should be prompted out, requesting you to locate the folder containing `TensorflowConfig.cmake`. This file can be found at `` or `` (for those have build install in previous steps). + +5. Configure again, generate the project. +6. Compile the project with `Release` config (Windows). For Linux users, just compile the project. +7. Copy the `tensorflow.dll`(Windows)/`tensorflow.so`(Linux) from build directory to the build folder containing `tf_hello` binary. +8. Run `tf_hello` binary + +Step-by-step Windows build (command prompt) ========================== 1. Install the prerequisites detailed above, and set up your environment. @@ -292,4 +436,4 @@ $ cd tensorflow $ tensorflow/tools/ci_build/ci_build.sh CMAKE tensorflow/tools/ci_build/builds/cmake.sh ``` -That's it. Dependencies included. +That's it. Dependencies included. \ No newline at end of file diff --git a/tensorflow/contrib/cmake/TensorflowConfig.cmake.in b/tensorflow/contrib/cmake/TensorflowConfig.cmake.in new file mode 100644 index 00000000000..cc04db6e952 --- /dev/null +++ b/tensorflow/contrib/cmake/TensorflowConfig.cmake.in @@ -0,0 +1,16 @@ +# - Config file for the Tensorflow package +# It defines the following variables +# TENSORFLOW_INCLUDE_DIRS - include directories for FooBar +# TENSORFLOW_LIBRARIES - libraries to link against + +# Compute paths +get_filename_component(TENSORFLOW_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH) +set(TENSORFLOW_INCLUDE_DIRS "@CONF_INCLUDE_DIRS@") + +# Our library dependencies (contains definitions for IMPORTED targets) +if(NOT TENSORFLOW_BINARY_DIR) + include("${TENSORFLOW_CMAKE_DIR}/TensorflowTargets.cmake") +endif() + +# These are IMPORTED targets created by TensorflowTargets.cmake +set(TENSORFLOW_LIBRARIES tensorflow) \ No newline at end of file diff --git a/tensorflow/contrib/cmake/TensorflowConfigVersion.cmake.in b/tensorflow/contrib/cmake/TensorflowConfigVersion.cmake.in new file mode 100644 index 00000000000..2a9609ddb9c --- /dev/null +++ b/tensorflow/contrib/cmake/TensorflowConfigVersion.cmake.in @@ -0,0 +1,11 @@ +set(PACKAGE_VERSION "@TENSORFLOW_VERSION@") + +# Check whether the requested PACKAGE_FIND_VERSION is compatible +if("${PACKAGE_VERSION}" VERSION_LESS "${PACKAGE_FIND_VERSION}") + set(PACKAGE_VERSION_COMPATIBLE FALSE) +else() + set(PACKAGE_VERSION_COMPATIBLE TRUE) + if ("${PACKAGE_VERSION}" VERSION_EQUAL "${PACKAGE_FIND_VERSION}") + set(PACKAGE_VERSION_EXACT TRUE) + endif() +endif() \ No newline at end of file diff --git a/tensorflow/contrib/cmake/external/abseil_cpp.cmake b/tensorflow/contrib/cmake/external/abseil_cpp.cmake index 4546dbdecc0..46a193971c5 100644 --- a/tensorflow/contrib/cmake/external/abseil_cpp.cmake +++ b/tensorflow/contrib/cmake/external/abseil_cpp.cmake @@ -31,27 +31,24 @@ if (systemlib_ABSEIL_CPP) message(STATUS " abseil_cpp includes: ${ABSEIL_CPP_INCLUDE_DIR}") message(STATUS " abseil_cpp libraries: ${ABSEIL_CPP_LIBRARIES}") - add_custom_target(abseil_cpp_build) - list(APPEND tensorflow_EXTERNAL_DEPENDENCIES abseil_cpp_build) + add_custom_target(abseil_cpp) + list(APPEND tensorflow_EXTERNAL_DEPENDENCIES abseil_cpp) else (systemlib_ABSEIL_CPP) include (ExternalProject) - set(abseil_cpp_INCLUDE_DIR ${CMAKE_BINARY_DIR}/abseil_cpp/src/abseil_cpp_build) + set(abseil_cpp_INCLUDE_DIR ${CMAKE_BINARY_DIR}/abseil_cpp/src/abseil_cpp) set(abseil_cpp_URL https://github.com/abseil/abseil-cpp/archive/e01d95528ea2137a4a27a88d1f57c6cb260aafed.tar.gz) set(abseil_cpp_HASH SHA256=84043ed402d2a2a6ba4cdddb7e85118b1158fd81fe4ac3a14adc343d054c1e2e) - set(abseil_cpp_BUILD ${CMAKE_BINARY_DIR}/abseil_cpp/src/abseil_cpp_build) + set(abseil_cpp_BUILD ${CMAKE_BINARY_DIR}/abseil_cpp/src/abseil_cpp-build) if(WIN32) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") set(abseil_cpp_STATIC_LIBRARIES ${abseil_cpp_BUILD}/absl/base/Release/absl_base.lib - ${abseil_cpp_BUILD}/absl/base/Release/absl_spinlock_wait.lib ${abseil_cpp_BUILD}/absl/base/Release/absl_dynamic_annotations.lib - ${abseil_cpp_BUILD}/absl/base/Release/absl_malloc_internal.lib - ${abseil_cpp_BUILD}/absl/base/Release/absl_throw_delegate.lib - ${abseil_cpp_BUILD}/absl/numeric/Release/absl_int128.lib + ${abseil_cpp_BUILD}/absl/base/Release/absl_internal_malloc_internal.lib ${abseil_cpp_BUILD}/absl/strings/Release/absl_strings.lib ${abseil_cpp_BUILD}/absl/strings/Release/str_format_internal.lib ${abseil_cpp_BUILD}/absl/types/Release/absl_bad_optional_access.lib) @@ -80,15 +77,12 @@ else (systemlib_ABSEIL_CPP) ${abseil_cpp_BUILD}/absl/types/libabsl_bad_optional_access.a) endif() - ExternalProject_Add(abseil_cpp_build + ExternalProject_Add(abseil_cpp PREFIX abseil_cpp URL ${abseil_cpp_URL} URL_HASH ${abseil_cpp_HASH} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" - BUILD_IN_SOURCE 1 BUILD_BYPRODUCTS ${abseil_cpp_STATIC_LIBRARIES} - BUILD_COMMAND ${CMAKE_COMMAND} --build . --config Release - COMMAND ${CMAKE_COMMAND} --build . --config Release INSTALL_COMMAND "" CMAKE_CACHE_ARGS -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE} @@ -99,6 +93,6 @@ else (systemlib_ABSEIL_CPP) include_directories(${abseil_cpp_INCLUDE_DIR}) list(APPEND tensorflow_EXTERNAL_LIBRARIES ${abseil_cpp_STATIC_LIBRARIES}) - list(APPEND tensorflow_EXTERNAL_DEPENDENCIES abseil_cpp_build) + list(APPEND tensorflow_EXTERNAL_DEPENDENCIES abseil_cpp) -endif (systemlib_ABSEIL_CPP) +endif (systemlib_ABSEIL_CPP) \ No newline at end of file diff --git a/tensorflow/contrib/cmake/external/png.cmake b/tensorflow/contrib/cmake/external/png.cmake index 1a147e9c8e5..32e6d78e508 100644 --- a/tensorflow/contrib/cmake/external/png.cmake +++ b/tensorflow/contrib/cmake/external/png.cmake @@ -59,6 +59,7 @@ ExternalProject_Add(png -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF -DCMAKE_INSTALL_PREFIX:STRING=${png_INSTALL} -DZLIB_ROOT:STRING=${ZLIB_INSTALL} + -DPNG_TESTS:BOOL=OFF ) ## put png includes in the directory where they are expected diff --git a/tensorflow/contrib/cmake/modules/FindAbseilCpp.cmake b/tensorflow/contrib/cmake/modules/FindAbseilCpp.cmake index d4f8bb1bec9..944ae3997a9 100644 --- a/tensorflow/contrib/cmake/modules/FindAbseilCpp.cmake +++ b/tensorflow/contrib/cmake/modules/FindAbseilCpp.cmake @@ -24,10 +24,10 @@ if(EXISTS "${ABSEIL_CPP_INCLUDE_DIR}" AND NOT "${ABSEIL_CPP_INCLUDE_DIR}" STREQU # search all libraries if no COMPONENTS was requested set(AbseilCpp_FIND_COMPONENTS "absl_algorithm;absl_any;absl_bad_any_cast" - "absl_bad_optional_access;absl_base absl_container;absl_debugging" + "absl_bad_optional_access;absl_base;absl_container;absl_debugging" "absl_dynamic_annotations;absl_examine_stack;absl_failure_signal_handler" - "absl_int128;absl_leak_check;absl_malloc_internal;absl_memory;absl_meta" - "absl_numeric;absl_optional;absl_span;absl_spinlock_wait;absl_stack_consumption" + "absl_int128;absl_leak_check;absl_internal_malloc_internal;absl_memory;absl_meta" + "absl_numeric;absl_optional;absl_span;absl_internal_spinlock_wait;absl_stack_consumption" "absl_stacktrace;absl_str_format;absl_strings;absl_symbolize;absl_synchronization" "absl_throw_delegate;absl_time;absl_utility;str_format_extension_internal" "str_format_internal;test_instance_tracker_lib") diff --git a/tensorflow/contrib/cmake/tf_c.cmake b/tensorflow/contrib/cmake/tf_c.cmake index 7a30eb94f54..a04142bd249 100644 --- a/tensorflow/contrib/cmake/tf_c.cmake +++ b/tensorflow/contrib/cmake/tf_c.cmake @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== + ######################################################## # tf_c_framework library ######################################################## diff --git a/tensorflow/contrib/cmake/tf_core_cpu.cmake b/tensorflow/contrib/cmake/tf_core_cpu.cmake index a54cbff33b6..d8884d464fb 100644 --- a/tensorflow/contrib/cmake/tf_core_cpu.cmake +++ b/tensorflow/contrib/cmake/tf_core_cpu.cmake @@ -39,6 +39,8 @@ file(GLOB_RECURSE tf_core_cpu_exclude_srcs "${tensorflow_source_dir}/tensorflow/core/*test*.h" "${tensorflow_source_dir}/tensorflow/core/*test*.cc" "${tensorflow_source_dir}/tensorflow/core/*main.cc" + "${tensorflow_source_dir}/tensorflow/core/common_runtime/eager/*.cc" + "${tensorflow_source_dir}/tensorflow/core/common_runtime/eager/*.h" "${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu/*.cc" "${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu_device_factory.cc" "${tensorflow_source_dir}/tensorflow/core/common_runtime/direct_session.cc" diff --git a/tensorflow/contrib/cmake/tf_core_eager_runtime.cmake b/tensorflow/contrib/cmake/tf_core_eager_runtime.cmake new file mode 100644 index 00000000000..78e4c0d3035 --- /dev/null +++ b/tensorflow/contrib/cmake/tf_core_eager_runtime.cmake @@ -0,0 +1,57 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +######################################################## +# tf_core_eager_runtime library +######################################################## +file(GLOB_RECURSE tf_core_eager_runtime_srcs + "${tensorflow_source_dir}/tensorflow/core/common_runtime/eager/*.cc" + "${tensorflow_source_dir}/tensorflow/core/common_runtime/eager/*.h" +) + +file(GLOB_RECURSE tf_core_eager_runtime_exclude_srcs + "${tensorflow_source_dir}/tensorflow/core/common_runtime/eager/*test*.h" + "${tensorflow_source_dir}/tensorflow/core/common_runtime/eager/*test*.cc" +) + +list(REMOVE_ITEM tf_core_eager_runtime_srcs ${tf_core_eager_runtime_exclude_srcs}) + +add_library(tf_core_eager_runtime OBJECT ${tf_core_eager_runtime_srcs}) +add_dependencies( + tf_core_eager_runtime + tf_c + tf_core_lib) + + +file(GLOB_RECURSE tf_c_eager_srcs + "${tensorflow_source_dir}/tensorflow/c/eager/*.cc" + "${tensorflow_source_dir}/tensorflow/c/eager/*.h" +) + +file(GLOB_RECURSE tf_c_eager_exlclude_srcs + "${tensorflow_source_dir}/tensorflow/c/eager/*test*.h" + "${tensorflow_source_dir}/tensorflow/c/eager/*test*.cc" +) + +list(REMOVE_ITEM tf_c_eager_srcs ${tf_c_eager_exlclude_srcs}) + +add_library(tf_c_eager OBJECT ${tf_c_eager_srcs}) +add_dependencies( + tf_c_eager + tf_core_eager_runtime + tf_c + tf_cc_framework + tf_cc_while_loop + tf_core_lib + tf_protos_cc) \ No newline at end of file diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index 7e806685b84..d7b2a1339e0 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -140,16 +140,19 @@ set(tf_proto_text_srcs "tensorflow/core/example/example.proto" "tensorflow/core/example/feature.proto" "tensorflow/core/framework/allocation_description.proto" + "tensorflow/core/framework/api_def.proto" "tensorflow/core/framework/attr_value.proto" "tensorflow/core/framework/cost_graph.proto" "tensorflow/core/framework/device_attributes.proto" "tensorflow/core/framework/function.proto" "tensorflow/core/framework/graph.proto" "tensorflow/core/framework/graph_transfer_info.proto" + "tensorflow/core/framework/iterator.proto" "tensorflow/core/framework/kernel_def.proto" "tensorflow/core/framework/log_memory.proto" "tensorflow/core/framework/node_def.proto" "tensorflow/core/framework/op_def.proto" + "tensorflow/core/framework/reader_base.proto" "tensorflow/core/framework/remote_fused_graph_execute_info.proto" "tensorflow/core/framework/resource_handle.proto" "tensorflow/core/framework/step_stats.proto" @@ -159,6 +162,7 @@ set(tf_proto_text_srcs "tensorflow/core/framework/tensor_shape.proto" "tensorflow/core/framework/tensor_slice.proto" "tensorflow/core/framework/types.proto" + "tensorflow/core/framework/variable.proto" "tensorflow/core/framework/versions.proto" "tensorflow/core/lib/core/error_codes.proto" "tensorflow/core/protobuf/cluster.proto" @@ -204,10 +208,10 @@ file(GLOB tf_core_platform_srcs "${tensorflow_source_dir}/tensorflow/core/framework/resource_handle.h" "${tensorflow_source_dir}/tensorflow/core/framework/resource_handle.cc") if (NOT tensorflow_ENABLE_GPU) - file(GLOB tf_core_platform_gpu_srcs + file(GLOB tf_core_platform_gpu_srcs_exclude "${tensorflow_source_dir}/tensorflow/core/platform/cuda_libdevice_path.*" "${tensorflow_source_dir}/tensorflow/core/platform/default/cuda_libdevice_path.*") - list(REMOVE_ITEM tf_core_platform_srcs ${tf_core_platform_gpu_srcs}) + list(REMOVE_ITEM tf_core_platform_srcs ${tf_core_platform_gpu_srcs_exclude}) else() file(GLOB tf_core_platform_srcs_exclude "${tensorflow_source_dir}/tensorflow/core/platform/default/device_tracer.cc") diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index 9cfa8b90749..6e75963313a 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -13,13 +13,14 @@ # limitations under the License. # ============================================================================== set(tf_op_lib_names - "audio_ops" "array_ops" + "audio_ops" "batch_ops" "bitwise_ops" "boosted_trees_ops" "candidate_sampling_ops" "checkpoint_ops" + "collective_ops" "control_flow_ops" "ctc_ops" "cudnn_rnn_ops" @@ -32,8 +33,8 @@ set(tf_op_lib_names "io_ops" "linalg_ops" "list_ops" - "lookup_ops" "logging_ops" + "lookup_ops" "manip_ops" "math_ops" "nn_ops" @@ -43,10 +44,11 @@ set(tf_op_lib_names "remote_fused_graph_ops" "resource_variable_ops" "rpc_ops" + "scoped_allocator_ops" "script_ops" "sdca_ops" - "set_ops" "sendrecv_ops" + "set_ops" "sparse_ops" "spectral_ops" "state_ops" @@ -54,6 +56,7 @@ set(tf_op_lib_names "string_ops" "summary_ops" "training_ops" + "word2vec_ops" ) foreach(tf_op_lib_name ${tf_op_lib_names}) diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index df7b854afcc..50284985982 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -313,15 +313,14 @@ function(GENERATE_PYTHON_OP_LIB tf_python_op_lib_name) ${GENERATE_PYTHON_OP_LIB_DESTINATION} PARENT_SCOPE) endfunction() -GENERATE_PYTHON_OP_LIB("audio_ops") GENERATE_PYTHON_OP_LIB("array_ops") +GENERATE_PYTHON_OP_LIB("audio_ops") GENERATE_PYTHON_OP_LIB("batch_ops") GENERATE_PYTHON_OP_LIB("bitwise_ops") GENERATE_PYTHON_OP_LIB("boosted_trees_ops") -GENERATE_PYTHON_OP_LIB("math_ops") -GENERATE_PYTHON_OP_LIB("functional_ops") GENERATE_PYTHON_OP_LIB("candidate_sampling_ops") GENERATE_PYTHON_OP_LIB("checkpoint_ops") +GENERATE_PYTHON_OP_LIB("collective_ops") GENERATE_PYTHON_OP_LIB("control_flow_ops" ADDITIONAL_LIBRARIES $) GENERATE_PYTHON_OP_LIB("ctc_ops") @@ -332,14 +331,18 @@ GENERATE_PYTHON_OP_LIB("decode_proto_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/proto/python/ops/gen_decode_proto_op.py) GENERATE_PYTHON_OP_LIB("encode_proto_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/proto/python/ops/gen_encode_proto_op.py) +GENERATE_PYTHON_OP_LIB("function_ops") +GENERATE_PYTHON_OP_LIB("functional_ops") GENERATE_PYTHON_OP_LIB("image_ops") GENERATE_PYTHON_OP_LIB("io_ops") GENERATE_PYTHON_OP_LIB("linalg_ops") GENERATE_PYTHON_OP_LIB("list_ops") GENERATE_PYTHON_OP_LIB("logging_ops") GENERATE_PYTHON_OP_LIB("lookup_ops") -GENERATE_PYTHON_OP_LIB("nn_ops") GENERATE_PYTHON_OP_LIB("manip_ops") +GENERATE_PYTHON_OP_LIB("math_ops") +GENERATE_PYTHON_OP_LIB("nn_ops") +GENERATE_PYTHON_OP_LIB("no_op") GENERATE_PYTHON_OP_LIB("parsing_ops") GENERATE_PYTHON_OP_LIB("random_ops") GENERATE_PYTHON_OP_LIB("remote_fused_graph_ops" @@ -347,17 +350,21 @@ GENERATE_PYTHON_OP_LIB("remote_fused_graph_ops" GENERATE_PYTHON_OP_LIB("resource_variable_ops") GENERATE_PYTHON_OP_LIB("rpc_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/rpc/python/ops/gen_rpc_op.py) +GENERATE_PYTHON_OP_LIB("scoped_allocator_ops") GENERATE_PYTHON_OP_LIB("script_ops") GENERATE_PYTHON_OP_LIB("sdca_ops") +GENERATE_PYTHON_OP_LIB("sendrecv_ops") GENERATE_PYTHON_OP_LIB("set_ops") -GENERATE_PYTHON_OP_LIB("state_ops") GENERATE_PYTHON_OP_LIB("sparse_ops") GENERATE_PYTHON_OP_LIB("spectral_ops") +GENERATE_PYTHON_OP_LIB("state_ops") +GENERATE_PYTHON_OP_LIB("stateless_random_ops") GENERATE_PYTHON_OP_LIB("string_ops") GENERATE_PYTHON_OP_LIB("summary_ops") GENERATE_PYTHON_OP_LIB("user_ops") GENERATE_PYTHON_OP_LIB("training_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/training/gen_training_ops.py) +GENERATE_PYTHON_OP_LIB("word2vec_ops") GENERATE_PYTHON_OP_LIB("contrib_boosted_trees_model_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/boosted_trees/python/ops/gen_model_ops.py) @@ -391,11 +398,8 @@ GENERATE_PYTHON_OP_LIB("contrib_layers_sparse_feature_cross_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/layers/ops/gen_sparse_feature_cross_op.py) GENERATE_PYTHON_OP_LIB("contrib_memory_stats_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/memory_stats/ops/gen_memory_stats_ops.py) -GENERATE_PYTHON_OP_LIB("contrib_nccl_ops" - DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/nccl/ops/gen_nccl_ops.py) GENERATE_PYTHON_OP_LIB("contrib_periodic_resample_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/periodic_resample/python/ops/gen_periodic_resample_op.py) - GENERATE_PYTHON_OP_LIB("contrib_nearest_neighbor_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/nearest_neighbor/ops/gen_nearest_neighbor_ops.py) GENERATE_PYTHON_OP_LIB("contrib_resampler_ops" @@ -524,11 +528,13 @@ if(WIN32) add_library(pywrap_tensorflow_internal_static STATIC ${pywrap_tensorflow_internal_src} $ + $ $ $ $ $ $ + $ $ $ $ @@ -581,11 +587,13 @@ endif(WIN32) add_library(pywrap_tensorflow_internal SHARED ${pywrap_tensorflow_internal_src} $ + $ $ $ $ $ $ + $ $ $ $ @@ -615,13 +623,28 @@ target_include_directories(pywrap_tensorflow_internal PUBLIC ${NUMPY_INCLUDE_DIR} ) -target_link_libraries(pywrap_tensorflow_internal PRIVATE +if(CMAKE_COMPILER_IS_GNUCC AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 5.0) + # There is a bug in GCC 5 resulting in undefined reference to a __cpu_model function when + # linking to the tensorflow library. Adding the following libraries fixes it. + # See issue on github: https://github.com/tensorflow/tensorflow/issues/9593 + target_link_libraries(pywrap_tensorflow_internal PRIVATE + ${tf_core_gpu_kernels_lib} + ${tensorflow_EXTERNAL_LIBRARIES} + tf_protos_cc + tf_python_protos_cc + ${PYTHON_LIBRARIES} + gcc_s + gcc +) +else() + target_link_libraries(pywrap_tensorflow_internal PRIVATE ${tf_core_gpu_kernels_lib} ${tensorflow_EXTERNAL_LIBRARIES} tf_protos_cc tf_python_protos_cc ${PYTHON_LIBRARIES} ) +endif() if(WIN32) diff --git a/tensorflow/contrib/cmake/tf_shared_lib.cmake b/tensorflow/contrib/cmake/tf_shared_lib.cmake index fdf522f1fd9..62005dd113b 100644 --- a/tensorflow/contrib/cmake/tf_shared_lib.cmake +++ b/tensorflow/contrib/cmake/tf_shared_lib.cmake @@ -23,6 +23,8 @@ if(WIN32) # we need. # add_library(tensorflow_static STATIC + $ + $ $ $ $ @@ -65,6 +67,8 @@ endif(WIN32) # tensorflow is a shared library containing all of the # TensorFlow runtime and the standard ops and kernels. add_library(tensorflow SHARED + $ + $ $ $ $ @@ -96,6 +100,27 @@ if(CMAKE_COMPILER_IS_GNUCC AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 5.0) target_link_libraries(tensorflow PRIVATE gcc_s gcc) endif() +# Offer the user the choice of overriding the installation directories +set(INSTALL_LIB_DIR lib CACHE PATH "Installation directory for libraries") +set(INSTALL_BIN_DIR bin CACHE PATH "Installation directory for executables") +set(INSTALL_INCLUDE_DIR include CACHE PATH + "Installation directory for header files") +if(WIN32 AND NOT CYGWIN) + set(DEF_INSTALL_CMAKE_DIR cmake) +else() + set(DEF_INSTALL_CMAKE_DIR lib/cmake) +endif() +set(INSTALL_CMAKE_DIR ${DEF_INSTALL_CMAKE_DIR} CACHE PATH + "Installation directory for CMake files") + +# Make relative paths absolute (needed later on) +foreach(p LIB BIN INCLUDE CMAKE) + set(var INSTALL_${p}_DIR) + if(NOT IS_ABSOLUTE "${${var}}") + set(${var} "${CMAKE_INSTALL_PREFIX}/${${var}}") + endif() +endforeach() + if(WIN32) add_dependencies(tensorflow tensorflow_static) endif(WIN32) @@ -103,14 +128,57 @@ endif(WIN32) target_include_directories(tensorflow PUBLIC $) -install(TARGETS tensorflow EXPORT tensorflow_export - RUNTIME DESTINATION bin - LIBRARY DESTINATION lib - ARCHIVE DESTINATION lib) +# Add all targets to build-tree export set +export(TARGETS tensorflow + FILE ${PROJECT_BINARY_DIR}/TensorflowTargets.cmake) + +# Export the package for use from the build-tree +export(PACKAGE Tensorflow) + +# Create the TensorflowConfig.cmake and TensorflowConfigVersion files +file(RELATIVE_PATH REL_INCLUDE_DIR "${INSTALL_CMAKE_DIR}" + "${INSTALL_INCLUDE_DIR}") +# for the build tree +set(CONF_INCLUDE_DIRS "${tensorflow_source_dir}" + "${PROJECT_BINARY_DIR}" + "${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/src" + "${CMAKE_CURRENT_BINARY_DIR}/nsync/install/include" # Please if there is a better directory + "${CMAKE_CURRENT_BINARY_DIR}/eigen/src/eigen/Eigen/" + "${CMAKE_CURRENT_BINARY_DIR}/external/eigen_archive/" + "${tensorflow_source_dir}/third_party/eigen3/" + "${CMAKE_CURRENT_BINARY_DIR}/eigen/src/eigen/unsupported/Eigen/") +configure_file(TensorflowConfig.cmake.in + "${PROJECT_BINARY_DIR}/TensorflowConfig.cmake" @ONLY) +# for the install tree, yet to be complete +set(CONF_INCLUDE_DIRS "\${TENSORFLOW_CMAKE_DIR}/${REL_INCLUDE_DIR}") +configure_file(TensorflowConfig.cmake.in + "${PROJECT_BINARY_DIR}/${CMAKE_FILES_DIRECTORY}/TensorflowConfig.cmake" @ONLY) +# for both +configure_file(TensorflowConfigVersion.cmake.in + "${PROJECT_BINARY_DIR}/TensorflowConfigVersion.cmake" @ONLY) + +# install(TARGETS tensorflow EXPORT tensorflow_export +# RUNTIME DESTINATION ${INSTALL_BIN_DIR} +# LIBRARY DESTINATION ${INSTALL_LIB_DIR} +# ARCHIVE DESTINATION ${INSTALL_LIB_DIR}) + +# install(EXPORT tensorflow_export +# FILE TensorflowConfig.cmake +# DESTINATION ${INSTALL_CMAKE_DIR}) -install(EXPORT tensorflow_export - FILE TensorflowConfig.cmake - DESTINATION lib/cmake) +install(FILES + "${PROJECT_BINARY_DIR}/${CMAKE_FILES_DIRECTORY}/TensorflowConfig.cmake" + "${PROJECT_BINARY_DIR}/TensorflowConfigVersion.cmake" + DESTINATION "${INSTALL_CMAKE_DIR}" COMPONENT dev) + +# install the export set for use with the install-tree +install(EXPORT TensorflowTargets + DESTINATION ${INSTALL_CMAKE_DIR}) + +install(TARGETS tensorflow EXPORT TensorflowTargets + RUNTIME DESTINATION ${INSTALL_BIN_DIR} + LIBRARY DESTINATION ${INSTALL_LIB_DIR} + ARCHIVE DESTINATION ${INSTALL_LIB_DIR}) # install necessary headers # tensorflow headers @@ -145,6 +213,10 @@ install(DIRECTORY ${tensorflow_source_dir}/third_party/eigen3/ # unsupported Eigen directory install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/eigen/src/eigen/unsupported/Eigen/ DESTINATION include/unsupported/Eigen) +# absl directory +install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/abseil_cpp/src/abseil_cpp/absl/ + DESTINATION include/absl + FILES_MATCHING PATTERN "*.h") # mkl if (tensorflow_ENABLE_MKL_SUPPORT) install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/include/ diff --git a/tensorflow/contrib/compiler/BUILD b/tensorflow/contrib/compiler/BUILD index 1630f010ab6..e4566437c60 100644 --- a/tensorflow/contrib/compiler/BUILD +++ b/tensorflow/contrib/compiler/BUILD @@ -58,6 +58,7 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/compiler/jit:xla_ops_py", + "//tensorflow/compiler/jit/ops:xla_ops_grad", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_ops", diff --git a/tensorflow/contrib/compiler/xla.py b/tensorflow/contrib/compiler/xla.py index 335ac794648..f867cd15b67 100644 --- a/tensorflow/contrib/compiler/xla.py +++ b/tensorflow/contrib/compiler/xla.py @@ -23,6 +23,7 @@ import contextlib from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.compiler.jit.ops import xla_ops +from tensorflow.compiler.jit.ops import xla_ops_grad # pylint: disable=unused-import from tensorflow.core.framework import attr_value_pb2 from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.framework import ops diff --git a/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py b/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py index 41258edd908..6926c0d03fe 100644 --- a/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py +++ b/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py @@ -74,8 +74,8 @@ class ConstrainedMinimizationProblem(object): if (constraints_shape.ndims is None or proxy_constraints_shape.ndims is None or - any([ii is None for ii in constraints_shape.as_list()]) or - any([ii is None for ii in proxy_constraints_shape.as_list()])): + any(ii is None for ii in constraints_shape.as_list()) or + any(ii is None for ii in proxy_constraints_shape.as_list())): raise ValueError( "constraints and proxy_constraints must have fully-known shapes") if constraints_shape != proxy_constraints_shape: diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py index 656633f0bf2..40e159b8fcb 100644 --- a/tensorflow/contrib/crf/python/ops/crf.py +++ b/tensorflow/contrib/crf/python/ops/crf.py @@ -38,12 +38,12 @@ tf_unary_scores, tf_sequence_lengths, tf_transition_params, _ = session.run( [unary_scores, sequence_lengths, transition_params, train_op]) for tf_unary_scores_, tf_sequence_length_ in zip(tf_unary_scores, tf_sequence_lengths): -# Remove padding. -tf_unary_scores_ = tf_unary_scores_[:tf_sequence_length_] + # Remove padding. + tf_unary_scores_ = tf_unary_scores_[:tf_sequence_length_] -# Compute the highest score and its tag sequence. -tf_viterbi_sequence, tf_viterbi_score = tf.contrib.crf.viterbi_decode( - tf_unary_scores_, tf_transition_params) + # Compute the highest score and its tag sequence. + tf_viterbi_sequence, tf_viterbi_score = tf.contrib.crf.viterbi_decode( + tf_unary_scores_, tf_transition_params) """ from __future__ import absolute_import diff --git a/tensorflow/contrib/cudnn_rnn/BUILD b/tensorflow/contrib/cudnn_rnn/BUILD index 670b5494327..8d35622e393 100644 --- a/tensorflow/contrib/cudnn_rnn/BUILD +++ b/tensorflow/contrib/cudnn_rnn/BUILD @@ -42,10 +42,11 @@ tf_custom_op_py_library( cuda_py_test( name = "cudnn_rnn_ops_test", - size = "large", + size = "medium", srcs = ["python/kernel_tests/cudnn_rnn_ops_test.py"], additional_deps = [ ":cudnn_rnn_py", + "@absl_py//absl/testing:parameterized", "//tensorflow/core:protos_all_py", "//tensorflow/contrib/rnn:rnn_py", "//tensorflow/python/ops/losses:losses", @@ -61,7 +62,7 @@ cuda_py_test( "//tensorflow/python:training", "//tensorflow/python:variables", ], - shard_count = 6, + shard_count = 2, tags = [ "noasan", # http://b/62067814 "requires-gpu-sm35", diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py index ae839108ebe..a268415f0e6 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py @@ -18,24 +18,30 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import itertools import os import unittest +from absl.testing import parameterized import numpy as np from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops from tensorflow.core.protobuf import saver_pb2 +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.framework.test_util import TensorFlowTestCase from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gradient_checker -from tensorflow.python.ops import math_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import init_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import rnn +from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import googletest from tensorflow.python.platform import test @@ -56,714 +62,989 @@ CUDNN_RNN_TANH_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_RNN_TANH_PARAMS_PER_LAYER CUDNN_RNN_RELU_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_RNN_RELU_PARAMS_PER_LAYER -def _CreateModel(rnn_mode, - num_layers, - num_units, - input_size, - input_mode="linear_input", - direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION, - dtype=dtypes.float32, - dropout=0.): - del input_mode - if rnn_mode == cudnn_rnn_ops.CUDNN_LSTM: - model_fn = cudnn_rnn_ops.CudnnLSTM - elif rnn_mode == cudnn_rnn_ops.CUDNN_GRU: - model_fn = cudnn_rnn_ops.CudnnGRU - elif rnn_mode == cudnn_rnn_ops.CUDNN_RNN_TANH: - model_fn = cudnn_rnn_ops.CudnnRNNTanh - elif rnn_mode == cudnn_rnn_ops.CUDNN_RNN_RELU: - model_fn = cudnn_rnn_ops.CudnnRNNRelu +def RunLSTM(sess, + num_units, + input_size, + batch_size, + time, + num_layers=1, + is_training=True, + dropout=0., + num_dirs=True, + dtype=dtypes.float32): + # TODO(jamesqin): add multi-layer tests. + # TODO(jamesqin): add multi-dir tests + assert num_layers == 1 + assert num_dirs == 1 + if is_training and not np.isclose(dropout, 0): + raise ValueError("dropout can not be 0. when test training.") + + # set graph level random seed and numpy random seed. + random_seed.set_random_seed(0) + np.random.seed(0) + + inputs = variable_scope.get_variable( + "inputs", + initializer=np.random.rand(time, batch_size, + input_size).astype(dtype.as_numpy_dtype), + dtype=dtype) + initial_h_op = variable_scope.get_variable( + "initial_h_op", + initializer=np.random.rand(batch_size, + num_units).astype(dtype.as_numpy_dtype), + dtype=dtype) + initial_c_op = variable_scope.get_variable( + "initial_c_op", + initializer=np.random.rand(batch_size, + num_units).astype(dtype.as_numpy_dtype), + dtype=dtype) + + initializer = init_ops.random_uniform_initializer( + -0.01, 0.01, dtype=dtype, seed=19980904) + + with variable_scope.variable_scope("test", initializer=initializer): + w = variable_scope.get_variable( + "rnn/lstm_cell/kernel", + shape=[input_size + num_units, num_units * 4], + dtype=dtype) + b = variable_scope.get_variable( + "rnn/lstm_cell/bias", shape=[num_units * 4], dtype=dtype) + + # canonical lstm. must set forget_bias to 0. to align with cudnn lstm. + cell = rnn_cell_impl.LSTMCell(num_units, forget_bias=0., reuse=True) + outputs_op, state_tuple_op = rnn.dynamic_rnn( + cell, + inputs, + initial_state=rnn_cell_impl.LSTMStateTuple( + h=initial_h_op, c=initial_c_op), + dtype=dtype, + time_major=True, + scope=None) + + # Convert to cudnn opaque param. + format_converter = cudnn_rnn_ops.CudnnParamsFormatConverterLSTM( + num_layers, num_units, input_size) + opaque_params = format_converter.tf_canonical_to_opaque([w, b]) + + cu_initial_h_op = array_ops.expand_dims(initial_h_op, axis=0) + cu_initial_c_op = array_ops.expand_dims(initial_c_op, axis=0) + cu_outputs_op, cu_h_op, cu_c_op = cudnn_rnn_ops._cudnn_rnn( + inputs, + cu_initial_h_op, + cu_initial_c_op, + opaque_params, + dropout=dropout, + is_training=is_training, + rnn_mode=cudnn_rnn_ops.CUDNN_LSTM) + # Remove the trivial 1st dimension. + cu_state_tuple_op = rnn_cell_impl.LSTMStateTuple( + c=array_ops.squeeze(cu_c_op, axis=0), + h=array_ops.squeeze(cu_h_op, axis=0)) + + if is_training: + (inp_grad_op, hgrad_op, + cgrad_op, wgrad_op, bgrad_op) = gradients_impl.gradients( + outputs_op, [inputs, initial_h_op, initial_c_op, w, b]) + + (cu_inp_grad_op, cu_hgrad_op, + cu_cgrad_op, opaque_grad_op) = gradients_impl.gradients( + cu_outputs_op, + [inputs, cu_initial_h_op, cu_initial_c_op, opaque_params]) + # Remove the trivial 1st dimension + cu_hgrad_op = array_ops.squeeze(cu_hgrad_op, axis=0) + # Remove the trivial 1st dimension + cu_cgrad_op = array_ops.squeeze(cu_cgrad_op, axis=0) + + cu_wgrad_op, cu_bgrad_op = format_converter.opaque_to_tf_canonical( + opaque_grad_op) + cu_wgrad_op = cu_wgrad_op[0] + cu_bgrad_op = cu_bgrad_op[0] + # cudnn lstm has 2 biases each gate. When converting to tf canonical format, + # the two biases are summed into one. Thus here bias gradient should be + # halved when comparing with tf lstm. + cu_bgrad_op *= 0.5 + + init_op = variables.global_variables_initializer() + sess.run(init_op) + + if is_training: + outputs, state_tuple, inp_grad, state_grad, wgrad, bgrad = sess.run([ + outputs_op, state_tuple_op, inp_grad_op, + (hgrad_op, cgrad_op), wgrad_op, bgrad_op + ]) + (cu_outputs, cu_state_tuple, cu_inp_grad, cu_state_grad, cu_wgrad, + cu_bgrad) = sess.run([ + cu_outputs_op, cu_state_tuple_op, cu_inp_grad_op, + (cu_hgrad_op, cu_cgrad_op), cu_wgrad_op, cu_bgrad_op + ]) + + logging.vlog(1, "outputs: %s" % outputs) + logging.vlog(1, "cu_outputs: %s" % cu_outputs) + logging.vlog(1, "state_tuple: %s" % str(state_tuple)) + logging.vlog(1, "cu_state_tuple: %s" % str(cu_state_tuple)) + logging.vlog(1, "inp_grad: %s" % inp_grad) + logging.vlog(1, "cu_inp_grad: %s" % cu_inp_grad) + logging.vlog(1, "state_grad: %s" % str(state_grad)) + logging.vlog(1, "cu_state_grad: %s" % str(cu_state_grad)) + logging.vlog(1, "wgrad: %s" % str(wgrad)) + logging.vlog(1, "bgrad: %s" % str(bgrad)) + logging.vlog(1, "cu_wgrad: %s" % str(cu_wgrad)) + logging.vlog(1, "cu_bgrad: %s" % str(cu_bgrad)) + return (outputs, cu_outputs, state_tuple, cu_state_tuple, inp_grad, + cu_inp_grad, state_grad, cu_state_grad, wgrad, bgrad, cu_wgrad, + cu_bgrad) else: - raise ValueError("Invalid rnn_mode: %s" % rnn_mode) - return model_fn( - num_layers, - num_units, - input_size, - direction=direction, - dtype=dtype, - dropout=dropout) + outputs, state_tuple = sess.run([outputs_op, state_tuple_op]) + cu_outputs, cu_state_tuple = sess.run([cu_outputs_op, cu_state_tuple_op]) + + logging.vlog(1, "outputs: %s" % outputs) + logging.vlog(1, "cu_outputs: %s" % cu_outputs) + logging.vlog(1, "state_tuple: %s" % str(state_tuple)) + logging.vlog(1, "cu_state_tuple: %s" % str(cu_state_tuple)) + return outputs, cu_outputs, state_tuple, cu_state_tuple -def _CreateParamsSavable(params, - model, - base_variable_scope=None, - name="params_canonical"): - """Create a RNNParamsSaveable for the weight and bias parameters. +# Basic set of RNN configs to test. They can be further extended in relevant +# test (e.g. adding num_dirs). +NAMED_RNN_TESTCASES = ({ + "testcase_name": "xsmall", + "num_units": 1, + "input_size": 1, + "batch_size": 1, + "time": 1, + "num_layers": 1, +}, { + "testcase_name": "small", + "num_units": 4, + "input_size": 4, + "batch_size": 4, + "time": 4, + "num_layers": 1, +}, { + "testcase_name": "medium", + "num_units": 128, + "input_size": 64, + "batch_size": 8, + "time": 16, + "num_layers": 1, +}, { + "testcase_name": "large", + "num_units": 128, + "input_size": 128, + "batch_size": 16, + "time": 32, + "num_layers": 1, +}) + + +def ExpandNamedTestCases(inputs, *remove_keys, **extra_configs): + """Expands testcase with new config dimensions. + + Example: + inputs = ( + {'testcase_name': 'test1', 'gender': 'male'} + {'testcase_name': 'test2', 'gender': 'female'} + ) + remove_keys: empty + extra_configs = { + 'age': [40, 80] + 'height': [5, 6] + } + + Returns: + ( + {'testcase_name': 'test1_age_40_height_5','gender': 'male', 'age': + 40,'height': 5} + {'testcase_name': 'test1_age_40_height_6', 'gender': 'male', 'age': 40, + 'height': 6} + {'testcase_name': 'test1_age_80_height_5', 'gender': 'male', 'age': 80, + 'height': 5} + {'testcase_name': 'test1_age_80_height_6', 'gender': 'male', 'age': 80, + 'height': 6} + + {'testcase_name': 'test2_age_40_height_5', 'gender': 'female', 'age': + 40, + 'height': 5} + {'testcase_name': 'test2_age_40_height_6', 'gender': 'female', 'age': + 40, + 'height': 6} + {'testcase_name': 'test2_age_80_height_5', 'gender': 'female', 'age': + 80, + 'height': 5} + {'testcase_name': 'test2_age_80_height_6', 'gender': 'female', 'age': + 80, + 'height': 6} + ) Args: - params: a Variable for weight and bias parameters. - model: a CudnnRNN model. - base_variable_scope: a string, prefix of names of saved variables. - name: a string, name of the RNNParamsSaveable object. + inputs: A list of dictionary, each being a testcase. + *remove_keys: A list of keys into testcase which are not needed in new + testcases. + **extra_configs: A dict of new test dimension and applicable values in that + dimension. + Returns: - a RNNParamsSaveable object. + A list of dictionary with expanded test cases. """ - if model._rnn_mode == CUDNN_LSTM: - fn = cudnn_rnn_ops.CudnnLSTMSaveable - elif model._rnn_mode == CUDNN_GRU: - fn = cudnn_rnn_ops.CudnnGRUSaveable - elif model._rnn_mode == CUDNN_RNN_TANH: - fn = cudnn_rnn_ops.CudnnRNNTanhSaveable - elif model._rnn_mode == CUDNN_RNN_RELU: - fn = cudnn_rnn_ops.CudnnRNNReluSaveable - params_saveable = fn( - params, - model.num_layers, - model.num_units, - model.input_size, - model.input_mode, - model.direction, - scope=base_variable_scope, - name=name) - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, params_saveable) - return params_saveable + res = [] + ordered_extra_configs = collections.OrderedDict(extra_configs) + keys = ordered_extra_configs.keys() + # A list of list of configs. + # The outer loop is iterating keys, the innner is values of one key. + combined_kv = [[(k, v) for v in ordered_extra_configs[k]] for k in keys] + logging.info("combined_kv: %s", combined_kv) + + for inp in inputs: + # Each inp is a dict + for config in itertools.product(*combined_kv): + new_inp = dict(inp) + # config is a list in the form of [(k_i, v_j), (k_p, v_q), ...] + suffix = ["%s_%s" % (p[0], str(p[1])) for p in config] + suffix = "_".join(suffix) + new_inp["testcase_name"] += "_" + suffix + for k, v in config: + new_inp[k] = v + # Remove not used keys from the new test case. + if remove_keys: + if not isinstance(remove_keys, (list, tuple)): + remove_keys = [remove_keys] + for k in remove_keys: + new_inp.pop(k, None) + logging.info("new_inp: %s", new_inp) + res.append(new_inp) + # Dedup, necessary if `remove_keys` is set. + return [dict(t) for t in {tuple(d.items()) for d in res}] -def _MinLSTMParamSize(num_layers, - num_units, - input_size, - direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION): - if direction == cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION: - first_layer_weights = 4 * num_units * (num_units + input_size) - higher_layer_weights = 8 * (num_layers - 1) * num_units * num_units - all_biases = 8 * num_layers * num_units - return first_layer_weights + higher_layer_weights + all_biases - elif direction == cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION: - first_layer_weights = 4 * num_units * (num_units + input_size) - higher_layer_weights = (num_layers - 1) * ( - 4 * 2 * num_units * num_units + 4 * num_units**2) - all_biases = 8 * num_layers * num_units - return 2 * (first_layer_weights + higher_layer_weights + all_biases) +class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): + + def _test_training_helper(self, + num_units, + input_size, + batch_size, + time, + num_layers, + dtype, + rtol=2e-6, + atol=2e-6): + with self.session(use_gpu=True) as sess: + (outputs, cu_outputs, state_tuple, cu_state_tuple, inp_grad, cu_inp_grad, + state_grad, cu_state_grad, wgrad, bgrad, cu_wgrad, cu_bgrad) = RunLSTM( + sess, num_units, input_size, batch_size, time, num_layers) + + self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) + for s, cu_s in zip(state_tuple, cu_state_tuple): + self.assertAllClose(s, cu_s, rtol=rtol, atol=atol) + for sg, cu_sg in zip(state_grad, cu_state_grad): + self.assertAllClose(sg, cu_sg, rtol=rtol, atol=atol) + self.assertAllClose(inp_grad, cu_inp_grad, rtol=rtol, atol=atol) + self.assertAllClose(bgrad, cu_bgrad, rtol=rtol, atol=atol) + self.assertAllClose(wgrad, cu_wgrad, rtol=rtol, atol=atol) + + @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_training(self, num_units, input_size, batch_size, time, num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + self._test_training_helper(num_units, input_size, batch_size, time, + num_layers, dtypes.float32) + + @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_training_fp16(self, num_units, input_size, batch_size, time, + num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + self._test_training_helper( + num_units, + input_size, + batch_size, + time, + num_layers, + dtypes.float16, + rtol=5e-3, + atol=5e-4) + + @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_inference(self, num_units, input_size, batch_size, time, num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + with self.session(use_gpu=True) as sess: + (outputs, cu_outputs, state_tuple, cu_state_tuple) = RunLSTM( + sess, + num_units, + input_size, + batch_size, + time, + num_layers, + is_training=False) + + self.assertAllClose(outputs, cu_outputs) + # h + self.assertAllClose(state_tuple.h, cu_state_tuple.h) + # c + self.assertAllClose(state_tuple.c, cu_state_tuple.c) + + @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_inference_fp16(self, num_units, input_size, batch_size, time, + num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + with self.session(use_gpu=True) as sess: + (outputs, cu_outputs, state_tuple, cu_state_tuple) = RunLSTM( + sess, + num_units, + input_size, + batch_size, + time, + num_layers, + is_training=False, + dtype=dtypes.float16) + + rtol, atol = 5e-3, 5e-4 + self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) + # h + self.assertAllClose( + state_tuple.h, cu_state_tuple.h, rtol=rtol, atol=atol) + # c + self.assertAllClose( + state_tuple.c, cu_state_tuple.c, rtol=rtol, atol=atol) + + @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_inference_with_dropout(self, num_units, input_size, batch_size, time, + num_layers): + """Validates that dropout does not affect Cudnn Rnn inference.""" + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + # Hand-picked dropouts are used below (0. and 1.) + with ops.Graph().as_default() as g: + with self.session(use_gpu=True, graph=g) as sess: + # 1st time w/o dropout. + (_, cu_outputs, _, cu_state_tuple) = RunLSTM( + sess, + num_units, + input_size, + batch_size, + time, + num_layers, + is_training=False, + dropout=0.) + + with ops.Graph().as_default() as g: + with self.session(use_gpu=True, graph=g) as sess: + (_, cu_outputs2, _, cu_state_tuple2) = RunLSTM( + sess, + num_units, + input_size, + batch_size, + time, + num_layers, + is_training=False, + dropout=1.) + + self.assertAllClose(cu_outputs, cu_outputs2) + # h + self.assertAllClose(cu_state_tuple.h, cu_state_tuple2.h) + # c + self.assertAllClose(cu_state_tuple.c, cu_state_tuple2.c) + + +def RunGRU(sess, + num_units, + input_size, + batch_size, + time, + num_layers=1, + is_training=True, + dropout=0., + num_dirs=True, + dtype=dtypes.float32): + # TODO(jamesqin): add multi-layer tests. + # TODO(jamesqin): add multi-dir tests + assert num_layers == 1 + assert num_dirs == 1 + if is_training and not np.isclose(dropout, 0): + raise ValueError("dropout can not be 0. when test training.") + + # set graph level random seed and numpy random seed. + random_seed.set_random_seed(0) + np.random.seed(0) + + inputs = variable_scope.get_variable( + "inputs", + initializer=np.random.rand(time, batch_size, + input_size).astype(dtype.as_numpy_dtype), + dtype=dtype) + initial_h_op = variable_scope.get_variable( + "initial_h_op", + initializer=np.random.rand(batch_size, + num_units).astype(dtype.as_numpy_dtype), + dtype=dtype) + + initializer = init_ops.random_uniform_initializer( + -0.01, 0.01, dtype=dtype, seed=19980904) + with variable_scope.variable_scope("test", initializer=initializer): + gate_kernel = variable_scope.get_variable( + "rnn/cudnn_compatible_gru_cell/gates/kernel", + shape=[input_size + num_units, num_units * 2], + dtype=dtype) + gate_bias = variable_scope.get_variable( + "rnn/cudnn_compatible_gru_cell/gates/bias", + shape=[num_units * 2], + dtype=dtype) + candidate_inp_kernel = variable_scope.get_variable( + "rnn/cudnn_compatible_gru_cell/candidate/input_projection/kernel", + shape=[input_size, num_units], + dtype=dtype) + candidate_inp_bias = variable_scope.get_variable( + "rnn/cudnn_compatible_gru_cell/candidate/input_projection/bias", + shape=[num_units], + dtype=dtype) + candidate_hid_kernel = variable_scope.get_variable( + "rnn/cudnn_compatible_gru_cell/candidate/hidden_projection/kernel", + shape=[num_units, num_units], + dtype=dtype) + candidate_hid_bias = variable_scope.get_variable( + "rnn/cudnn_compatible_gru_cell/candidate/hidden_projection/bias", + shape=[num_units], + dtype=dtype) + + cell = cudnn_rnn_ops.CudnnCompatibleGRUCell(num_units, reuse=True) + outputs_op, h_op = rnn.dynamic_rnn( + cell, + inputs, + initial_state=initial_h_op, + dtype=dtype, + time_major=True, + scope=None) + + ws = [gate_kernel, candidate_inp_kernel, candidate_hid_kernel] + bs = [gate_bias, candidate_inp_bias, candidate_hid_bias] + # Convert to cudnn opaque param. + format_converter = cudnn_rnn_ops.CudnnParamsFormatConverterGRU( + num_layers, num_units, input_size) + opaque_params = format_converter.tf_canonical_to_opaque(ws + bs) + + cu_initial_h_op = array_ops.expand_dims(initial_h_op, axis=0) + cu_outputs_op, cu_h_op, _ = cudnn_rnn_ops._cudnn_rnn( + inputs, + cu_initial_h_op, + array_ops.zeros_like(cu_initial_h_op), # not used + opaque_params, + dropout=dropout, + is_training=is_training, + rnn_mode=cudnn_rnn_ops.CUDNN_GRU) + + if is_training: + (inp_grad_op, hgrad_op, gk_grad_op, cik_grad_op, chk_grad_op, gb_grad_op, + cib_grad_op, chb_grad_op) = gradients_impl.gradients( + outputs_op, [inputs, initial_h_op] + ws + bs) + + (cu_inp_grad_op, cu_hgrad_op, opaque_grad_op) = gradients_impl.gradients( + cu_outputs_op, [inputs, cu_initial_h_op, opaque_params]) + # Remove the trivial 1st dimension + cu_hgrad_op = array_ops.squeeze(cu_hgrad_op, axis=0) + + cu_wgrad_op, cu_bgrad_op = format_converter.opaque_to_tf_canonical( + opaque_grad_op) + (cu_gk_grad_op, cu_cik_grad_op, cu_chk_grad_op) = cu_wgrad_op + (cu_gb_grad_op, cu_cib_grad_op, cu_chb_grad_op) = cu_bgrad_op + # cudnn gru has 2 biases for reset and update gates. When converting to tf + # canonical format, the two biases are summed into one. Thus here relevant + # bias gradient should be halved before comparing with tf gru. + cu_gb_grad_op *= 0.5 + + init_op = variables.global_variables_initializer() + sess.run(init_op) + + if is_training: + outputs, h, inp_grad, hgrad, wgrad, bgrad = sess.run([ + outputs_op, h_op, inp_grad_op, hgrad_op, + (gk_grad_op, cik_grad_op, chk_grad_op), + (gb_grad_op, cib_grad_op, chb_grad_op) + ]) + (cu_outputs, cu_h, cu_inp_grad, cu_hgrad, cu_wgrad, cu_bgrad) = sess.run([ + cu_outputs_op, cu_h_op, cu_inp_grad_op, cu_hgrad_op, + (cu_gk_grad_op, cu_cik_grad_op, cu_chk_grad_op), + (cu_gb_grad_op, cu_cib_grad_op, cu_chb_grad_op) + ]) + # Remove the trivial 1st dimension + cu_h = np.squeeze(cu_h, axis=0) + + logging.vlog(1, "outputs: %s" % outputs) + logging.vlog(1, "cu_outputs: %s" % cu_outputs) + logging.vlog(1, "h: %s" % h) + logging.vlog(1, "cu_h: %s" % h) + logging.vlog(1, "inp_grad: %s" % inp_grad) + logging.vlog(1, "cu_inp_grad: %s" % cu_inp_grad) + logging.vlog(1, "hgrad: %s" % hgrad) + logging.vlog(1, "cu_hgrad: %s" % cu_hgrad) + logging.vlog(1, "wgrad: %s" % str(wgrad)) + logging.vlog(1, "bgrad: %s" % str(bgrad)) + logging.vlog(1, "cu_wgrad: %s" % str(cu_wgrad)) + logging.vlog(1, "cu_bgrad: %s" % str(cu_bgrad)) + return (outputs, cu_outputs, h, cu_h, inp_grad, cu_inp_grad, hgrad, + cu_hgrad, wgrad, bgrad, cu_wgrad, cu_bgrad) else: - raise ValueError("%s direction is not supported.") + outputs, h = sess.run([outputs_op, h_op]) + cu_outputs, cu_h = sess.run([cu_outputs_op, cu_h_op]) + # Remove the trivial 1st dimension. + cu_h = np.squeeze(cu_h, axis=0) + + logging.vlog(1, "outputs: %s" % outputs) + logging.vlog(1, "cu_outputs: %s" % cu_outputs) + logging.vlog(1, "h: %s" % h) + logging.vlog(1, "cu_h: %s" % h) + return outputs, cu_outputs, h, cu_h -class CudnnRNNTestSaveRestore(TensorFlowTestCase): +class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): - def _CompareWeights(self, lhs, rhs): - self.assertEqual(len(lhs), len(rhs)) - for lw, rw in zip(lhs, rhs): - self.assertAllEqual(lw, rw) + def _test_training_helper(self, + num_units, + input_size, + batch_size, + time, + num_layers, + dtype, + rtol=2e-6, + atol=2e-6): + with self.session(use_gpu=True) as sess: + (outputs, cu_outputs, h, cu_h, inp_grad, cu_inp_grad, hgrad, + cu_hgrad, wgrad, bgrad, cu_wgrad, cu_bgrad) = RunGRU( + sess, num_units, input_size, batch_size, time, num_layers) - def _CompareBiases(self, lhs, rhs, rnn_mode, num_layers, direction): - self.assertEqual(len(lhs), len(rhs)) - if rnn_mode == CUDNN_LSTM: - num_params_per_layer = CUDNN_LSTM_PARAMS_PER_LAYER - elif rnn_mode == CUDNN_GRU: - num_params_per_layer = CUDNN_GRU_PARAMS_PER_LAYER - elif rnn_mode == CUDNN_RNN_TANH: - num_params_per_layer = CUDNN_RNN_TANH_PARAMS_PER_LAYER - else: - num_params_per_layer = CUDNN_RNN_RELU_PARAMS_PER_LAYER - num_dirs = 1 if direction == CUDNN_RNN_UNIDIRECTION else 2 - num_params_per_layer *= num_dirs - self.assertEqual(num_params_per_layer * num_layers, len(lhs)) + self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) + self.assertAllClose(h, cu_h, rtol=rtol, atol=atol) + self.assertAllClose(hgrad, cu_hgrad, rtol=rtol, atol=atol) + self.assertAllClose(inp_grad, cu_inp_grad, rtol=rtol, atol=atol) + for bg, cu_bg in zip(bgrad, cu_bgrad): + self.assertAllClose(bg, cu_bg, rtol=rtol, atol=atol) + for wg, cu_wg in zip(wgrad, cu_wgrad): + self.assertAllClose(wg, cu_wg, rtol=rtol, atol=atol) - for i in range(num_layers): - layer_lhs = lhs[i * num_params_per_layer: (i+1) * num_params_per_layer] - layer_rhs = rhs[i * num_params_per_layer: (i+1) * num_params_per_layer] - if direction == CUDNN_RNN_UNIDIRECTION: - self._CompareSingleLayerBiases(layer_lhs, layer_rhs) - else: - size = len(layer_lhs) - fw_lhs, bw_lhs = layer_lhs[:size//2], layer_lhs[size//2:] - fw_rhs, bw_rhs = layer_rhs[:size//2], layer_rhs[size//2:] - self._CompareSingleLayerBiases(fw_lhs, fw_rhs) - self._CompareSingleLayerBiases(bw_lhs, bw_rhs) + @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_training(self, num_units, input_size, batch_size, time, num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + self._test_training_helper(num_units, input_size, batch_size, time, + num_layers, dtypes.float32) - def _CompareSingleLayerBiases(self, lhs, rhs): - self.assertEqual(len(lhs), len(rhs)) + @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_training_fp16(self, num_units, input_size, batch_size, time, + num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + self._test_training_helper( + num_units, + input_size, + batch_size, + time, + num_layers, + dtypes.float16, + rtol=5e-3, + atol=5e-4) - lf_lhs, rt_lhs = lhs[:len(lhs)//2], lhs[len(lhs)//2:] - lf_rhs, rt_rhs = rhs[:len(rhs)//2], rhs[len(rhs)//2:] - self.assertEqual(len(lf_lhs), len(rt_lhs)) - self.assertEqual(len(lf_rhs), len(rt_rhs)) + @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_inference(self, num_units, input_size, batch_size, time, num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + with self.session(use_gpu=True) as sess: + (outputs, cu_outputs, h, cu_h) = RunGRU( + sess, + num_units, + input_size, + batch_size, + time, + num_layers, + is_training=False) + self.assertAllClose(outputs, cu_outputs) + self.assertAllClose(h, cu_h) - sum_lhs, sum_rhs = [], [] - for lf, rt in zip(lf_lhs, rt_lhs): - sum_lhs.append(lf + rt) - for lf, rt in zip(lf_rhs, rt_rhs): - sum_rhs.append(lf + rt) - self.assertEqual(len(sum_lhs), len(sum_rhs)) - for lf, rt in zip(sum_lhs, sum_rhs): - self.assertAllEqual(lf, rt) + @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_inference_fp16(self, num_units, input_size, batch_size, time, + num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + with self.session(use_gpu=True) as sess: + (outputs, cu_outputs, h, cu_h) = RunGRU( + sess, + num_units, + input_size, + batch_size, + time, + num_layers, + is_training=False, + dtype=dtypes.float16) - def _testSaveRestoreVariable(self, rnn_mode, direction, dtype): - num_layers = 2 - num_units = 7 - input_size = 3 - with ops.Graph().as_default(): - model = _CreateModel( - rnn_mode, - num_layers=num_layers, - num_units=num_units, - input_size=input_size, - direction=direction, - dtype=dtype) - random_seed.set_random_seed(1234) - params_size_t = model.params_size() - params = variables.VariableV1( - random_ops.random_uniform([params_size_t], dtype=dtype), - dtype=dtype, - validate_shape=False) - saveable = _CreateParamsSavable(params, model) - weights, biases = saveable.format_converter._opaque_to_cu_canonical( - saveable._variables) - reset_params = state_ops.assign( - params, - array_ops.zeros([params_size_t], dtype=dtype), - validate_shape=False) - save_path = os.path.join(self.get_temp_dir(), - "save-restore-variable-test") - saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2) - # Passing graph explicitly, otherwise an old sess would be reused. - with self.test_session( - use_gpu=True, graph=ops.get_default_graph()) as sess: - sess.run(variables.global_variables_initializer()) - val = saver.save(sess, save_path) - self.assertEqual(save_path, val) + rtol, atol = 5e-3, 5e-4 + self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) + self.assertAllClose(h, cu_h, rtol=rtol, atol=atol) - weights_v, biases_v = sess.run([weights, biases]) + @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_inference_with_dropout(self, num_units, input_size, batch_size, time, + num_layers): + """Validates that dropout does not affect Cudnn Rnn inference.""" + # Hand-picked dropouts are used below (0. and 1.) + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + with ops.Graph().as_default() as g: + with self.session(use_gpu=True, graph=g) as sess: + # 1st time w/o dropout. + (_, cu_outputs, _, cu_h) = RunGRU( + sess, + num_units, + input_size, + batch_size, + time, + num_layers, + is_training=False, + dropout=0.) - sess.run(reset_params) - saver.restore(sess, save_path) - weights_v_restored, biases_v_restored = sess.run([weights, biases]) + with ops.Graph().as_default() as g: + with self.session(use_gpu=True, graph=g) as sess: + (_, cu_outputs2, _, cu_h2) = RunGRU( + sess, + num_units, + input_size, + batch_size, + time, + num_layers, + is_training=False, + dropout=1.) - self._CompareWeights(weights_v, weights_v_restored) - self._CompareBiases(biases_v, biases_v_restored, rnn_mode, num_layers, - direction) + self.assertAllClose(cu_outputs, cu_outputs2) + self.assertAllClose(cu_h[0], cu_h2[0]) - def _testSaveRestoreTwoVariables(self, rnn_mode, direction, dtype): - num_layers = 2 - num_units = 7 - input_size = 3 - with ops.Graph().as_default(): - model = _CreateModel( - rnn_mode, - num_layers=num_layers, - num_units=num_units, - input_size=input_size, - direction=direction, - dtype=dtype) - random_seed.set_random_seed(1234) - params_size_t = model.params_size() - names = ["rnn_1", "rnn_2"] - param_vars = [ - variables.VariableV1( - random_ops.random_uniform([params_size_t], dtype=dtype), - dtype=dtype, - validate_shape=False) for name in names - ] - saveables = [] - for name, params in zip(names, param_vars): - saveables.append(_CreateParamsSavable(params, model, name, name)) - weights1, biases1 = saveables[0].format_converter._opaque_to_cu_canonical( - saveables[0]._variables) - weights2, biases2 = saveables[1].format_converter._opaque_to_cu_canonical( - saveables[1]._variables) - reset_params = [ - state_ops.assign( - params, - array_ops.zeros([params_size_t], dtype=dtype), - validate_shape=False) for params in param_vars - ] - save_path = os.path.join(self.get_temp_dir(), - "save-restore-variable-test") - saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2) - # Passing graph explicitly, otherwise an old sess would be reused. - with self.test_session(use_gpu=True, - graph=ops.get_default_graph()) as sess: - sess.run(variables.global_variables_initializer()) - val = saver.save(sess, save_path) - self.assertEqual(save_path, val) - weights1_v, biases1_v = sess.run([weights1, biases1]) - weights2_v, biases2_v = sess.run([weights2, biases2]) - sess.run(reset_params) - saver.restore(sess, save_path) - weights1_v_restored, biases1_v_restored = sess.run([weights1, biases1]) - weights2_v_restored, biases2_v_restored = sess.run([weights2, biases2]) +class CudnnParamsFormatConverterTest(TensorFlowTestCase, + parameterized.TestCase): + """Class for testing various format converters.""" - self._CompareWeights(weights1_v, weights1_v_restored) - self._CompareWeights(weights2_v, weights2_v_restored) - self._CompareBiases(biases1_v, biases1_v_restored, rnn_mode, num_layers, - direction) - self._CompareBiases(biases2_v, biases2_v_restored, rnn_mode, num_layers, - direction) + def _test_lstm_helper(self, num_units, input_size, num_layers, direction): + with self.session(use_gpu=True) as sess: + random_seed.set_random_seed(0) + np.random.seed(0) - def _testSaveRestoreOutput(self, rnn_mode, direction, dtype): - with ops.Graph().as_default(): - num_layers = 2 - num_units = 7 - input_size = 7 - seq_length = 10 - batch_size = 5 - dir_count = 1 if direction == cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION else 2 - model = _CreateModel( - rnn_mode, + num_dirs = 1 if direction == cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION else 2 + format_converter = cudnn_rnn_ops.CudnnParamsFormatConverterLSTM( + num_layers, num_units, input_size, direction=direction) + + ws, bs = [], [] + for _ in range(num_layers * num_dirs): + w = constant_op.constant( + np.random.rand(input_size + num_units, 4 * num_units), + dtype=dtypes.float32) + b = constant_op.constant( + np.random.rand(4 * num_units), dtype=dtypes.float32) + ws.append(w) + bs.append(b) + + opaque_params = format_converter.tf_canonical_to_opaque(ws + bs) + opaque_params_size = cudnn_rnn_ops.cudnn_rnn_opaque_params_size( + cudnn_rnn_ops.CUDNN_LSTM, num_layers, num_units, input_size, - direction=direction, - dtype=dtype) - params_size_t = model.params_size() - params = variables.VariableV1( - array_ops.ones([params_size_t], dtype=dtype), - validate_shape=False, - dtype=dtype) - _CreateParamsSavable(params, model) - save_path = os.path.join(self.get_temp_dir(), "save-restore-output-test") + direction=direction) + + ws_r, bs_r = format_converter.opaque_to_tf_canonical(opaque_params) + + # Test tf_canonical_to_opaque() followed by opaque_to_tf_canonical() + # returns the original input. + ws, ws_r, bs, bs_r = sess.run([ws, ws_r, bs, bs_r]) + for w, w_r in zip(ws, ws_r): + self.assertAllClose(w, w_r) + for b, b_r in zip(bs, bs_r): + self.assertAllClose(b, b_r) + + # Test opaque_params size lower bound + opaque_params_size_v = sess.run(opaque_params_size) + min_params_size = sum(x.size for x in ws) + np.sum(x.size for x in bs) + logging.info("min_parm_size: %d vs actual_opaque_param_size: %d", + min_params_size, opaque_params_size_v) + self.assertLessEqual(min_params_size, opaque_params_size_v) + + @parameterized.named_parameters((c["testcase_name"], c["num_units"], + c["input_size"], c["num_layers"]) + for c in NAMED_RNN_TESTCASES) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_lstm(self, num_units, input_size, num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + self._test_lstm_helper(num_units, input_size, num_layers, + cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) + + @parameterized.named_parameters((c["testcase_name"], c["num_units"], + c["input_size"], c["num_layers"]) + for c in NAMED_RNN_TESTCASES) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_lstm_bidi(self, num_units, input_size, num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + self._test_lstm_helper(num_units, input_size, num_layers, + cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION) + + def _test_gru_helper(self, num_units, input_size, num_layers, direction): + with self.session(use_gpu=True) as sess: + random_seed.set_random_seed(0) + np.random.seed(0) + + num_dirs = 1 if direction == cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION else 2 + format_converter = cudnn_rnn_ops.CudnnParamsFormatConverterGRU( + num_layers, num_units, input_size, direction=direction) + + ws, bs = [], [] + for _ in range(num_layers * num_dirs): + gate_kernel = constant_op.constant( + np.random.rand(input_size + num_units, num_units * 2), + dtype=dtypes.float32) + gate_bias = constant_op.constant( + np.random.rand(num_units * 2), dtype=dtypes.float32) + candidate_inp_kernel = constant_op.constant( + np.random.rand(input_size, num_units), dtype=dtypes.float32) + candidate_inp_bias = constant_op.constant( + np.random.rand(num_units), dtype=dtypes.float32) + candidate_hid_kernel = constant_op.constant( + np.random.rand(num_units, num_units), dtype=dtypes.float32) + candidate_hid_bias = constant_op.constant( + np.random.rand(num_units), dtype=dtypes.float32) + ws.extend([gate_kernel, candidate_inp_kernel, candidate_hid_kernel]) + bs.extend([gate_bias, candidate_inp_bias, candidate_hid_bias]) + + opaque_params = format_converter.tf_canonical_to_opaque(ws + bs) + opaque_params_size = cudnn_rnn_ops.cudnn_rnn_opaque_params_size( + cudnn_rnn_ops.CUDNN_GRU, + num_layers, + num_units, + input_size, + direction=direction) + + ws_r, bs_r = format_converter.opaque_to_tf_canonical(opaque_params) + + # Test tf_canonical_to_opaque() followed by opaque_to_tf_canonical() + # returns the original input. + ws, ws_r, bs, bs_r = sess.run([ws, ws_r, bs, bs_r]) + for w, w_r in zip(ws, ws_r): + self.assertAllClose(w, w_r) + for b, b_r in zip(bs, bs_r): + self.assertAllClose(b, b_r) + + # Test opaque_params size lower bound + opaque_params_size_v = sess.run(opaque_params_size) + min_params_size = sum(x.size for x in ws) + sum(x.size for x in bs) + logging.info("min_parm_size: %d vs actual_opaque_param_size: %d", + min_params_size, opaque_params_size_v) + self.assertLessEqual(min_params_size, opaque_params_size_v) + + @parameterized.named_parameters((c["testcase_name"], c["num_units"], + c["input_size"], c["num_layers"]) + for c in NAMED_RNN_TESTCASES) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_gru(self, num_units, input_size, num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + self._test_gru_helper(num_units, input_size, num_layers, + cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) + + @parameterized.named_parameters((c["testcase_name"], c["num_units"], + c["input_size"], c["num_layers"]) + for c in NAMED_RNN_TESTCASES) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_gru_bidi(self, num_units, input_size, num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + self._test_gru_helper(num_units, input_size, num_layers, + cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION) + + +class CudnnRnnSaveRestoreTest(TensorFlowTestCase, parameterized.TestCase): + """Class for testing various Cudnn Rnn SaveableObjects.""" + + def _create_opaque_param(self, + rnn_mode, + num_units, + input_size, + num_layers, + direction, + name=None): + param_size_t = cudnn_rnn_ops.cudnn_rnn_opaque_params_size( + rnn_mode, num_layers, num_units, input_size, direction=direction) + init_val = random_ops.random_uniform([param_size_t]) + return variable_scope.get_variable( + name or "opaque_param", initializer=init_val, validate_shape=False) + + def _create_saveable(self, opaque_param, rnn_mode, num_units, input_size, + num_layers, direction): + if rnn_mode == CUDNN_LSTM: + fn = cudnn_rnn_ops.CudnnLSTMSaveable + elif rnn_mode == CUDNN_GRU: + fn = cudnn_rnn_ops.CudnnGRUSaveable + elif rnn_mode == CUDNN_RNN_TANH: + fn = cudnn_rnn_ops.CudnnRNNTanhSaveable + elif rnn_mode == CUDNN_RNN_RELU: + fn = cudnn_rnn_ops.CudnnRNNReluSaveable + saveable = fn( + opaque_param, num_layers, num_units, input_size, direction=direction) + return saveable + + def _compare_weights(self, lhs, rhs): + self.assertLen(rhs, len(lhs)) + for lw, rw in zip(lhs, rhs): + self.assertAllEqual(lw, rw) + + def _compare_biases(self, lhs, rhs): + self.assertLen(rhs, len(lhs)) + for lf, rt in zip(lhs, rhs): + self.assertAllEqual(lf, rt) + + @parameterized.named_parameters( + ExpandNamedTestCases( + NAMED_RNN_TESTCASES, "time", "batch_size", **{ + "rnn_mode": [ + CUDNN_LSTM, CUDNN_GRU, CUDNN_RNN_RELU, CUDNN_RNN_TANH + ], + "direction": [CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION] + })) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_save_restore_variable(self, rnn_mode, num_units, input_size, + num_layers, direction): + # Verify the restored opaque param, once converted to tf_canonical format, + # is the same as the tf canonicals of the pre-restored param. + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + with self.session(use_gpu=True) as sess: + opaque_param = self._create_opaque_param(rnn_mode, num_units, input_size, + num_layers, direction) + saveable = self._create_saveable(opaque_param, rnn_mode, num_units, + input_size, num_layers, direction) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + weights_op, biases_op = saveable.format_converter.opaque_to_tf_canonical( + saveable._variables) + + save_path = os.path.join(self.get_temp_dir(), "save_restore_var_test") saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2) - np.random.seed(1234) - has_input_c = (rnn_mode == cudnn_rnn_ops.CUDNN_LSTM) - input_data = constant_op.constant( - np.random.randn(seq_length, batch_size, input_size), dtype=dtype) - input_h = constant_op.constant( - np.random.randn(num_layers * dir_count, batch_size, num_units), - dtype=dtype) - if has_input_c: - input_c = constant_op.constant( - np.random.randn(num_layers * dir_count, batch_size, num_units), - dtype=dtype) - outputs = model( - input_data=input_data, - input_h=input_h, - input_c=input_c, - params=params, - is_training=False) - else: - outputs = model( - input_data=input_data, - input_h=input_h, - params=params, - is_training=False) - total_sum = sum(map(math_ops.reduce_sum, outputs)) - # Passing graph explicitly, otherwise an old sess would be reused. - with self.test_session( - use_gpu=True, graph=ops.get_default_graph()) as sess: - sess.run(variables.global_variables_initializer()) - total_sum_v = sess.run(total_sum) - val = saver.save(sess, save_path) - self.assertEqual(save_path, val) - # Passing graph explicitly, otherwise an old sess would be reused. - with self.test_session( - use_gpu=True, graph=ops.get_default_graph()) as sess: - reset_params = state_ops.assign( - params, - array_ops.zeros([params_size_t], dtype=dtype), - validate_shape=False) - sess.run(reset_params) + init_op = variables.global_variables_initializer() + reset_op = state_ops.assign(opaque_param, + array_ops.zeros_like(opaque_param)) + sess.run(init_op) + self.assertEqual(save_path, saver.save(sess, save_path)) + + # Get the tf canonical vals before reset-restore + weights, biases = sess.run([weights_op, biases_op]) + + # Reset the opaque param value + sess.run(reset_op) + # Assert reset happened. + weights_z, biases_z = sess.run([weights_op, biases_op]) + for w in weights_z: + self.assertAllClose(w, np.zeros_like(w)) + for b in biases_z: + self.assertAllClose(b, np.zeros_like(b)) + + # Restore opaque param value from checkpoint. + saver.restore(sess, save_path) + weights_r, biases_r = sess.run([weights_op, biases_op]) + self._compare_weights(weights, weights_r) + self._compare_biases(biases, biases_r) + + @parameterized.named_parameters( + ExpandNamedTestCases( + NAMED_RNN_TESTCASES, "time", "batch_size", **{ + "rnn_mode": [ + CUDNN_LSTM, CUDNN_GRU, CUDNN_RNN_RELU, CUDNN_RNN_TANH + ], + "direction": [CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION] + })) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_save_restore_multi_variables(self, rnn_mode, num_units, input_size, + num_layers, direction): + # Verify the restored opaque param, once converted to tf_canonical format, + # is the same as the tf canonicals of the pre-restored param. + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + with self.session(use_gpu=True) as sess: + opaque_params = [] + saveables = [] + num_opaque_params = 2 + for i in range(num_opaque_params): + opaque_params.append( + self._create_opaque_param( + rnn_mode, + num_units, + input_size, + num_layers, + direction, + name="opaque_param_%d" % i)) + saveable = self._create_saveable(opaque_params[i], rnn_mode, num_units, + input_size, num_layers, direction) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + saveables.append(saveable) + + weights_ops, biases_ops = [], [] + for i in range(num_opaque_params): + weights_op, biases_op = ( + saveables[i].format_converter.opaque_to_tf_canonical( + saveables[i]._variables)) + weights_ops.append(weights_op) + biases_ops.append(biases_op) + + save_path = os.path.join(self.get_temp_dir(), "save_restore_var_test") + saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2) + + init_op = variables.global_variables_initializer() + reset_ops = [] + for i in range(num_opaque_params): + reset_ops.append( + state_ops.assign(opaque_params[i], + array_ops.zeros_like(opaque_params[i]))) + sess.run(init_op) + self.assertEqual(save_path, saver.save(sess, save_path)) + + # Get the tf canonical vals before reset-restore + for i in range(num_opaque_params): + weights, biases = sess.run([weights_ops[i], biases_ops[i]]) + + # Reset the opaque param value + sess.run(reset_ops[i]) + + # Assert reset happened. + weights_z, biases_z = sess.run([weights_ops[i], biases_ops[i]]) + for w in weights_z: + self.assertAllClose(w, np.zeros_like(w)) + for b in biases_z: + self.assertAllClose(b, np.zeros_like(b)) + + # Restore opaque param value from checkpoint. saver.restore(sess, save_path) - total_sum_v_restored = sess.run(total_sum) - self.assertAllClose(total_sum_v, total_sum_v_restored, atol=1e-5) - - @unittest.skipUnless(test.is_built_with_cuda(), - "Test only applicable when running on GPUs") - def testSaveRestore(self): - rnn_modes = [ - cudnn_rnn_ops.CUDNN_LSTM, cudnn_rnn_ops.CUDNN_GRU, - cudnn_rnn_ops.CUDNN_RNN_TANH, cudnn_rnn_ops.CUDNN_RNN_RELU - ] - directions = [ - cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION, - cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION - ] - dtype_list = [dtypes.float32, dtypes.float64] - for rnn_mode, direction, dtype in itertools.product(rnn_modes, directions, - dtype_list): - self._testSaveRestoreVariable(rnn_mode, direction, dtype) - self._testSaveRestoreTwoVariables(rnn_mode, direction, dtype) - self._testSaveRestoreOutput(rnn_mode, direction, dtype) - - -class CudnnRNNTestParamsSize(TensorFlowTestCase): - - def _testOneLSTMParamsSize(self, num_layers, num_units, input_size, - direction): - logging.info("Testing one lstm param size with config: %s", locals()) - min_params_size = _MinLSTMParamSize(num_layers, num_units, input_size, - direction) - model = _CreateModel( - cudnn_rnn_ops.CUDNN_LSTM, - num_layers, - num_units, - input_size, - direction=direction) - params_size = model.params_size() - with self.test_session(use_gpu=True, graph=ops.get_default_graph()) as sess: - params_size_v = sess.run(params_size) - self.assertLessEqual(min_params_size, params_size_v) - - @unittest.skipUnless(test.is_built_with_cuda(), - "Test only applicable when running on GPUs") - def testLSTMParamsSize(self): - test_configs = [ - [4, 200, 200], - [4, 200, 300], - [4, 200, 100], - [1, 100, 200], - [2, 200, 100], - [3, 200, 400], - ] - directions = [ - cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION, - cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION - ] - for (config, direction) in itertools.product(test_configs, directions): - num_layers, num_units, input_size = config - with ops.Graph().as_default(): - self._testOneLSTMParamsSize(num_layers, num_units, input_size, - direction) - - @unittest.skipUnless(test.is_built_with_cuda(), - "Test only applicable when running on GPUs") - def testLSTMParamsSizeShape(self): - with self.assertRaisesRegexp( - ValueError, "Shape must be rank 0 but is rank 1"): - model = _CreateModel( - cudnn_rnn_ops.CUDNN_LSTM, - constant_op.constant([4]), 200, 200, - direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) - _ = model.params_size() - with self.assertRaisesRegexp( - ValueError, "Shape must be rank 0 but is rank 1"): - model = _CreateModel( - cudnn_rnn_ops.CUDNN_LSTM, - 4, constant_op.constant([200]), 200, - direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) - _ = model.params_size() - with self.assertRaisesRegexp( - ValueError, "Shape must be rank 0 but is rank 1"): - model = _CreateModel( - cudnn_rnn_ops.CUDNN_LSTM, - 4, 200, constant_op.constant([200]), - direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) - _ = model.params_size() - - -class CudnnRNNTestInference(TensorFlowTestCase): - - def _testOneSimpleInference(self, rnn_mode, num_layers, num_units, input_size, - batch_size, seq_length, dir_count, dropout, - expected, tolerance): - random_seed.set_random_seed(5678) - model = _CreateModel( - rnn_mode, - num_layers, - num_units, - input_size, - input_mode="auto_select", - direction=(cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION if dir_count == 1 - else cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION), - dropout=dropout) - has_input_c = (rnn_mode == cudnn_rnn_ops.CUDNN_LSTM) - params_size_t = model.params_size() - input_data = array_ops.ones([seq_length, batch_size, input_size]) - input_h = array_ops.ones([num_layers * dir_count, batch_size, num_units]) - params = variables.VariableV1( - array_ops.ones([params_size_t]), validate_shape=False) - if has_input_c: - input_c = array_ops.ones([num_layers * dir_count, batch_size, num_units]) - output, output_h, output_c = model( - input_data=input_data, - input_h=input_h, - input_c=input_c, - params=params, - is_training=False) - else: - output, output_h = model( - input_data=input_data, - input_h=input_h, - params=params, - is_training=False) - output_sum = math_ops.reduce_sum(output) - output_h_sum = math_ops.reduce_sum(output_h) - total_sum = output_sum + output_h_sum - if has_input_c: - output_c_sum = math_ops.reduce_sum(output_c) - total_sum += output_c_sum - with self.test_session(use_gpu=True, graph=ops.get_default_graph()) as sess: - sess.run(variables.global_variables_initializer()) - total_sum_v = sess.run([total_sum]) - - self.assertAllClose( - total_sum_v[0], expected, atol=tolerance, rtol=tolerance) - - @unittest.skipUnless(test.is_built_with_cuda(), - "Test only applicable when running on GPUs") - def testSimpleInference(self): - test_configs = [ - { - "rnn_mode": cudnn_rnn_ops.CUDNN_LSTM, - "expected": 231833.22, - "tolerance": 1e-2, - "shape": { - "num_layers": 4, - "num_units": 200, - "input_size": 200, - "batch_size": 20, - "seq_length": 10, - "dir_count": 1, - }, - }, - { - "rnn_mode": cudnn_rnn_ops.CUDNN_GRU, - "expected": 56000, - "tolerance": 1e-2, - "shape": { - "num_layers": 4, - "num_units": 200, - "input_size": 200, - "batch_size": 20, - "seq_length": 10, - "dir_count": 1, - }, - }, - { - "rnn_mode": cudnn_rnn_ops.CUDNN_RNN_TANH, - "expected": 56000, - "tolerance": 1e-2, - "shape": { - "num_layers": 4, - "num_units": 200, - "input_size": 200, - "batch_size": 20, - "seq_length": 10, - "dir_count": 1, - }, - }, - { - "rnn_mode": cudnn_rnn_ops.CUDNN_RNN_RELU, - "expected": 130688, - "tolerance": 1e-2, - "shape": { - "num_layers": 2, - "num_units": 8, - "input_size": 4, - "batch_size": 4, - "seq_length": 2, - "dir_count": 1, - }, - }, - ] - # Cudnn scales result for dropout during training, therefore dropout has no - # impact for inference results. - # (lstm, gru, rnn_tanh are saturated in the test. rnn_relu case is most - # demonstrative of the dropout-invariant nature of CudnnRnn.) - dropouts = [0., 0.5, 1.] - for (config, dropout) in itertools.product(test_configs, dropouts): - rnn_mode = config["rnn_mode"] - expected = config["expected"] - tolerance = config["tolerance"] - shape = config["shape"] - with ops.Graph().as_default(): - self._testOneSimpleInference( - rnn_mode, shape["num_layers"], shape["num_units"], - shape["input_size"], shape["batch_size"], shape["seq_length"], - shape["dir_count"], dropout, expected, tolerance) - - -class CudnnRNNTestTraining(TensorFlowTestCase): - - def _testOneSimpleTraining(self, rnn_mode, num_layers, num_units, input_size, - batch_size, seq_length, dir_count, dropout, dtype, - delta, tolerance): - # Gradient checking runs two forward ops with almost the same input. Need to - # make sure the drop patterns across the two runs are the same. - logging.info("Training test with config: %s", locals()) - old_env_state = os.environ.get("TF_CUDNN_RESET_RND_GEN_STATE", str(False)) - os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = str(True) - has_input_c = (rnn_mode == cudnn_rnn_ops.CUDNN_LSTM) - random_seed.set_random_seed(5678) - direction = (cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION if dir_count == 1 - else cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION) - model = _CreateModel( - rnn_mode, - num_layers, - num_units, - input_size, - direction=direction, - dtype=dtype, - dropout=dropout) - params_size_t = model.params_size() - input_data = variables.VariableV1( - random_ops.random_uniform( - [seq_length, batch_size, input_size], dtype=dtype), - dtype=dtype) - input_h = variables.VariableV1( - random_ops.random_uniform( - [num_layers * dir_count, batch_size, num_units], dtype=dtype), - dtype=dtype) - params = variables.VariableV1( - random_ops.random_uniform([params_size_t], dtype=dtype), - validate_shape=False, - dtype=dtype) - if has_input_c: - input_c = variables.VariableV1( - random_ops.random_uniform( - [num_layers * dir_count, batch_size, num_units], dtype=dtype), - dtype=dtype) - - output, output_h, output_c = model( - input_data=input_data, - input_h=input_h, - input_c=input_c, - params=params) - else: - output, output_h = model( - input_data=input_data, input_h=input_h, params=params) - output_sum = math_ops.reduce_sum(output) - output_h_sum = math_ops.reduce_sum(output_h) - total_sum = output_sum + output_h_sum - if has_input_c: - output_c_sum = math_ops.reduce_sum(output_c) - total_sum += output_c_sum - - with self.test_session(use_gpu=True, graph=ops.get_default_graph()) as sess: - params_size_v = sess.run(params_size_t) - inputs_and_shapes = [ - (input_data, [seq_length, batch_size, input_size]), - (input_h, [num_layers * dir_count, batch_size, num_units]), - (params, [params_size_v]), - ] - if has_input_c: - inputs_and_shapes.append( - (input_c, [num_layers * dir_count, batch_size, num_units]),) - sess.run(variables.global_variables_initializer()) - all_inputs = [entry[0] for entry in inputs_and_shapes] - all_shapes = [entry[1] for entry in inputs_and_shapes] - - err = gradient_checker.compute_gradient_error( - all_inputs, all_shapes, total_sum, [1], delta=delta) - - self.assertLess(err, tolerance) - os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = old_env_state - - @unittest.skipUnless(test.is_built_with_cuda(), - "Test only applicable when running on GPUs") - def DISABLED_testSimpleTraining(self): - # TODO(jamesqin): fix b/117989214 - test_configs = [ - { - "rnn_mode": cudnn_rnn_ops.CUDNN_LSTM, - "dtype": dtypes.float64, - "delta": 1e-4, - "tolerance": 5e-6, - "shape": { - "num_layers": 2, - "num_units": 3, - "input_size": 4, - "batch_size": 3, - "seq_length": 4, - "dir_count": 1, - }, - }, - { - "rnn_mode": cudnn_rnn_ops.CUDNN_GRU, - "dtype": dtypes.float64, - "delta": 1e-4, - "tolerance": 5e-6, - "shape": { - "num_layers": 2, - "num_units": 3, - "input_size": 4, - "batch_size": 3, - "seq_length": 4, - "dir_count": 1, - }, - }, - { - "rnn_mode": cudnn_rnn_ops.CUDNN_RNN_TANH, - "dtype": dtypes.float64, - "delta": 1e-4, - "tolerance": 5e-6, - "shape": { - "num_layers": 2, - "num_units": 3, - "input_size": 4, - "batch_size": 3, - "seq_length": 4, - "dir_count": 1, - }, - }, - { - "rnn_mode": cudnn_rnn_ops.CUDNN_RNN_RELU, - "dtype": dtypes.float64, - "delta": 1e-4, - "tolerance": 5e-6, - "shape": { - "num_layers": 2, - "num_units": 3, - "input_size": 4, - "batch_size": 3, - "seq_length": 4, - "dir_count": 1, - }, - }, - { - "rnn_mode": cudnn_rnn_ops.CUDNN_LSTM, - "dtype": dtypes.float32, - "tolerance": 1.5e-2, - "shape": { - "num_layers": 2, - "num_units": 3, - "input_size": 4, - "batch_size": 3, - "seq_length": 4, - }, - }, - { - "rnn_mode": cudnn_rnn_ops.CUDNN_GRU, - "dtype": dtypes.float32, - "tolerance": 4e-3, - "shape": { - "num_layers": 2, - "num_units": 3, - "input_size": 4, - "batch_size": 3, - "seq_length": 4, - }, - }, - { - "rnn_mode": cudnn_rnn_ops.CUDNN_RNN_TANH, - "dtype": dtypes.float32, - "tolerance": 5e-3, - "shape": { - "num_layers": 2, - "num_units": 3, - "input_size": 4, - "batch_size": 3, - "seq_length": 4, - }, - }, - { - "rnn_mode": cudnn_rnn_ops.CUDNN_RNN_RELU, - "dtype": dtypes.float32, - "tolerance": 5e-1, - "shape": { - "num_layers": 2, - "num_units": 3, - "input_size": 4, - "batch_size": 3, - "seq_length": 4, - }, - }, - ] - dropouts = [0., 0.5, 1.] - dir_counts = [1] - for config, dropout, dir_count in itertools.product(test_configs, dropouts, - dir_counts): - rnn_mode = config["rnn_mode"] - dtype = config.get("dtype", dtypes.float32) - delta = config.get("delta", 1e-3) - tolerance = config["tolerance"] - shape = config["shape"] - with ops.Graph().as_default(): - self._testOneSimpleTraining(rnn_mode, shape["num_layers"], - shape["num_units"], shape["input_size"], - shape["batch_size"], shape["seq_length"], - dir_count, dropout, dtype, delta, tolerance) + weights_r, biases_r = sess.run([weights_ops[i], biases_ops[i]]) + self._compare_weights(weights, weights_r) + self._compare_biases(biases, biases_r) if __name__ == "__main__": diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py index 1954f6717bb..7e1b4062ce4 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py @@ -536,7 +536,9 @@ class CudnnRNNTestSaveRestore(test_util.TensorFlowTestCase): save_path = os.path.join(self.get_temp_dir(), "save-restore-variable-test") saver = saver_lib.Saver() - weights, biases = model.rnn.saveable._OpaqueParamsToCanonical() + weights, biases = ( + model.rnn.saveable.format_converter._opaque_to_cu_canonical( + model.rnn.saveable._variables)) opaque_params = rnn.trainable_variables[0] # CudnnTestModel() creates CudnnOpaqueParamsSaveable that helps saver save # Cudnn vars in canonical format. @@ -583,8 +585,12 @@ class CudnnRNNTestSaveRestore(test_util.TensorFlowTestCase): dtype=dtype) opaque_params = (model1.rnn.trainable_variables[0], model2.rnn.trainable_variables[0]) - weights1, biases1 = model1.rnn.saveable._OpaqueParamsToCanonical() - weights2, biases2 = model2.rnn.saveable._OpaqueParamsToCanonical() + saveable1 = model1.rnn.saveable + weights1, biases1 = saveable1.format_converter._opaque_to_cu_canonical( + saveable1._variables) + saveable2 = model1.rnn.saveable + weights2, biases2 = saveable2.format_converter._opaque_to_cu_canonical( + saveable2._variables) reset_params = [ state_ops.assign(params, array_ops.zeros_like(params, dtype=dtype)) @@ -1039,8 +1045,8 @@ class CudnnRNNTestParamsSize(test_util.TensorFlowTestCase): # Min param size estimate = sum(weights.size) + sum(biases.size) min_params_size = ( - np.sum(list(map(np.prod, rnn.canonical_weight_shapes))) + - np.sum([sp[0] for sp in rnn.canonical_bias_shapes])) + sum(map(np.prod, rnn.canonical_weight_shapes)) + + sum(sp[0] for sp in rnn.canonical_bias_shapes)) opaque_params = rnn.trainable_variables[0] with self.test_session(use_gpu=True, graph=ops.get_default_graph()): diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py index 8bbcc7cd039..8e25637ed91 100644 --- a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py +++ b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py @@ -21,6 +21,7 @@ from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras.engine import input_spec from tensorflow.python.layers import base as base_layer from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops @@ -322,7 +323,7 @@ class _CudnnRNN(base_layer.Layer): raise ValueError("The last dimension of the inputs to `CudnnRNN` " "should be defined. Found `None`.") self._input_size = input_shape[-1].value - self.input_spec = base_layer.InputSpec(ndim=3, axes={-1: self._input_size}) + self.input_spec = input_spec.InputSpec(ndim=3, axes={-1: self._input_size}) self._set_scope(None) diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index d06d0c6bdaa..1ce29b42d52 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -738,7 +738,7 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): self._variables, opaque_params, validate_shape=False) def _checkpointable_save(self, save_buffer): - weights, biases = self.format_converter.opaque_params_to_tf_canonical( + weights, biases = self.format_converter.opaque_to_tf_canonical( self._variables) for name, tensor in zip(self._param_names, weights + biases): save_buffer[name] = array_ops.identity(tensor) diff --git a/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py b/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py index 0456463a192..6c5f8c6b009 100644 --- a/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py @@ -46,7 +46,7 @@ class AssertElementShapeTest(test_base.DatasetTestBase): result = dataset.apply(batching.assert_element_shape(expected_shapes)) self.assertEqual(expected_shapes, result.output_shapes) - iterator = result.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(result) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: @@ -88,7 +88,7 @@ class AssertElementShapeTest(test_base.DatasetTestBase): result = dataset.apply(batching.assert_element_shape(expected_shapes)) self.assertEqual(expected_shapes, result.output_shapes) - iterator = result.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(result) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: @@ -115,9 +115,8 @@ class AssertElementShapeTest(test_base.DatasetTestBase): wrong_shapes = (tensor_shape.TensorShape(2), tensor_shape.TensorShape((3, 10))) - iterator = ( - dataset.apply(batching.assert_element_shape(wrong_shapes)) - .make_initializable_iterator()) + iterator = dataset_ops.make_initializable_iterator( + dataset.apply(batching.assert_element_shape(wrong_shapes))) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: @@ -142,7 +141,7 @@ class AssertElementShapeTest(test_base.DatasetTestBase): tensor_shape.TensorShape((3, 4))) self.assertEqual(actual_shapes, result.output_shapes) - iterator = result.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(result) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: @@ -184,7 +183,7 @@ class AssertElementShapeTest(test_base.DatasetTestBase): result = dataset.apply(batching.assert_element_shape(expected_shapes)) self.assertEqual(expected_shapes, result.output_shapes) - iterator = result.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(result) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: @@ -211,9 +210,8 @@ class AssertElementShapeTest(test_base.DatasetTestBase): wrong_shapes = (tensor_shape.TensorShape(2), tensor_shape.TensorShape((None, 10))) - iterator = ( - dataset.apply(batching.assert_element_shape(wrong_shapes)) - .make_initializable_iterator()) + iterator = dataset_ops.make_initializable_iterator( + dataset.apply(batching.assert_element_shape(wrong_shapes))) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: diff --git a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py index d2a72272db1..b9840b1ff1a 100644 --- a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py @@ -23,6 +23,7 @@ import shutil from tensorflow.contrib.data.python.ops import readers from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -48,7 +49,7 @@ class LMDBDatasetTest(test_base.DatasetTestBase): num_repeats = 2 dataset = readers.LMDBDataset(filenames).repeat(num_repeats) - iterator = dataset.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(dataset) init_op = iterator.initializer get_next = iterator.get_next() diff --git a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py index c5a78623225..2527706709f 100644 --- a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py @@ -63,13 +63,13 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): # The pipeline is TensorSliceDataset -> MapDataset(square_3) -> # RepeatDataset(count) -> # _SlideDataset(window_size, window_shift, window_stride). - iterator = ( + iterator = dataset_ops.make_initializable_iterator( dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) .repeat(count).apply( sliding.sliding_window_batch( window_size=window_size_t, window_shift=window_shift_t, - window_stride=window_stride_t)).make_initializable_iterator()) + window_stride=window_stride_t))) init_op = iterator.initializer get_next = iterator.get_next() @@ -127,13 +127,13 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): # The pipeline is TensorSliceDataset -> MapDataset(square_3) -> # RepeatDataset(count) -> _SlideDataset(window_size, stride, window_stride). - iterator = ( + iterator = dataset_ops.make_initializable_iterator( dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) .repeat(count).apply( sliding.sliding_window_batch( window_size=window_size_t, stride=stride_t, - window_stride=window_stride_t)).make_initializable_iterator()) + window_stride=window_stride_t))) init_op = iterator.initializer get_next = iterator.get_next() @@ -173,12 +173,12 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): window_shift_t = array_ops.placeholder(dtypes.int64, shape=[]) window_stride_t = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = ( + iterator = dataset_ops.make_initializable_iterator( dataset_ops.Dataset.range(10).map(lambda x: x).repeat(count_t).apply( sliding.sliding_window_batch( window_size=window_size_t, window_shift=window_shift_t, - window_stride=window_stride_t)).make_initializable_iterator()) + window_stride=window_stride_t))) init_op = iterator.initializer with self.cached_session() as sess: @@ -204,9 +204,9 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): return sparse_tensor.SparseTensorValue( indices=[[0]], values=(i * [1]), dense_shape=[1]) - iterator = dataset_ops.Dataset.range(10).map(_sparse).apply( - sliding.sliding_window_batch( - window_size=5, window_shift=3)).make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator( + dataset_ops.Dataset.range(10).map(_sparse).apply( + sliding.sliding_window_batch(window_size=5, window_shift=3))) init_op = iterator.initializer get_next = iterator.get_next() @@ -233,9 +233,9 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): values=array_ops.fill([math_ops.to_int32(i)], i), dense_shape=[i]) - iterator = dataset_ops.Dataset.range(10).map(_sparse).apply( - sliding.sliding_window_batch( - window_size=5, window_shift=3)).make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator( + dataset_ops.Dataset.range(10).map(_sparse).apply( + sliding.sliding_window_batch(window_size=5, window_shift=3))) init_op = iterator.initializer get_next = iterator.get_next() @@ -265,11 +265,10 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): return sparse_tensor.SparseTensorValue( indices=[[0]], values=(i * [1]), dense_shape=[1]) - iterator = ( + iterator = dataset_ops.make_initializable_iterator( dataset_ops.Dataset.range(10).map(_sparse).apply( sliding.sliding_window_batch(window_size=4, window_shift=2)).apply( - sliding.sliding_window_batch(window_size=3, window_shift=1)) - .make_initializable_iterator()) + sliding.sliding_window_batch(window_size=3, window_shift=1))) init_op = iterator.initializer get_next = iterator.get_next() @@ -305,11 +304,10 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): yield [4.0, 5.0, 6.0] yield [7.0, 8.0, 9.0, 10.0] - iterator = ( + iterator = dataset_ops.make_initializable_iterator( dataset_ops.Dataset.from_generator( generator, dtypes.float32, output_shapes=[None]).apply( - sliding.sliding_window_batch(window_size=3, window_shift=1)) - .make_initializable_iterator()) + sliding.sliding_window_batch(window_size=3, window_shift=1))) next_element = iterator.get_next() with self.cached_session() as sess: diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 34dc2379d0c..0fb406f1167 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -188,8 +188,7 @@ py_library( "//tensorflow/python:framework_ops", "//tensorflow/python:function", "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:nest", - "//tensorflow/python/data/util:sparse", + "//tensorflow/python/data/util:structure", ], ) diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index 4601376dff4..aa42782807a 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -355,7 +355,7 @@ def read_batch_features(file_pattern, shuffle=randomize_input, num_epochs=num_epochs, shuffle_buffer_size=capacity) - iterator = dataset.make_one_shot_iterator() + iterator = dataset_ops.make_one_shot_iterator(dataset) outputs = iterator.get_next() return outputs @@ -379,15 +379,13 @@ class LMDBDataset(dataset_ops.DatasetSource): (key value) pairs sequentially. For example: ```python + tf.enable_eager_execution() + dataset = tf.contrib.lmdb.LMDBDataset("/foo/bar.mdb") - iterator = dataset.make_one_shot_iterator() - next_element = iterator.get_next() + # Prints the (key, value) pairs inside a lmdb file. - while True: - try: - print(sess.run(next_element)) - except tf.errors.OutOfRangeError: - break + for key, value in dataset: + print(key, value) ``` Args: filenames: A `tf.string` tensor containing one or more filenames. diff --git a/tensorflow/contrib/data/python/ops/sliding.py b/tensorflow/contrib/data/python/ops/sliding.py index bcc383587c5..9ebdca317f2 100644 --- a/tensorflow/contrib/data/python/ops/sliding.py +++ b/tensorflow/contrib/data/python/ops/sliding.py @@ -18,11 +18,10 @@ from __future__ import division from __future__ import print_function from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest +from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops from tensorflow.python.util import deprecation @@ -40,29 +39,31 @@ class _SlideDataset(dataset_ops.UnaryDataset): self._window_shift = ops.convert_to_tensor( window_shift, dtype=dtypes.int64, name="window_shift") + # pylint: disable=protected-access + input_structure = structure.Structure._from_legacy_structure( + input_dataset.output_types, input_dataset.output_shapes, + input_dataset.output_classes) + self._output_structure = input_structure._batch(None) + def _as_variant_tensor(self): - return gen_dataset_ops.slide_dataset( + return ged_ops.experimental_sliding_window_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access window_size=self._window_size, window_shift=self._window_shift, window_stride=self._window_stride, - **dataset_ops.flat_structure(self)) + **dataset_ops.flat_structure(structure=self._output_structure)) @property def output_classes(self): - return self._input_dataset.output_classes + return self._output_structure._to_legacy_output_classes() # pylint: disable=protected-access @property def output_shapes(self): - input_shapes = self._input_dataset.output_shapes - return nest.pack_sequence_as(input_shapes, [ - tensor_shape.vector(None).concatenate(s) - for s in nest.flatten(self._input_dataset.output_shapes) - ]) + return self._output_structure._to_legacy_output_shapes() # pylint: disable=protected-access @property def output_types(self): - return self._input_dataset.output_types + return self._output_structure._to_legacy_output_types() # pylint: disable=protected-access @deprecation.deprecated_args( diff --git a/tensorflow/contrib/distribute/BUILD b/tensorflow/contrib/distribute/BUILD index a87a5624c88..3ecd755d86f 100644 --- a/tensorflow/contrib/distribute/BUILD +++ b/tensorflow/contrib/distribute/BUILD @@ -26,7 +26,6 @@ py_library( visibility = ["//tensorflow:internal"], deps = [ "//tensorflow/contrib/distribute/python:collective_all_reduce_strategy", - "//tensorflow/contrib/distribute/python:cross_tower_ops", "//tensorflow/contrib/distribute/python:mirrored_strategy", "//tensorflow/contrib/distribute/python:monitor", "//tensorflow/contrib/distribute/python:one_device_strategy", @@ -35,6 +34,7 @@ py_library( "//tensorflow/contrib/distribute/python:tpu_strategy", "//tensorflow/python:training", "//tensorflow/python:util", + "//tensorflow/python/distribute:cross_device_ops", "//tensorflow/python/distribute:distribute_config", "//tensorflow/python/distribute:distribute_coordinator", ], diff --git a/tensorflow/contrib/distribute/README.md b/tensorflow/contrib/distribute/README.md index a938f8629d8..81574a2047e 100644 --- a/tensorflow/contrib/distribute/README.md +++ b/tensorflow/contrib/distribute/README.md @@ -134,7 +134,7 @@ def model_fn(features, labels, mode): return tf.estimator.EstimatorSpec(mode, loss=loss) if mode == tf.estimator.ModeKeys.TRAIN: - train_op = tf.train.GradientDescentOptimizer(0.2).minimize(loss_fn()) + train_op = tf.train.GradientDescentOptimizer(0.2).minimize(loss) return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op) ``` @@ -251,10 +251,10 @@ start multi-worker training using `tf.estimator.train_and_evaluate`: ```python def model_main(): - estimator = ... distribution = tf.contrib.distribute.CollectiveAllReduceStrategy( num_gpus_per_worker=2) config = tf.estimator.RunConfig(train_distribute=distribution) + estimator = tf.estimator.Estimator(model_fn=model_fn, config=config) train_spec = tf.estimator.TrainSpec(input_fn=input_fn) eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn) tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) @@ -327,13 +327,13 @@ start training. On your laptop, you can run ```python -estimator = ... distribution = tf.contrib.distribute.CollectiveAllReduceStrategy( num_gpus_per_worker=2) config = tf.estimator.RunConfig( experimental_distribute=tf.contrib.distribute.DistributeConfig( train_distribute=distribution, remote_cluster={"worker": ["host1:port", "host2:port", "host3:port"]})) +estimator = tf.estimator.Estimator(model_fn=model_fn, config=config) train_spec = tf.estimator.TrainSpec(input_fn=input_fn) eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn) tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) diff --git a/tensorflow/contrib/distribute/__init__.py b/tensorflow/contrib/distribute/__init__.py index ab2f221dc64..8ec73654e30 100644 --- a/tensorflow/contrib/distribute/__init__.py +++ b/tensorflow/contrib/distribute/__init__.py @@ -25,13 +25,13 @@ from __future__ import print_function # pylint: disable=unused-import,wildcard-import from tensorflow.contrib.distribute.python.collective_all_reduce_strategy import CollectiveAllReduceStrategy -from tensorflow.contrib.distribute.python.cross_tower_ops import * from tensorflow.contrib.distribute.python.mirrored_strategy import MirroredStrategy from tensorflow.contrib.distribute.python.monitor import Monitor from tensorflow.contrib.distribute.python.one_device_strategy import OneDeviceStrategy from tensorflow.contrib.distribute.python.parameter_server_strategy import ParameterServerStrategy from tensorflow.contrib.distribute.python.step_fn import * from tensorflow.contrib.distribute.python.tpu_strategy import TPUStrategy +from tensorflow.python.distribute.cross_device_ops import * from tensorflow.python.distribute.distribute_config import DistributeConfig from tensorflow.python.distribute.distribute_coordinator import run_standard_tensorflow_server from tensorflow.python.training.distribute import * @@ -46,6 +46,7 @@ _allowed_symbols = [ 'CrossDeviceOps', 'DistributeConfig', 'DistributionStrategy', + 'DistributionStrategyExtended', 'MirroredStrategy', 'Monitor', 'MultiWorkerAllReduce', @@ -62,6 +63,7 @@ _allowed_symbols = [ 'get_loss_reduction', 'get_replica_context', 'has_distribution_strategy', + 'in_cross_replica_context', 'require_replica_context', 'run_standard_tensorflow_server', 'UpdateContext', diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index 4094e52169a..4c9c35da5a3 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -16,45 +16,26 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test") # TODO(priyag): Figure out testonly issues that are preventing us from # including our tests in pip for now. -py_library( - name = "values", - srcs = ["values.py"], - visibility = ["//tensorflow:internal"], - deps = [ - ":input_ops", - "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:device_util", - "//tensorflow/python:distribute", - "//tensorflow/python:framework_ops", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python/data/ops:multi_device_iterator_ops", - "//tensorflow/python/eager:context", - "//tensorflow/python/training/checkpointable:base", - "@six_archive//:six", - ], -) - cuda_py_test( name = "values_test", srcs = ["values_test.py"], additional_deps = [ + ":combinations", ":mirrored_strategy", ":multi_worker_test_base", - ":values", + "@absl_py//absl/testing:parameterized", "//tensorflow/core:protos_all_py", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python:errors", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", + "//tensorflow/python:errors", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", "//tensorflow/python:training", "//tensorflow/python:variable_scope", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/distribute:device_util", + "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", - "//tensorflow/python:device_util", "//tensorflow/python/eager:test", "//tensorflow/python/estimator:estimator_py", ], @@ -68,25 +49,9 @@ py_library( srcs = ["mirrored_strategy.py"], visibility = ["//tensorflow:internal"], deps = [ - ":cross_tower_ops", - ":shared_variable_creator", - ":values", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:constant_op", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:device", - "//tensorflow/python:device_util", - "//tensorflow/python:distribute", - "//tensorflow/python:framework_ops", - "//tensorflow/python:pywrap_tensorflow", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "//tensorflow/python/distribute:multi_worker_util", - "//tensorflow/python/eager:context", - "//tensorflow/python/eager:tape", + "//tensorflow/python/distribute:distribute_lib", + "//tensorflow/python/distribute:mirrored_strategy", + "//tensorflow/python/distribute:values", ], ) @@ -95,16 +60,17 @@ py_library( srcs = ["parameter_server_strategy.py"], visibility = ["//tensorflow:internal"], deps = [ - ":cross_tower_ops", ":mirrored_strategy", - ":values", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:framework_ops", "//tensorflow/python:resource_variable_ops", "//tensorflow/python:training", "//tensorflow/python:util", + "//tensorflow/python/distribute:cross_device_ops", "//tensorflow/python/distribute:multi_worker_util", + "//tensorflow/python/distribute:reduce_util", + "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", ], ) @@ -116,7 +82,7 @@ cuda_py_test( ":combinations", ":multi_worker_test_base", ":parameter_server_strategy", - ":values", + ":strategy_test_lib", "@absl_py//absl/testing:parameterized", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -127,10 +93,12 @@ cuda_py_test( "//tensorflow/python:gradients", "//tensorflow/python:layers", "//tensorflow/python:session", + "//tensorflow/python:tensor_util", "//tensorflow/python:training", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", "//tensorflow/python/distribute:multi_worker_util", + "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", "//tensorflow/python/estimator:estimator_py", ], @@ -145,12 +113,13 @@ py_library( srcs = ["one_device_strategy.py"], visibility = ["//tensorflow:internal"], deps = [ - ":values", - "//tensorflow/contrib/eager/python:datasets", "//tensorflow/python:array_ops", - "//tensorflow/python:distribute", + "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", + "//tensorflow/python/distribute:distribute_lib", + "//tensorflow/python/distribute:reduce_util", + "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", "@six_archive//:six", ], @@ -161,16 +130,16 @@ py_library( srcs = ["collective_all_reduce_strategy.py"], visibility = ["//tensorflow:internal"], deps = [ - ":cross_tower_ops", - ":cross_tower_utils", ":mirrored_strategy", - ":values", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:collective_ops", "//tensorflow/python:framework_ops", "//tensorflow/python:training", + "//tensorflow/python/distribute:cross_device_ops", + "//tensorflow/python/distribute:cross_device_utils", "//tensorflow/python/distribute:multi_worker_util", + "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", ], ) @@ -187,11 +156,11 @@ py_library( "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", - "//tensorflow/python:distribute", "//tensorflow/python:framework_ops", "//tensorflow/python:layers", "//tensorflow/python:training", "//tensorflow/python:variables", + "//tensorflow/python/distribute:distribute_lib", "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", @@ -212,10 +181,10 @@ py_library( ":tpu_strategy", "//tensorflow/contrib/cluster_resolver:cluster_resolver_pip", "//tensorflow/contrib/optimizer_v2:training", - "//tensorflow/python:distribute", "//tensorflow/python:framework_ops", "//tensorflow/python:training", "//tensorflow/python:util", + "//tensorflow/python/distribute:distribute_lib", "//tensorflow/python/eager:context", "@absl_py//absl/testing:parameterized", ], @@ -233,28 +202,6 @@ py_test( ], ) -py_test( - name = "mirrored_strategy_test", - srcs = ["mirrored_strategy_test.py"], - srcs_version = "PY2AND3", - tags = [ - "no_pip", - ], - deps = [ - ":mirrored_strategy", - ":multi_worker_test_base", - ":strategy_test_lib", - "//tensorflow/python:constant_op", - "//tensorflow/python:distribute", - "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python/eager:context", - "//tensorflow/python/eager:test", - ], -) - py_test( name = "one_device_strategy_test", srcs = ["one_device_strategy_test.py"], @@ -270,35 +217,32 @@ py_test( ], ) +# TODO(priyag): Rename this test to mirrored_strategy_test cuda_py_test( name = "mirrored_strategy_multigpu_test", srcs = ["mirrored_strategy_multigpu_test.py"], additional_deps = [ + ":combinations", ":mirrored_strategy", ":multi_worker_test_base", - ":values", ":strategy_test_lib", - "//tensorflow/python:distribute", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", + "//tensorflow/python:framework_test_lib", "//tensorflow/python:layers", "//tensorflow/python:state_ops", "//tensorflow/python:variable_scope", - "//tensorflow/python:framework_test_lib", + "//tensorflow/python/distribute:distribute_lib", + "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", ], + shard_count = 5, tags = [ "guitar", - "no_pip", "multi_and_single_gpu", - # Do not perform the extra analysis on this test, because it is already - # performed for the `:mirrored_strategy_test` target. - "no_oss", - "noasan", - "notap", - "notsan", + "no_pip", ], ) @@ -337,12 +281,15 @@ py_library( visibility = ["//tensorflow:internal"], deps = [ ":one_device_strategy", - ":values", "//tensorflow/contrib/tpu:tpu_lib", "//tensorflow/python:constant_op", "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:tensor_util", "//tensorflow/python:util", + "//tensorflow/python/distribute:reduce_util", + "//tensorflow/python/distribute:values", ], ) @@ -352,7 +299,6 @@ cuda_py_test( additional_deps = [ ":collective_all_reduce_strategy", ":combinations", - ":cross_tower_utils", ":multi_worker_test_base", ":strategy_test_lib", "@absl_py//absl/testing:parameterized", @@ -368,6 +314,7 @@ cuda_py_test( "//tensorflow/python:layers", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", + "//tensorflow/python/distribute:cross_device_utils", "//tensorflow/python/eager:context", "//tensorflow/python/estimator:estimator_py", ], @@ -469,6 +416,7 @@ cuda_py_test( "multi_and_single_gpu", "no_oss", # http://b/119349471 "no_pip", + "tf_integration_test", ], ) @@ -476,28 +424,18 @@ cuda_py_test( name = "keras_optimizer_v2_test", srcs = ["keras_optimizer_v2_test.py"], additional_deps = [ - ":combinations", - "@absl_py//absl/testing:parameterized", - "//third_party/py/numpy", - "//tensorflow/contrib/optimizer_v2:training", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/eager:test", - "//tensorflow/python/estimator:estimator_py", - "//tensorflow/python/feature_column", - "//tensorflow/python:framework_ops", - "//tensorflow/python:platform", - "//tensorflow/python:summary", + ":keras_test_lib", ], tags = [ "multi_and_single_gpu", "no_oss", # http://b/119349471 "no_pip", + "tf_integration_test", ], ) cuda_py_test( name = "estimator_training_test", - size = "large", srcs = ["estimator_training_test.py"], additional_deps = [ ":collective_all_reduce_strategy", @@ -508,7 +446,9 @@ cuda_py_test( "//third_party/py/numpy", "//tensorflow/contrib/optimizer_v2:training", "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/distribute", + "//tensorflow/python/distribute:distribute_config", + "//tensorflow/python/distribute:distribute_coordinator", + "//tensorflow/python/distribute:distribute_coordinator_context", "//tensorflow/python/eager:test", "//tensorflow/python/estimator:estimator_py", "//tensorflow/python/feature_column", @@ -516,7 +456,7 @@ cuda_py_test( "//tensorflow/python:platform", "//tensorflow/python:summary", ], - shard_count = 5, + shard_count = 48, tags = [ "multi_and_single_gpu", "no_pip", @@ -524,6 +464,7 @@ cuda_py_test( "noasan", "nomsan", "notsan", + "no_oss", # http://b/119349471 ], ) @@ -599,52 +540,16 @@ cuda_py_test( ], ) -py_library( - name = "shared_variable_creator", - srcs = ["shared_variable_creator.py"], - visibility = ["//tensorflow:internal"], -) - -py_test( - name = "shared_variable_creator_test", - srcs = ["shared_variable_creator_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":shared_variable_creator", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:variable_scope", - "//tensorflow/python/eager:test", - ], -) - -py_library( - name = "cross_tower_utils", - srcs = ["cross_tower_utils.py"], - srcs_version = "PY2AND3", - deps = [ - ":values", - "//tensorflow/contrib/all_reduce:all_reduce_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:collective_ops", - "//tensorflow/python:device", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:gradients", - "//tensorflow/python:math_ops", - "//tensorflow/python:nccl_ops", - ], -) - cuda_py_test( - name = "cross_tower_utils_test", - srcs = ["cross_tower_utils_test.py"], + name = "cross_device_utils_test", + srcs = ["cross_device_utils_test.py"], additional_deps = [ ":combinations", - ":cross_tower_utils", "@absl_py//absl/testing:parameterized", "//tensorflow/python:constant_op", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", + "//tensorflow/python/distribute:cross_device_utils", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", ], @@ -653,40 +558,20 @@ cuda_py_test( ], ) -py_library( - name = "cross_tower_ops", - srcs = ["cross_tower_ops.py"], - srcs_version = "PY2AND3", - deps = [ - ":cross_tower_utils", - ":values", - "//tensorflow/python:array_ops", - "//tensorflow/python:device_lib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python/eager:context", - "@six_archive//:six", - ], -) - cuda_py_test( - name = "cross_tower_ops_test", - srcs = ["cross_tower_ops_test.py"], + name = "cross_device_ops_test", + srcs = ["cross_device_ops_test.py"], additional_deps = [ ":combinations", - ":cross_tower_ops", ":multi_worker_test_base", ":mirrored_strategy", - ":values", "@absl_py//absl/testing:parameterized", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", + "//tensorflow/python/distribute:cross_device_ops", + "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", ], @@ -696,37 +581,6 @@ cuda_py_test( ], ) -py_library( - name = "input_ops", - srcs = ["input_ops.py"], - visibility = ["//tensorflow:internal"], - deps = [ - "//tensorflow/python:framework_ops", - "//tensorflow/python/data/util:nest", - ], -) - -cuda_py_test( - name = "input_ops_test", - srcs = ["input_ops_test.py"], - additional_deps = [ - ":input_ops", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:batching", - "//tensorflow/contrib/data/python/ops:interleave_ops", - "//tensorflow/python:errors", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:io_ops", - "//tensorflow/python/data/ops:readers", - "//tensorflow/python:util", - ], - tags = [ - "no_pip", - ], -) - py_library( name = "keras_test_lib", testonly = 1, @@ -737,6 +591,7 @@ py_library( "//tensorflow/contrib/distribute/python:tpu_strategy", "//tensorflow/python:client_testlib", "//tensorflow/python:training", + "//tensorflow/python/eager:test", "//tensorflow/python/estimator:estimator_py", "//tensorflow/python/keras", "//third_party/py/numpy", @@ -766,7 +621,6 @@ py_library( srcs = ["metrics_v1_test.py"], deps = [ ":combinations", - "//tensorflow/contrib/data/python/ops:batching", "//tensorflow/python:math_ops", "//tensorflow/python:metrics", "//tensorflow/python:variables", diff --git a/tensorflow/contrib/distribute/python/checkpoint_utils_test.py b/tensorflow/contrib/distribute/python/checkpoint_utils_test.py index d38bdb592a3..31bd0e996a2 100644 --- a/tensorflow/contrib/distribute/python/checkpoint_utils_test.py +++ b/tensorflow/contrib/distribute/python/checkpoint_utils_test.py @@ -43,7 +43,9 @@ class CheckpointUtilsWithDistributionStrategyTest( distribution=[combinations.default_strategy, combinations.one_device_strategy, combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus], + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus], in_replica_mode=[True, False], mode=["graph"])) def testInitFromCheckpoint(self, distribution, in_replica_mode): diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py index efa99d1fc52..e988b63a287 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py @@ -18,12 +18,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib -from tensorflow.contrib.distribute.python import cross_tower_utils +import copy + from tensorflow.contrib.distribute.python import mirrored_strategy -from tensorflow.contrib.distribute.python import values from tensorflow.core.protobuf import rewriter_config_pb2 +from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib +from tensorflow.python.distribute import cross_device_utils +from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import multi_worker_util +from tensorflow.python.distribute import values from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -32,7 +36,7 @@ from tensorflow.python.platform import tf_logging as logging # TODO(yuefengz): support in-graph replication. -class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): +class CollectiveAllReduceStrategy(distribute_lib.DistributionStrategy): """Distribution strategy that uses collective ops for all-reduce. It is similar to the MirroredStrategy but it uses collective ops for @@ -53,6 +57,17 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): num_gpus_per_worker: number of local GPUs or GPUs per worker, the default is 0 meaning CPU only. """ + super(CollectiveAllReduceStrategy, self).__init__( + CollectiveAllReduceExtended(self, num_gpus_per_worker)) + + +class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): + """Implementation of CollectiveAllReduceStrategy.""" + + def __init__(self, container_strategy, num_gpus_per_worker): + distribute_lib.DistributionStrategyExtended.__init__( + self, container_strategy) + self._cross_device_ops = None self._num_gpus_per_worker = num_gpus_per_worker self._initialize_local_worker(num_gpus_per_worker) @@ -67,14 +82,14 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): ] else: local_devices = ["/device:CPU:0"] + self._worker_device = device_util.canonicalize("/device:CPU:0") - self._collective_keys = cross_tower_utils.CollectiveKeys() - super(CollectiveAllReduceStrategy, self).__init__( - devices=local_devices, - cross_tower_ops=cross_tower_ops_lib.CollectiveAllReduce( - num_workers=1, - num_gpus_per_worker=num_gpus_per_worker, - collective_keys=self._collective_keys)) + self._collective_keys = cross_device_utils.CollectiveKeys() + self._initialize_local(local_devices) + self._cross_tower_ops = cross_device_ops_lib.CollectiveAllReduce( + num_workers=self._num_workers, + num_gpus_per_worker=num_gpus_per_worker, + collective_keys=self._collective_keys) self._cluster_spec = None self._task_type = None @@ -94,8 +109,7 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): "Unrecognized task_type: %r, valid task types are: \"chief\", " "\"worker\"." % task_type) cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec) - self._num_workers = len(cluster_spec.as_dict().get("worker", [])) + len( - cluster_spec.as_dict().get("chief", [])) + self._num_workers = multi_worker_util.worker_count(cluster_spec, task_type) if not self._num_workers: raise ValueError("No `worker` or `chief` tasks can be found in " "`cluster_spec`.") @@ -103,22 +117,21 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type, task_id) - worker_device = "/job:%s/task:%d" % (task_type, task_id) + self._worker_device = "/job:%s/task:%d" % (task_type, task_id) if num_gpus_per_worker: local_devices = [ - "%s/device:GPU:%d" % (worker_device, i) + "%s/device:GPU:%d" % (self._worker_device, i) for i in range(num_gpus_per_worker) ] else: - local_devices = [worker_device] + local_devices = [self._worker_device] - self._collective_keys = cross_tower_utils.CollectiveKeys() - super(CollectiveAllReduceStrategy, self).__init__( - devices=local_devices, - cross_tower_ops=cross_tower_ops_lib.CollectiveAllReduce( - num_workers=self._num_workers, - num_gpus_per_worker=num_gpus_per_worker, - collective_keys=self._collective_keys)) + self._collective_keys = cross_device_utils.CollectiveKeys() + self._initialize_local(local_devices) + self._cross_tower_ops = cross_device_ops_lib.CollectiveAllReduce( + num_workers=self._num_workers, + num_gpus_per_worker=num_gpus_per_worker, + collective_keys=self._collective_keys) # Add a default device so that ops without specified devices will not end up # on other workers. @@ -202,17 +215,40 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): return mirrored_strategy._create_mirrored_variable( devices, _real_mirrored_creator, *args, **kwargs) - def distribute_dataset(self, dataset_fn): + def _distribute_dataset(self, dataset_fn): """Distributes the dataset to each local GPU.""" # TODO(yuefengz): shard the dataset. return values.PerReplicaDataset( self._call_dataset_fn(dataset_fn), self._devices, True) - def configure(self, - session_config=None, - cluster_spec=None, - task_type=None, - task_id=None): + def _make_dataset_iterator(self, dataset): + worker_device_pairs = [(self._worker_device, self._devices)] + return values.DatasetIterator(dataset, worker_device_pairs, + self._num_replicas_in_sync) + + def _make_input_fn_iterator( + self, + input_fn, + replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): + """Distributes the dataset to each local GPU.""" + if self._cluster_spec is None: + input_pipeline_id = 0 + else: + input_pipeline_id = multi_worker_util.id_in_cluster( + self._cluster_spec, self._task_type, self._task_id) + input_context = distribute_lib.InputContext( + num_input_pipelines=self._num_workers, + input_pipeline_id=input_pipeline_id, + num_replicas_in_sync=self._num_replicas_in_sync) + + return values.InputFunctionIterator( + input_fn, [(self._worker_device, self._devices)], [input_context]) + + def _configure(self, + session_config=None, + cluster_spec=None, + task_type=None, + task_id=None): """Configures the object. Args: @@ -232,13 +268,15 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): self._initialize_multi_worker(self._num_gpus_per_worker, cluster_spec, task_type, task_id) - if not session_config: - return + if session_config: + session_config.CopyFrom(self._update_config_proto(session_config)) + def _update_config_proto(self, config_proto): + updated_config = copy.deepcopy(config_proto) # Enable the scoped allocator optimization for CollectiveOps. This # optimization converts many small all-reduces into fewer larger # all-reduces. - rewrite_options = session_config.graph_options.rewrite_options + rewrite_options = updated_config.graph_options.rewrite_options rewrite_options.scoped_allocator_optimization = ( rewriter_config_pb2.RewriterConfig.ON) # We turn on ScopedAllocator only for CollectiveReduce op, i.e. enable_op = @@ -248,7 +286,7 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): rewrite_options.scoped_allocator_opts.enable_op.append("CollectiveReduce") if not self._cluster_spec: - return + return updated_config assert self._task_type assert self._task_id is not None @@ -256,26 +294,28 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): # Collective group leader is needed for collective ops to coordinate # workers. if "chief" in self._cluster_spec.jobs: - session_config.experimental.collective_group_leader = ( + updated_config.experimental.collective_group_leader = ( "/job:chief/replica:0/task:0") else: if "worker" not in self._cluster_spec.jobs: raise ValueError( "You must have `chief` or `worker` jobs in the `cluster_spec`.") - session_config.experimental.collective_group_leader = ( + updated_config.experimental.collective_group_leader = ( "/job:worker/replica:0/task:0") # The device filters prevent communication between workers. - del session_config.device_filters[:] - session_config.device_filters.append( + del updated_config.device_filters[:] + updated_config.device_filters.append( "/job:%s/task:%d" % (self._task_type, self._task_id)) + return updated_config + @property - def between_graph(self): + def experimental_between_graph(self): return True @property - def should_init(self): + def experimental_should_init(self): return True @property @@ -287,6 +327,10 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): return self._is_chief @property - def num_replicas_in_sync(self): + def _num_replicas_in_sync(self): return len(self._devices) * self._num_workers + # TODO(priyag): Delete this once all strategies use global batch size. + @property + def _global_batch_size(self): + return False diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py index e3d919dd0d4..8a9e583f0af 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py @@ -23,13 +23,19 @@ import numpy as np from tensorflow.contrib.distribute.python import collective_all_reduce_strategy from tensorflow.contrib.distribute.python import combinations -from tensorflow.contrib.distribute.python import cross_tower_utils from tensorflow.contrib.distribute.python import multi_worker_test_base +from tensorflow.contrib.distribute.python import strategy_test_lib from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python import keras +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import cross_device_utils +from tensorflow.python.distribute import reduce_util +from tensorflow.python.distribute import values from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.layers import core from tensorflow.python.ops import array_ops @@ -51,9 +57,6 @@ class CollectiveAllReduceStrategyTestBase( collective_key_base = 0 def setUp(self): - self._run_options = config_pb2.RunOptions() - self._run_options.experimental.collective_graph_key = 6 - # We use a different key_base for each test so that collective keys won't be # reused. # TODO(yuefengz, tucker): enable it to reuse collective keys in different @@ -71,15 +74,16 @@ class CollectiveAllReduceStrategyTestBase( cluster_spec=self._cluster_spec, task_type=task_type, task_id=task_id) - collective_keys = cross_tower_utils.CollectiveKeys( + collective_keys = cross_device_utils.CollectiveKeys( group_key_start=10 * num_gpus + CollectiveAllReduceStrategyTestBase.collective_key_base, instance_key_start=num_gpus * 100 + CollectiveAllReduceStrategyTestBase.collective_key_base, instance_key_with_id_start=num_gpus * 10000 + CollectiveAllReduceStrategyTestBase.collective_key_base) - distribution._collective_keys = collective_keys - distribution._cross_tower_ops._collective_keys = collective_keys + distribution.extended._collective_keys = collective_keys + distribution.extended._inferred_cross_device_ops._collective_keys = ( + collective_keys) if task_type and task_id is not None: return distribution, 'grpc://' + self._cluster_spec[task_type][ task_id], session_config @@ -93,7 +97,8 @@ class CollectiveAllReduceStrategyTestBase( self.cached_session(config=config, target=master_target) as sess, \ d.scope(): - l = core.Dense(1, use_bias=False, name='gpu_%d' % d._num_gpus_per_worker) + l = core.Dense(1, use_bias=False, + name='gpu_%d' % d.extended._num_gpus_per_worker) def loss_fn(x): y = array_ops.reshape(l(x), []) - constant_op.constant(1.) @@ -127,8 +132,8 @@ class CollectiveAllReduceStrategyTestBase( before_list.append(fetched) with ops.control_dependencies([fetched]): # TODO(yuefengz): support non-Mirrored variable as destinations. - g = d.reduce( - variable_scope.VariableAggregation.SUM, g, destinations=v) + g = d.extended.reduce_to( + reduce_util.ReduceOp.SUM, g, destinations=v) with ops.control_dependencies( d.update(v, update, g, grouped=False)): after_list.append(d.read_var(v)) @@ -136,14 +141,13 @@ class CollectiveAllReduceStrategyTestBase( before_out, after_out = step() - if context.num_gpus() < d._num_gpus_per_worker: + if context.num_gpus() < d.extended._num_gpus_per_worker: return True - sess.run( - variables.global_variables_initializer(), options=self._run_options) + sess.run(variables.global_variables_initializer()) for i in range(10): - b, a = sess.run((before_out, after_out), options=self._run_options) + b, a = sess.run((before_out, after_out)) if i == 0: before, = b after, = a @@ -222,26 +226,54 @@ class CollectiveAllReduceStrategyTestBase( return array_ops.identity(x) x = distribution.call_for_each_replica(model_fn) - reduced_x = distribution.unwrap( - distribution.reduce( - variable_scope.VariableAggregation.MEAN, x, - destinations='/cpu:0'))[0] + reduced_x = distribution.reduce(reduce_util.ReduceOp.MEAN, x) x = distribution.unwrap(x)[0] - sess.run( - variables.global_variables_initializer(), options=self._run_options) + sess.run(variables.global_variables_initializer()) - x_value, reduced_x_value = sess.run([x, reduced_x], - options=self._run_options) + x_value, reduced_x_value = sess.run([x, reduced_x]) self.assertTrue( np.allclose(x_value, reduced_x_value, atol=1e-5), msg=('x_value = %r, reduced_x_value = %r' % (x_value, reduced_x_value))) return np.allclose(x_value, reduced_x_value, atol=1e-5) + def _test_input_fn_iterator(self, task_type, task_id, num_gpus, input_fn, + expected_values): + distribution, master_target, config = self._get_test_object( + task_type, task_id, num_gpus) + devices = distribution.extended.worker_devices + + with ops.Graph().as_default(), \ + self.cached_session(config=config, + target=master_target) as sess: + iterator = distribution.make_input_fn_iterator(input_fn) + sess.run(iterator.initialize()) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = sess.run( + [values.select_device(d, next_element) for d in devices]) + self.assertEqual(expected_value, computed_value) + + with self.assertRaises(errors.OutOfRangeError): + next_element = iterator.get_next() + sess.run([values.select_device(d, next_element) for d in devices]) + + # After re-initializing the iterator, should be able to iterate again. + sess.run(iterator.initialize()) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = sess.run( + [values.select_device(d, next_element) for d in devices]) + self.assertEqual(expected_value, computed_value) + class DistributedCollectiveAllReduceStrategyTest( - CollectiveAllReduceStrategyTestBase, parameterized.TestCase): + CollectiveAllReduceStrategyTestBase, + strategy_test_lib.DistributionTestBase, + parameterized.TestCase): @classmethod def setUpClass(cls): @@ -269,7 +301,7 @@ class DistributedCollectiveAllReduceStrategyTest( combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1)) def testVariableInitialization(self, num_gpus): if context.num_gpus() < num_gpus: - return + self.skipTest('Not enough GPUs') self._run_between_graph_clients( self._test_variable_initialization, self._cluster_spec, @@ -279,10 +311,56 @@ class DistributedCollectiveAllReduceStrategyTest( combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1)) def testComplexModel(self, num_gpus): if context.num_gpus() < num_gpus: - return + self.skipTest('Not enough GPUs') self._run_between_graph_clients( self._test_complex_model, self._cluster_spec, num_gpus=num_gpus) + # TODO(yuefengz): Update how we use num_gpus and required_gpus + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1)) + def testMakeInputFnIterator(self, num_gpus): + if context.num_gpus() < num_gpus: + self.skipTest('Not enough GPUs') + dataset_fn = lambda: dataset_ops.Dataset.range(100) + # We use CPU as the device when num_gpus = 0 + devices_per_worker = max(1, num_gpus) + expected_values = [[i+j for j in range(devices_per_worker)] + for i in range(0, 100, devices_per_worker)] + + input_fn = self._input_fn_to_test_input_context( + dataset_fn, + expected_num_replicas_in_sync=3*devices_per_worker, + expected_num_input_pipelines=3, + expected_input_pipeline_id=1) # because task_id = 1 + self._test_input_fn_iterator('worker', 1, num_gpus, + input_fn, expected_values) + + def testUpdateConfigProto(self): + distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy( + num_gpus_per_worker=2) + distribution.configure( + cluster_spec=self._cluster_spec, task_type='worker', task_id=1) + + config_proto = config_pb2.ConfigProto(device_filters=['to_be_overridden']) + rewrite_options = config_proto.graph_options.rewrite_options + rewrite_options.scoped_allocator_opts.enable_op.append('to_be_removed') + + new_config = distribution.update_config_proto(config_proto) + + # Verify group leader + self.assertEqual('/job:worker/replica:0/task:0', + new_config.experimental.collective_group_leader) + + # Verify device filters. + self.assertEqual(['/job:worker/task:1'], new_config.device_filters) + + # Verify rewrite options. + new_rewrite_options = new_config.graph_options.rewrite_options + self.assertEqual(rewriter_config_pb2.RewriterConfig.ON, + new_rewrite_options.scoped_allocator_optimization) + self.assertEqual(['CollectiveReduce'], + new_rewrite_options.scoped_allocator_opts.enable_op) + class DistributedCollectiveAllReduceStrategyTestWithChief( CollectiveAllReduceStrategyTestBase, parameterized.TestCase): @@ -293,10 +371,6 @@ class DistributedCollectiveAllReduceStrategyTestWithChief( cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( num_workers=3, num_ps=0, has_chief=True) - def setUp(self): - super(DistributedCollectiveAllReduceStrategyTestWithChief, self).setUp() - self._run_options.experimental.collective_graph_key = 7 - @combinations.generate( combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1)) def testMinimizeLossGraph(self, num_gpus): @@ -323,20 +397,36 @@ class DistributedCollectiveAllReduceStrategyTestWithChief( class LocalCollectiveAllReduceStrategy(CollectiveAllReduceStrategyTestBase, + strategy_test_lib.DistributionTestBase, parameterized.TestCase): def testMinimizeLossGraph(self, num_gpus=2): # Collective ops doesn't support strategy with one device. if context.num_gpus() < num_gpus: - return + self.skipTest('Not enough GPUs') self._test_minimize_loss_graph(None, None, num_gpus) def testComplexModel(self, num_gpus=2): # Collective ops doesn't support strategy with one device. if context.num_gpus() < num_gpus: - return + self.skipTest('Not enough GPUs') self._test_complex_model(None, None, num_gpus) + def testMakeInputFnIterator(self, num_gpus=2): + # Collective ops doesn't support strategy with one device. + if context.num_gpus() < num_gpus: + self.skipTest('Not enough GPUs') + dataset_fn = lambda: dataset_ops.Dataset.range(10) + expected_values = [[i, i+1] for i in range(0, 10, 2)] + + input_fn = self._input_fn_to_test_input_context( + dataset_fn, + expected_num_replicas_in_sync=num_gpus, + expected_num_input_pipelines=1, + expected_input_pipeline_id=0) + self._test_input_fn_iterator(None, None, num_gpus, + input_fn, expected_values) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py index a5137165403..365ce5cdec7 100644 --- a/tensorflow/contrib/distribute/python/combinations.py +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -53,11 +53,11 @@ from tensorflow.contrib.distribute.python import tpu_strategy as tpu_lib from tensorflow.contrib.optimizer_v2 import adagrad as adagrad_v2 from tensorflow.contrib.optimizer_v2 import adam as adam_v2 from tensorflow.contrib.optimizer_v2 import gradient_descent as gradient_descent_v2 +from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.training import adagrad from tensorflow.python.training import adam -from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import gradient_descent from tensorflow.python.training import rmsprop from tensorflow.python.util import tf_inspect @@ -168,6 +168,8 @@ def _augment_with_special_arguments(test_method): if GPU_TEST: self.skipTest("Test that doesn't require GPUs.") elif context.num_gpus() < required_gpus: + # TODO(priyag): Consider allowing tests in graph mode using soft + # placement. self.skipTest( "{} GPUs are not available for this test. {} GPUs are available". format(required_gpus, context.num_gpus())) @@ -190,7 +192,7 @@ def _augment_with_special_arguments(test_method): kwargs_to_pass[arg] = kwargs[arg] if mode == "eager": - with ops.Graph().as_default(), context.eager_mode(): + with context.eager_mode(): if distribution: kwargs_to_pass["distribution"] = distribution.strategy test_method(**kwargs_to_pass) @@ -335,6 +337,13 @@ tpu_strategy_one_step = NamedDistribution( "TPUOneStep", lambda: tpu_lib.TPUStrategy( TPUClusterResolver(""), steps_per_run=1), required_tpu=True) +mirrored_strategy_with_one_cpu = NamedDistribution( + "Mirrored1CPU", + lambda: mirrored_lib.MirroredStrategy(["/cpu:0"])) +mirrored_strategy_with_one_gpu = NamedDistribution( + "Mirrored1GPU", + lambda: mirrored_lib.MirroredStrategy(["/gpu:0"]), + required_gpus=1) mirrored_strategy_with_gpu_and_cpu = NamedDistribution( "MirroredCPUAndGPU", lambda: mirrored_lib.MirroredStrategy(["/gpu:0", "/cpu:0"]), @@ -343,6 +352,21 @@ mirrored_strategy_with_two_gpus = NamedDistribution( "Mirrored2GPUs", lambda: mirrored_lib.MirroredStrategy(["/gpu:0", "/gpu:1"]), required_gpus=2) +core_mirrored_strategy_with_one_cpu = NamedDistribution( + "CoreMirrored1CPU", + lambda: mirrored_lib.CoreMirroredStrategy(["/cpu:0"])) +core_mirrored_strategy_with_one_gpu = NamedDistribution( + "CoreMirrored1GPU", + lambda: mirrored_lib.CoreMirroredStrategy(["/gpu:0"]), + required_gpus=1) +core_mirrored_strategy_with_gpu_and_cpu = NamedDistribution( + "CoreMirroredCPUAndGPU", + lambda: mirrored_lib.CoreMirroredStrategy(["/gpu:0", "/cpu:0"]), + required_gpus=1) +core_mirrored_strategy_with_two_gpus = NamedDistribution( + "CoreMirrored2GPUs", + lambda: mirrored_lib.CoreMirroredStrategy(["/gpu:0", "/gpu:1"]), + required_gpus=2) gradient_descent_optimizer_v1_fn = NamedObject( @@ -373,8 +397,11 @@ def distributions_and_v1_optimizers(): """A common set of combination with DistributionStrategies and Optimizers.""" return combine( distribution=[ - one_device_strategy, mirrored_strategy_with_gpu_and_cpu, - mirrored_strategy_with_two_gpus + one_device_strategy, + mirrored_strategy_with_gpu_and_cpu, + mirrored_strategy_with_two_gpus, + core_mirrored_strategy_with_gpu_and_cpu, + core_mirrored_strategy_with_two_gpus, ], optimizer_fn=optimizers_v1) @@ -383,7 +410,10 @@ def distributions_and_v2_optimizers(): """DistributionStrategies and V2 Optimizers.""" return combine( distribution=[ - one_device_strategy, mirrored_strategy_with_gpu_and_cpu, - mirrored_strategy_with_two_gpus + one_device_strategy, + mirrored_strategy_with_gpu_and_cpu, + mirrored_strategy_with_two_gpus, + core_mirrored_strategy_with_gpu_and_cpu, + core_mirrored_strategy_with_two_gpus, ], optimizer_fn=optimizers_v2) diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_device_ops_test.py similarity index 79% rename from tensorflow/contrib/distribute/python/cross_tower_ops_test.py rename to tensorflow/contrib/distribute/python/cross_device_ops_test.py index 3e274ba67ca..d6e9521c1c1 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py +++ b/tensorflow/contrib/distribute/python/cross_device_ops_test.py @@ -24,24 +24,24 @@ from absl.testing import parameterized import numpy as np from tensorflow.contrib.distribute.python import combinations -from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib -from tensorflow.contrib.distribute.python import cross_tower_utils from tensorflow.contrib.distribute.python import mirrored_strategy from tensorflow.contrib.distribute.python import multi_worker_test_base -from tensorflow.contrib.distribute.python import values as value_lib from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib +from tensorflow.python.distribute import cross_device_utils +from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import reduce_util +from tensorflow.python.distribute import values as value_lib from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import variable_scope as vs -from tensorflow.python.training import device_util def _make_per_replica(values, devices, regroup=False): - devices = cross_tower_ops_lib.get_devices_from(devices) + devices = cross_device_ops_lib.get_devices_from(devices) assert len(values) == len(devices) # We simulate the result of regroup called on PerReplica which strips the @@ -66,7 +66,7 @@ def _fake_mirrored(value, devices): All components of the returned Mirrored have the same objects, which is not true in reality. """ - devices = cross_tower_ops_lib.get_devices_from(devices) + devices = cross_device_ops_lib.get_devices_from(devices) return value_lib.Mirrored( {d: v for d, v in zip(devices, [value] * len(devices))}) @@ -118,8 +118,8 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase): self.assertEqual( sess.run(list(left._index.values())), list(right._index.values())) - def _testReductionAndBroadcast(self, cross_tower_ops, distribution): - devices = distribution.worker_devices + def _testReductionAndBroadcast(self, cross_device_ops, distribution): + devices = distribution.extended.worker_devices values = [constant_op.constant(float(d)) for d in range(len(devices))] per_replica = _make_per_replica(values, devices) @@ -132,35 +132,33 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase): destination_mirrored = _fake_mirrored(1., devices) destination_different = _fake_mirrored(1., _cpu_device) destination_str = _cpu_device - destination_list = devices all_destinations = [ destination_mirrored, destination_different, destination_str, - destination_list ] # test reduce() for destinations in all_destinations: self._assert_values_equal( - cross_tower_ops.reduce( - vs.VariableAggregation.MEAN, + cross_device_ops.reduce( + reduce_util.ReduceOp.MEAN, per_replica, destinations=destinations), _fake_mirrored(mean, destinations)) self._assert_values_equal( - cross_tower_ops.reduce( - vs.VariableAggregation.MEAN, + cross_device_ops.reduce( + reduce_util.ReduceOp.MEAN, per_replica_2, destinations=destinations), _fake_mirrored(mean_2, destinations)) self._assert_values_equal( - cross_tower_ops.reduce( - vs.VariableAggregation.SUM, per_replica, + cross_device_ops.reduce( + reduce_util.ReduceOp.SUM, per_replica, destinations=destinations), _fake_mirrored(mean * len(devices), destinations)) self._assert_values_equal( - cross_tower_ops.reduce( - vs.VariableAggregation.SUM, + cross_device_ops.reduce( + reduce_util.ReduceOp.SUM, per_replica_2, destinations=destinations), _fake_mirrored(mean_2 * len(devices), destinations)) @@ -168,16 +166,16 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase): # test batch_reduce() for d1, d2 in itertools.product(all_destinations, all_destinations): self._assert_values_equal( - cross_tower_ops.batch_reduce( - vs.VariableAggregation.MEAN, + cross_device_ops.batch_reduce( + reduce_util.ReduceOp.MEAN, [(per_replica, d1), (per_replica_2, d2)]), [ _fake_mirrored(mean, d1), _fake_mirrored(mean_2, d2) ]) self._assert_values_equal( - cross_tower_ops.batch_reduce( - vs.VariableAggregation.SUM, + cross_device_ops.batch_reduce( + reduce_util.ReduceOp.SUM, [(per_replica, d1), (per_replica_2, d2)]), [ _fake_mirrored(mean * len(devices), d1), @@ -187,7 +185,7 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase): # test broadcast() for destinations in all_destinations: self._assert_values_equal( - cross_tower_ops.broadcast(constant_op.constant(1.), destinations), + cross_device_ops.broadcast(constant_op.constant(1.), destinations), _fake_mirrored(1., destinations)) @@ -196,62 +194,65 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): # combinations module so that we can pass in devices instead of a distribution # strategy. reduction_to_one_combinations = combinations.combine( - cross_tower_ops=[ + cross_device_ops=[ combinations.NamedObject( "DefaultReductionToOneDeviceCrossDeviceOps", - cross_tower_ops_lib.ReductionToOneDeviceCrossDeviceOps()), + cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps()), combinations.NamedObject( "ReductionToCPUDeviceCrossDeviceOps", - cross_tower_ops_lib.ReductionToOneDeviceCrossDeviceOps( + cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps( reduce_to_device=_cpu_device)), combinations.NamedObject( "AccumulateNCrossDeviceOp", - cross_tower_ops_lib.ReductionToOneDeviceCrossDeviceOps( + cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps( accumulation_fn=math_ops.accumulate_n)), ], distribution=[ combinations.one_device_strategy, combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus ], mode=["graph", "eager"]) allreduce_combinations = combinations.combine( - cross_tower_ops=[ + cross_device_ops=[ combinations.NamedObject( "AllReduce", - cross_tower_ops_lib.AllReduceCrossDeviceOps("nccl", 1, 0, 0)), + cross_device_ops_lib.AllReduceCrossDeviceOps("nccl", 1, 0, 0)), combinations.NamedObject( "HierarchicalCopy", - cross_tower_ops_lib.AllReduceCrossDeviceOps( + cross_device_ops_lib.AllReduceCrossDeviceOps( "hierarchical_copy", 8, 0, 0)), combinations.NamedObject( "AllReduceNoGradientRepacking", - cross_tower_ops_lib.AllReduceCrossDeviceOps("nccl", 0, 0, 0)), + cross_device_ops_lib.AllReduceCrossDeviceOps("nccl", 0, 0, 0)), combinations.NamedObject( "HierarchicalCopyAggregateSmallTensors", - cross_tower_ops_lib.AllReduceCrossDeviceOps( + cross_device_ops_lib.AllReduceCrossDeviceOps( "hierarchical_copy", 0, 100, 10)) ], - distribution=[combinations.mirrored_strategy_with_two_gpus], + distribution=[combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_two_gpus], mode=["graph", "eager"]) @combinations.generate(reduction_to_one_combinations + allreduce_combinations) - def testReductionAndBroadcast(self, cross_tower_ops, distribution): + def testReductionAndBroadcast(self, cross_device_ops, distribution): with distribution.scope(): - self._testReductionAndBroadcast(cross_tower_ops, distribution) + self._testReductionAndBroadcast(cross_device_ops, distribution) def testChooseAlgorithm(self): device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], [0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6]] - result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) - self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossDeviceOps) + result = cross_device_ops_lib._choose_all_reduce_algorithm(device_links) + self.assertIsInstance(result, cross_device_ops_lib.AllReduceCrossDeviceOps) self.assertEqual(result._all_reduce_alg, "hierarchical_copy") self.assertEqual(result._num_packs, 8) # if there are only 4 devices device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7]] - result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) - self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossDeviceOps) + result = cross_device_ops_lib._choose_all_reduce_algorithm(device_links) + self.assertIsInstance(result, cross_device_ops_lib.AllReduceCrossDeviceOps) self.assertEqual(result._all_reduce_alg, "nccl") self.assertEqual(result._num_packs, 1) @@ -259,16 +260,16 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): device_links = [[0, 1, 2, 3, 4], [0, 1, 2, 3, 5], [0, 1, 2, 3, 6], [0, 1, 2, 3, 7], [0, 4, 5, 6, 7], [1, 4, 5, 6, 7], [2, 4, 5, 6, 7], [3, 4, 5, 6, 7]] - result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) - self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossDeviceOps) + result = cross_device_ops_lib._choose_all_reduce_algorithm(device_links) + self.assertIsInstance(result, cross_device_ops_lib.AllReduceCrossDeviceOps) self.assertEqual(result._all_reduce_alg, "hierarchical_copy") self.assertEqual(result._num_packs, 8) # if not dgx1-like links device_links = [[0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], [0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6], [1, 2, 3, 4]] - result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) - self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossDeviceOps) + result = cross_device_ops_lib._choose_all_reduce_algorithm(device_links) + self.assertIsInstance(result, cross_device_ops_lib.AllReduceCrossDeviceOps) self.assertEqual(result._all_reduce_alg, "nccl") self.assertEqual(result._num_packs, 1) @@ -280,8 +281,8 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): t0 = _make_indexed_slices([[1., 2.]], [1], [5, 2], devices[0]) t1 = _make_indexed_slices([[3., 4.], [5., 6.]], [1, 3], [5, 2], devices[1]) per_replica = value_lib.PerReplica({devices[0]: t0, devices[1]: t1}) - result = cross_tower_ops_lib._simple_reduce( - per_replica, devices[0], math_ops.add_n, vs.VariableAggregation.SUM) + result = cross_device_ops_lib._simple_reduce( + per_replica, devices[0], math_ops.add_n, reduce_util.ReduceOp.SUM) # Test that the result is semantically equal to both the concatenated # IndexedSlices with and without duplicate indices. @@ -294,19 +295,19 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): @combinations.generate( combinations.combine( - cross_tower_ops_instance=[ + cross_device_ops_instance=[ combinations.NamedObject( "ReductionToOneDeviceCrossDeviceOps", - cross_tower_ops_lib.ReductionToOneDeviceCrossDeviceOps()), + cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps()), combinations.NamedObject( "AllReduceCrossDeviceOps", - cross_tower_ops_lib.AllReduceCrossDeviceOps()) + cross_device_ops_lib.AllReduceCrossDeviceOps()) ], - aggregation=[vs.VariableAggregation.SUM, vs.VariableAggregation.MEAN], + reduce_op=[reduce_util.ReduceOp.SUM, reduce_util.ReduceOp.MEAN], batch_reduce=[True, False], mode=["graph", "eager"], required_gpus=1)) - def testIndexedSlicesAllReduce(self, cross_tower_ops_instance, aggregation, + def testIndexedSlicesAllReduce(self, cross_device_ops_instance, reduce_op, batch_reduce): devices = ["/cpu:0", "/gpu:0"] dense_shape = [5, 2] @@ -316,20 +317,20 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): per_replica = value_lib.PerReplica({devices[0]: t0, devices[1]: t1}) if batch_reduce: - result = cross_tower_ops_instance.batch_reduce( - aggregation, [(per_replica, devices)]) + result = cross_device_ops_instance.batch_reduce( + reduce_op, [(per_replica, per_replica)]) else: - result = cross_tower_ops_instance.reduce( - aggregation, per_replica, devices) + result = cross_device_ops_instance.reduce( + reduce_op, per_replica, per_replica) total_indices_with_dups = [1, 1, 3] total_indices_without_dups = [1, 3] - if aggregation == vs.VariableAggregation.SUM: + if reduce_op == reduce_util.ReduceOp.SUM: total_values_with_dups = [[1., 2.], [3., 4.], [5., 6.]] total_values_without_dups = [[4., 6.], [5., 6.]] else: - assert aggregation == vs.VariableAggregation.MEAN + assert reduce_op == reduce_util.ReduceOp.MEAN total_values_with_dups = [[0.5, 1.], [1.5, 2.], [2.5, 3.]] total_values_without_dups = [[2., 3.], [2.5, 3.]] @@ -356,49 +357,63 @@ class MultiWorkerCrossDeviceOpsTest(multi_worker_test_base.MultiWorkerTestBase, "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1" ] multi_worker_allreduce_combinations = combinations.combine( - cross_tower_ops=[ + cross_device_ops=[ combinations.NamedObject( "MultiWorkerAllReduce", - cross_tower_ops_lib.MultiWorkerAllReduce( + cross_device_ops_lib.MultiWorkerAllReduce( worker_devices, 2, ("pscpu/pscpu", 2, -1), 0, 0, 0)), combinations.NamedObject( "MultiWorkerAllReducePack", - cross_tower_ops_lib.MultiWorkerAllReduce( + cross_device_ops_lib.MultiWorkerAllReduce( worker_devices, 2, ("pscpu/pscpu", 2, -1), 1, 0, 0)), combinations.NamedObject( "MultiWorkerAllReduceAggregation", - cross_tower_ops_lib.MultiWorkerAllReduce( + cross_device_ops_lib.MultiWorkerAllReduce( worker_devices, 2, ("pscpu/pscpu", 2, -1), 0, 100, 10)), combinations.NamedObject( "MultiWorkerAllReduceMultipleSpecs", - cross_tower_ops_lib.MultiWorkerAllReduce( + cross_device_ops_lib.MultiWorkerAllReduce( worker_devices, 2, [("pscpu/pscpu", 2, 100), ("xring", 2, -1)], 0, 0, 0)), ], distribution=[ combinations.NamedDistribution( "MirroredCPU", - lambda: mirrored_strategy.MirroredStrategy(num_gpus=0), + lambda: mirrored_strategy.MirroredStrategy(num_gpus_per_worker=0), required_gpus=0), combinations.NamedDistribution( "Mirrored1GPU", - lambda: mirrored_strategy.MirroredStrategy(num_gpus=1), + lambda: mirrored_strategy.MirroredStrategy(num_gpus_per_worker=1), required_gpus=1), combinations.NamedDistribution( "Mirrored2GPUs", - lambda: mirrored_strategy.MirroredStrategy(num_gpus=2), + lambda: mirrored_strategy.MirroredStrategy(num_gpus_per_worker=2), + required_gpus=2), + # pylint: disable=g-long-lambda + combinations.NamedDistribution( + "CoreMirroredCPU", + lambda: mirrored_strategy.CoreMirroredStrategy(["/device:CPU:0"]), + required_gpus=0), + combinations.NamedDistribution( + "CoreMirrored1GPU", + lambda: mirrored_strategy.CoreMirroredStrategy(["/device:GPU:0"]), + required_gpus=1), + combinations.NamedDistribution( + "CoreMirrored2GPUs", + lambda: mirrored_strategy.CoreMirroredStrategy( + ["/device:GPU:0", "/device:GPU:1"]), required_gpus=2), ], mode=["graph"]) @combinations.generate(multi_worker_allreduce_combinations) - def testReductionAndBroadcast(self, cross_tower_ops, distribution): + def testReductionAndBroadcast(self, cross_device_ops, distribution): distribution.configure(cluster_spec={ "worker": ["/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1"] }) with distribution.scope(): - self._testReductionAndBroadcast(cross_tower_ops, distribution) + self._testReductionAndBroadcast(cross_device_ops, distribution) class MultiWorkerCollectiveAllReduceTest( @@ -419,7 +434,7 @@ class MultiWorkerCollectiveAllReduceTest( MultiWorkerCollectiveAllReduceTest.collective_key_base += 100000 def _get_test_objects(self, task_type, task_id, num_gpus=0, local_mode=False): - collective_keys = cross_tower_utils.CollectiveKeys( + collective_keys = cross_device_utils.CollectiveKeys( group_key_start=10 * num_gpus + MultiWorkerCollectiveAllReduceTest.collective_key_base, instance_key_start=num_gpus * 100 + @@ -427,7 +442,7 @@ class MultiWorkerCollectiveAllReduceTest( instance_key_with_id_start=num_gpus * 10000 + MultiWorkerCollectiveAllReduceTest.collective_key_base) if local_mode: - collective_all_reduce_ops = cross_tower_ops_lib.CollectiveAllReduce( + collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce( 1, num_gpus, collective_keys=collective_keys) if num_gpus: devices = ["/device:GPU:%d" % i for i in range(num_gpus)] @@ -435,7 +450,7 @@ class MultiWorkerCollectiveAllReduceTest( devices = ["/device:CPU:0"] return collective_all_reduce_ops, devices, "" else: - collective_all_reduce_ops = cross_tower_ops_lib.CollectiveAllReduce( + collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce( 3, num_gpus, collective_keys=collective_keys) if num_gpus: devices = [ @@ -491,37 +506,35 @@ class MultiWorkerCollectiveAllReduceTest( destination_mirrored = _fake_mirrored(1., devices) destination_different = _fake_mirrored(1., _cpu_device) destination_str = _cpu_device - destination_list = devices all_destinations = [ - destination_different, destination_mirrored, destination_str, - destination_list + destination_different, destination_mirrored, destination_str ] # test reduce() for destinations in all_destinations: self._assert_values_equal( collective_all_reduce.reduce( - vs.VariableAggregation.MEAN, + reduce_util.ReduceOp.MEAN, per_replica, destinations=destinations), _fake_mirrored(mean, destinations), sess) self._assert_values_equal( collective_all_reduce.reduce( - vs.VariableAggregation.MEAN, + reduce_util.ReduceOp.MEAN, per_replica_2, destinations=destinations), _fake_mirrored(mean_2, destinations), sess) self._assert_values_equal( collective_all_reduce.reduce( - vs.VariableAggregation.SUM, + reduce_util.ReduceOp.SUM, per_replica, destinations=destinations), _fake_mirrored(mean * len(devices) * num_workers, destinations), sess) self._assert_values_equal( collective_all_reduce.reduce( - vs.VariableAggregation.SUM, + reduce_util.ReduceOp.SUM, per_replica_2, destinations=destinations), _fake_mirrored(mean_2 * len(devices) * num_workers, destinations), @@ -530,7 +543,7 @@ class MultiWorkerCollectiveAllReduceTest( # test batch_reduce() for d1, d2 in itertools.product(all_destinations, all_destinations): self._assert_values_equal( - collective_all_reduce.batch_reduce(vs.VariableAggregation.MEAN, + collective_all_reduce.batch_reduce(reduce_util.ReduceOp.MEAN, [(per_replica, d1), (per_replica_2, d2)]), [ @@ -538,7 +551,7 @@ class MultiWorkerCollectiveAllReduceTest( _fake_mirrored(mean_2, d2) ], sess) self._assert_values_equal( - collective_all_reduce.batch_reduce(vs.VariableAggregation.SUM, + collective_all_reduce.batch_reduce(reduce_util.ReduceOp.SUM, [(per_replica, d1), (per_replica_2, d2)]), [ diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils_test.py b/tensorflow/contrib/distribute/python/cross_device_utils_test.py similarity index 83% rename from tensorflow/contrib/distribute/python/cross_tower_utils_test.py rename to tensorflow/contrib/distribute/python/cross_device_utils_test.py index e46240abbfa..2303a31677a 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_utils_test.py +++ b/tensorflow/contrib/distribute/python/cross_device_utils_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for cross_tower_utils.""" +"""Tests for cross_device_utils.""" from __future__ import absolute_import from __future__ import division @@ -21,14 +21,14 @@ from __future__ import print_function from absl.testing import parameterized from tensorflow.contrib.distribute.python import combinations -from tensorflow.contrib.distribute.python import cross_tower_utils -from tensorflow.contrib.distribute.python import values as value_lib +from tensorflow.python.distribute import cross_device_utils +from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import values as value_lib from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import math_ops -from tensorflow.python.training import device_util class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): @@ -43,7 +43,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): t0 = constant_op.constant([[1., 2.], [0, 0], [3., 4.]]) t1 = constant_op.constant([[0., 0.], [5, 6], [7., 8.]]) total = constant_op.constant([[1., 2.], [5, 6], [10., 12.]]) - result = cross_tower_utils.aggregate_tensors_or_indexed_slices([t0, t1]) + result = cross_device_utils.aggregate_tensors_or_indexed_slices([t0, t1]) self._assert_values_equal(total, result) @test_util.run_in_graph_and_eager_modes @@ -53,7 +53,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): t1 = math_ops._as_indexed_slices( constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) total = constant_op.constant([[1., 2.], [5, 6], [10., 12.]]) - result = cross_tower_utils.aggregate_tensors_or_indexed_slices([t0, t1]) + result = cross_device_utils.aggregate_tensors_or_indexed_slices([t0, t1]) self.assertIsInstance(result, ops.IndexedSlices) self._assert_values_equal(total, result) @@ -62,7 +62,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): t = constant_op.constant([[1., 2.], [0, 0], [3., 4.]]) n = 2 expected = constant_op.constant([[0.5, 1.], [0, 0], [1.5, 2.]]) - result = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(t, n) + result = cross_device_utils.divide_by_n_tensors_or_indexed_slices(t, n) self._assert_values_equal(expected, result) @test_util.run_in_graph_and_eager_modes @@ -71,7 +71,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) n = 2 expected = constant_op.constant([[0.5, 1.], [0, 0], [1.5, 2.]]) - result = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(t, n) + result = cross_device_utils.divide_by_n_tensors_or_indexed_slices(t, n) self.assertIsInstance(result, ops.IndexedSlices) self._assert_values_equal(expected, result) @@ -79,7 +79,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): def testIsIndexedSlices(self): t = math_ops._as_indexed_slices( constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) - self.assertTrue(cross_tower_utils.contains_indexed_slices(t)) + self.assertTrue(cross_device_utils.contains_indexed_slices(t)) @test_util.run_in_graph_and_eager_modes def testContainsIndexedSlices_List(self): @@ -87,7 +87,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) t1 = math_ops._as_indexed_slices( constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) - self.assertTrue(cross_tower_utils.contains_indexed_slices([t0, t1])) + self.assertTrue(cross_device_utils.contains_indexed_slices([t0, t1])) @test_util.run_in_graph_and_eager_modes def testContainsIndexedSlices_Tuple(self): @@ -95,7 +95,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) t1 = math_ops._as_indexed_slices( constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) - self.assertTrue(cross_tower_utils.contains_indexed_slices((t0, t1))) + self.assertTrue(cross_device_utils.contains_indexed_slices((t0, t1))) @test_util.run_in_graph_and_eager_modes def testContainsIndexedSlices_PerReplica(self): @@ -104,7 +104,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): t1 = math_ops._as_indexed_slices( constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) per_replica = value_lib.PerReplica({"/gpu:0": t0, "/cpu:0": t1}) - self.assertTrue(cross_tower_utils.contains_indexed_slices(per_replica)) + self.assertTrue(cross_device_utils.contains_indexed_slices(per_replica)) @combinations.generate(combinations.combine( mode=["graph", "eager"], @@ -113,7 +113,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): with ops.device("/cpu:0"): t = constant_op.constant([[1., 2.], [0, 0], [3., 4.]]) destination = "/gpu:0" - result = cross_tower_utils.copy_tensor_or_indexed_slices_to_device( + result = cross_device_utils.copy_tensor_or_indexed_slices_to_device( t, destination) self._assert_values_equal(t, result) @@ -128,7 +128,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): t = math_ops._as_indexed_slices( constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) destination = "/gpu:0" - result = cross_tower_utils.copy_tensor_or_indexed_slices_to_device( + result = cross_device_utils.copy_tensor_or_indexed_slices_to_device( t, destination) self.assertIsInstance(result, ops.IndexedSlices) diff --git a/tensorflow/contrib/distribute/python/estimator_integration_test.py b/tensorflow/contrib/distribute/python/estimator_integration_test.py index a1355c0b09e..e17085628ba 100644 --- a/tensorflow/contrib/distribute/python/estimator_integration_test.py +++ b/tensorflow/contrib/distribute/python/estimator_integration_test.py @@ -34,7 +34,7 @@ from tensorflow.python.estimator.canned import dnn_linear_combined from tensorflow.python.estimator.canned import prediction_keys from tensorflow.python.estimator.export import export from tensorflow.python.estimator.inputs import numpy_io -from tensorflow.python.feature_column import feature_column +from tensorflow.python.feature_column import feature_column_lib as feature_column from tensorflow.python.framework import ops from tensorflow.python.platform import gfile from tensorflow.python.summary.writer import writer_cache @@ -63,7 +63,9 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase, distribution=[ combinations.one_device_strategy, combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus ], use_train_and_evaluate=[True, False])) def test_complete_flow_with_mode(self, distribution, use_train_and_evaluate): @@ -75,12 +77,12 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase, train_input_fn = self.dataset_input_fn( x={'x': data}, y=data, - batch_size=batch_size // len(distribution.worker_devices), + batch_size=batch_size // distribution.num_replicas_in_sync, shuffle=True) eval_input_fn = self.dataset_input_fn( x={'x': data}, y=data, - batch_size=batch_size // len(distribution.worker_devices), + batch_size=batch_size // distribution.num_replicas_in_sync, shuffle=False) predict_input_fn = numpy_io.numpy_input_fn( x={'x': data}, batch_size=batch_size, shuffle=False) diff --git a/tensorflow/contrib/distribute/python/estimator_training_test.py b/tensorflow/contrib/distribute/python/estimator_training_test.py index 8f82b4c92aa..b369a7fefe6 100644 --- a/tensorflow/contrib/distribute/python/estimator_training_test.py +++ b/tensorflow/contrib/distribute/python/estimator_training_test.py @@ -24,7 +24,6 @@ import json import os import sys import tempfile -import threading from absl.testing import parameterized import numpy as np @@ -45,11 +44,13 @@ from tensorflow.python.estimator import training as estimator_training from tensorflow.python.estimator.canned import dnn_linear_combined from tensorflow.python.estimator.canned import prediction_keys from tensorflow.python.estimator.export import export as export_lib -from tensorflow.python.feature_column import feature_column +from tensorflow.python.feature_column import feature_column_lib as feature_column from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary import summary_iterator from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import session_manager + BATCH_SIZE = 10 LABEL_DIMENSION = 2 @@ -68,57 +69,19 @@ PS = dc._TaskType.PS original_run_std_server = dc._run_std_server -class MockOsEnv(dict): - - def __init__(self, *args): - self._thread_local = threading.local() - super(MockOsEnv, self).__init__(*args) - - def get(self, key, default): - if not hasattr(self._thread_local, "dict"): - self._thread_local.dict = dict() - if key == "TF_CONFIG": - return dict.get(self._thread_local.dict, key, default) - else: - return dict.get(self, key, default) - - def __getitem__(self, key): - if not hasattr(self._thread_local, "dict"): - self._thread_local.dict = dict() - if key == "TF_CONFIG": - return dict.__getitem__(self._thread_local.dict, key) - else: - return dict.__getitem__(self, key) - - def __setitem__(self, key, val): - if not hasattr(self._thread_local, "dict"): - self._thread_local.dict = dict() - if key == "TF_CONFIG": - return dict.__setitem__(self._thread_local.dict, key, val) - else: - return dict.__setitem__(self, key, val) - - -class DistributeCoordinatorIntegrationTest(test.TestCase, - parameterized.TestCase): +class DistributeCoordinatorIntegrationTest( + multi_worker_test_base.IndependentWorkerTestBase, parameterized.TestCase): @classmethod def setUpClass(cls): """Create a local cluster with 2 workers.""" + super(DistributeCoordinatorIntegrationTest, cls).setUpClass() cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( num_workers=3, num_ps=2, has_eval=True) def setUp(self): self._model_dir = tempfile.mkdtemp() - self._mock_os_env = MockOsEnv() - self._mock_context = test.mock.patch.object(os, "environ", - self._mock_os_env) super(DistributeCoordinatorIntegrationTest, self).setUp() - self._mock_context.__enter__() - - def tearDown(self): - self._mock_context.__exit__(None, None, None) - super(DistributeCoordinatorIntegrationTest, self).tearDown() def dataset_input_fn(self, x, y, batch_size, shuffle): @@ -141,8 +104,8 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, def _extract_loss_and_global_step(self, event_folder): """Returns the loss and global step in last event.""" event_paths = glob.glob(os.path.join(event_folder, "events*")) - self.assertGreater(len(event_paths), 0, - msg="Event file not found in dir %s" % event_folder) + self.assertNotEmpty( + event_paths, msg="Event file not found in dir %s" % event_folder) loss = None global_step_count = None @@ -202,10 +165,10 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, train_input_fn = self.dataset_input_fn( x={"x": DATA}, y=DATA, - batch_size=BATCH_SIZE // len(train_distribute.worker_devices), + batch_size=BATCH_SIZE // train_distribute.num_replicas_in_sync, shuffle=True) if eval_distribute: - eval_batch_size = BATCH_SIZE // len(eval_distribute.worker_devices) + eval_batch_size = BATCH_SIZE // eval_distribute.num_replicas_in_sync else: eval_batch_size = BATCH_SIZE eval_input_fn = self.dataset_input_fn( @@ -285,27 +248,34 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, ]) self.assertAllEqual((BATCH_SIZE, LABEL_DIMENSION), predicted_proba.shape) + def _get_strategy_object(self, strategy_cls): + if strategy_cls == mirrored_strategy.CoreMirroredStrategy: + return strategy_cls(mirrored_strategy.all_local_devices()) + else: + return strategy_cls(num_gpus_per_worker=context.num_gpus()) + @combinations.generate( combinations.combine( mode=["graph"], train_distribute_cls=[ collective_all_reduce_strategy.CollectiveAllReduceStrategy, mirrored_strategy.MirroredStrategy, + mirrored_strategy.CoreMirroredStrategy, parameter_server_strategy.ParameterServerStrategy ], eval_distribute_cls=[ - None, mirrored_strategy.MirroredStrategy, + None, + mirrored_strategy.MirroredStrategy, + mirrored_strategy.CoreMirroredStrategy, parameter_server_strategy.ParameterServerStrategy, ], required_gpus=[0, 1])) def test_complete_flow_standalone_client(self, train_distribute_cls, eval_distribute_cls): - train_distribute = train_distribute_cls( - num_gpus_per_worker=context.num_gpus()) + train_distribute = self._get_strategy_object(train_distribute_cls) if eval_distribute_cls: - eval_distribute = eval_distribute_cls( - num_gpus_per_worker=context.num_gpus()) + eval_distribute = self._get_strategy_object(eval_distribute_cls) else: eval_distribute = None @@ -322,20 +292,20 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, mode=["graph"], train_distribute_cls=[ mirrored_strategy.MirroredStrategy, + mirrored_strategy.CoreMirroredStrategy, ], eval_distribute_cls=[ None, mirrored_strategy.MirroredStrategy, + mirrored_strategy.CoreMirroredStrategy, ], required_gpus=[0, 1])) def test_estimator_standalone_client(self, train_distribute_cls, eval_distribute_cls): - train_distribute = train_distribute_cls( - num_gpus_per_worker=context.num_gpus()) + train_distribute = self._get_strategy_object(train_distribute_cls) if eval_distribute_cls: - eval_distribute = eval_distribute_cls( - num_gpus_per_worker=context.num_gpus()) + eval_distribute = self._get_strategy_object(eval_distribute_cls) else: eval_distribute = None @@ -355,47 +325,15 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, self._barrier.wait() return ret - def _task_thread(self, train_distribute, eval_distribute, tf_config): - os.environ["TF_CONFIG"] = json.dumps(tf_config) + def _independent_worker_fn( + self, + train_distribute, + eval_distribute, + ): with test.mock.patch.object(dc, "_run_std_server", self._mock_run_std_server): self._complete_flow(train_distribute, eval_distribute) - def _run_task_in_thread(self, cluster_spec, task_type, task_id, - train_distribute, eval_distribute): - if task_type: - tf_config = { - "cluster": cluster_spec, - "task": { - "type": task_type, - "index": task_id - } - } - else: - tf_config = { - "cluster": cluster_spec, - "task": { - "type": task_type, - "index": task_id - } - } - t = threading.Thread( - target=self._task_thread, - args=(train_distribute, eval_distribute, tf_config)) - t.start() - return t - - def _run_multiple_tasks_in_threads(self, cluster_spec, train_distribute, - eval_distribute): - threads = {} - for task_type in cluster_spec.keys(): - threads[task_type] = [] - for task_id in range(len(cluster_spec[task_type])): - t = self._run_task_in_thread(cluster_spec, task_type, task_id, - train_distribute, eval_distribute) - threads[task_type].append(t) - return threads - @combinations.generate( combinations.combine( mode=["graph"], @@ -405,21 +343,20 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, ], eval_distribute_cls=[ None, mirrored_strategy.MirroredStrategy, + mirrored_strategy.CoreMirroredStrategy, parameter_server_strategy.ParameterServerStrategy, ], required_gpus=[0, 1])) def test_complete_flow_indepedent_worker_between_graph( self, train_distribute_cls, eval_distribute_cls): - train_distribute = train_distribute_cls( - num_gpus_per_worker=context.num_gpus()) - if (context.num_gpus() < 2 and eval_distribute_cls == collective_all_reduce_strategy.CollectiveAllReduceStrategy): self.skipTest("`CollectiveAllReduceStrategy` needs at least two towers.") + train_distribute = self._get_strategy_object(train_distribute_cls) + if eval_distribute_cls: - eval_distribute = eval_distribute_cls( - num_gpus_per_worker=context.num_gpus()) + eval_distribute = self._get_strategy_object(eval_distribute_cls) else: eval_distribute = None @@ -435,8 +372,9 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, # 3 workers and 1 evaluator. self._barrier = dc._Barrier(4) - threads = self._run_multiple_tasks_in_threads( - cluster_spec, train_distribute, eval_distribute) + threads = self.run_multiple_tasks_in_threads(self._independent_worker_fn, + cluster_spec, train_distribute, + eval_distribute) for task_type, ts in threads.items(): if task_type == PS: continue @@ -449,17 +387,22 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, @combinations.generate( combinations.combine( mode=["graph"], - train_distribute_cls=[mirrored_strategy.MirroredStrategy], - eval_distribute_cls=[None, mirrored_strategy.MirroredStrategy], + train_distribute_cls=[ + mirrored_strategy.MirroredStrategy, + mirrored_strategy.CoreMirroredStrategy + ], + eval_distribute_cls=[ + None, + mirrored_strategy.MirroredStrategy, + mirrored_strategy.CoreMirroredStrategy + ], required_gpus=[0, 1])) def test_complete_flow_indepedent_worker_in_graph(self, train_distribute_cls, eval_distribute_cls): - train_distribute = train_distribute_cls( - num_gpus_per_worker=context.num_gpus()) + train_distribute = self._get_strategy_object(train_distribute_cls) if eval_distribute_cls: - eval_distribute = eval_distribute_cls( - num_gpus_per_worker=context.num_gpus()) + eval_distribute = self._get_strategy_object(eval_distribute_cls) else: eval_distribute = None @@ -467,8 +410,9 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, num_workers=3, num_ps=0, has_eval=True) # 3 workers and 1 evaluator. self._barrier = dc._Barrier(4) - threads = self._run_multiple_tasks_in_threads( - cluster_spec, train_distribute, eval_distribute) + threads = self.run_multiple_tasks_in_threads(self._independent_worker_fn, + cluster_spec, train_distribute, + eval_distribute) threads[WORKER][0].join() threads[EVALUATOR][0].join() @@ -506,7 +450,8 @@ class RunConfigTest(test.TestCase): "os.environ", {"TF_CONFIG": json.dumps(TF_CONFIG_WITHOUT_TASK)}): run_config_lib.RunConfig( experimental_distribute=DistributeConfig( - train_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2))) + train_distribute=mirrored_strategy.CoreMirroredStrategy( + ["/device:GPU:0", "/device:GPU:1"]))) def test_should_run_distribute_coordinator(self): """Tests that should_run_distribute_coordinator return a correct value.""" @@ -529,10 +474,12 @@ class RunConfigTest(test.TestCase): {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}): config_with_train_distribute = run_config_lib.RunConfig( experimental_distribute=DistributeConfig( - train_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2))) + train_distribute=mirrored_strategy.CoreMirroredStrategy( + ["/device:GPU:0", "/device:GPU:1"]))) config_with_eval_distribute = run_config_lib.RunConfig( experimental_distribute=DistributeConfig( - eval_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2))) + eval_distribute=mirrored_strategy.CoreMirroredStrategy( + ["/device:GPU:0", "/device:GPU:1"]))) self.assertTrue( dc_training.should_run_distribute_coordinator( config_with_train_distribute)) @@ -545,26 +492,27 @@ class RunConfigTest(test.TestCase): {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_MASTER)}): config = run_config_lib.RunConfig( experimental_distribute=DistributeConfig( - train_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2))) + train_distribute=mirrored_strategy.CoreMirroredStrategy( + ["/device:GPU:0", "/device:GPU:1"]))) self.assertFalse(dc_training.should_run_distribute_coordinator(config)) def test_init_run_config_duplicate_distribute(self): with self.assertRaises(ValueError): run_config_lib.RunConfig( - train_distribute=mirrored_strategy.MirroredStrategy(), + train_distribute=mirrored_strategy.CoreMirroredStrategy(), experimental_distribute=DistributeConfig( - train_distribute=mirrored_strategy.MirroredStrategy())) + train_distribute=mirrored_strategy.CoreMirroredStrategy())) with self.assertRaises(ValueError): run_config_lib.RunConfig( - eval_distribute=mirrored_strategy.MirroredStrategy(), + eval_distribute=mirrored_strategy.CoreMirroredStrategy(), experimental_distribute=DistributeConfig( - eval_distribute=mirrored_strategy.MirroredStrategy())) + eval_distribute=mirrored_strategy.CoreMirroredStrategy())) def test_init_run_config_none_distribute_coordinator_mode(self): # We don't use distribute coordinator for local training. config = run_config_lib.RunConfig( - train_distribute=mirrored_strategy.MirroredStrategy()) + train_distribute=mirrored_strategy.CoreMirroredStrategy()) dc_training.init_run_config(config, {}) self.assertIsNone(config._distribute_coordinator_mode) @@ -572,7 +520,7 @@ class RunConfigTest(test.TestCase): with test.mock.patch.dict("os.environ", {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_MASTER)}): config = run_config_lib.RunConfig( - train_distribute=mirrored_strategy.MirroredStrategy()) + train_distribute=mirrored_strategy.CoreMirroredStrategy()) self.assertIsNone(config._distribute_coordinator_mode) # When `train_distribute` is not specified, don't use distribute @@ -588,7 +536,7 @@ class RunConfigTest(test.TestCase): with test.mock.patch.dict("os.environ", {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}): config = run_config_lib.RunConfig( - train_distribute=mirrored_strategy.MirroredStrategy()) + train_distribute=mirrored_strategy.CoreMirroredStrategy()) self.assertEqual(config._distribute_coordinator_mode, dc.CoordinatorMode.INDEPENDENT_WORKER) @@ -597,7 +545,7 @@ class RunConfigTest(test.TestCase): # `experimental.remote_cluster` is set use distribute coordinator with # STANDALONE_CLIENT mode. config = run_config_lib.RunConfig( - train_distribute=mirrored_strategy.MirroredStrategy(), + train_distribute=mirrored_strategy.CoreMirroredStrategy(), experimental_distribute=DistributeConfig( remote_cluster={"chief": ["fake_worker"]})) self.assertEqual(config._distribute_coordinator_mode, @@ -605,5 +553,15 @@ class RunConfigTest(test.TestCase): if __name__ == "__main__": + # Reduce `recovery_wait_secs` from 30 seconds so the test completes quickly. + orig_init = session_manager.SessionManager.__init__ + + def new_init(*args, **kwargs): + kwargs.pop("recovery_wait_secs", None) + kwargs["recovery_wait_secs"] = 0.5 + orig_init(*args, **kwargs) + + session_manager.SessionManager.__init__ = new_init + with test.mock.patch.object(sys, "exit", os._exit): test.main() diff --git a/tensorflow/contrib/distribute/python/examples/keras_mnist.py b/tensorflow/contrib/distribute/python/examples/keras_mnist.py index 0fd3acd0451..60fda996642 100644 --- a/tensorflow/contrib/distribute/python/examples/keras_mnist.py +++ b/tensorflow/contrib/distribute/python/examples/keras_mnist.py @@ -20,6 +20,10 @@ from __future__ import print_function import tensorflow as tf +from tensorflow.python.distribute import mirrored_strategy +from tensorflow.python.keras.optimizer_v2 import rmsprop + + NUM_CLASSES = 10 @@ -102,18 +106,23 @@ def main(_): # Build the train and eval datasets from the MNIST data. Also return the # input shape which is constructed based on the `image_data_format` # i.e channels_first or channels_last. + tf.enable_eager_execution() + train_ds, eval_ds, input_shape = get_input_datasets() model = get_model(input_shape) # Instantiate the MirroredStrategy object. If we don't specify `num_gpus` or # the `devices` argument then all the GPUs available on the machine are used. - strategy = tf.contrib.distribute.MirroredStrategy() + # TODO(priyag): Use `tf.distribute.MirroredStrategy` once available. + strategy = mirrored_strategy.MirroredStrategy(['/gpu:0', '/cpu:0']) + + optimizer = rmsprop.RMSProp(learning_rate=0.001) # Compile the model by passing the distribution strategy object to the # `distribute` argument. `fit`, `evaluate` and `predict` will be distributed # based on the strategy instantiated. model.compile(loss=tf.keras.losses.categorical_crossentropy, - optimizer=tf.train.RMSPropOptimizer(learning_rate=0.001), + optimizer=optimizer, metrics=['accuracy'], distribute=strategy) diff --git a/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py b/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py index 46a1cf41c55..6dfd85bcc4f 100644 --- a/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py +++ b/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py @@ -25,18 +25,23 @@ import numpy as np import six from tensorflow.contrib.distribute.python import combinations -from tensorflow.contrib.distribute.python import mirrored_strategy from tensorflow.core.protobuf import config_pb2 +from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.eager import context +from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.estimator import run_config from tensorflow.python.estimator import training from tensorflow.python.estimator.canned import dnn_linear_combined from tensorflow.python.estimator.canned import prediction_keys from tensorflow.python.estimator.export import export from tensorflow.python.estimator.inputs import numpy_io -from tensorflow.python.feature_column import feature_column +from tensorflow.python.feature_column import feature_column_lib as feature_column +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.keras.optimizer_v2 import adam +from tensorflow.python.keras.optimizer_v2 import gradient_descent +from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import gfile @@ -64,7 +69,9 @@ class KerasOptimizerV2IntegrationTest(test.TestCase, parameterized.TestCase): distribution=[ combinations.one_device_strategy, combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus ], use_train_and_evaluate=[True, False])) def test_complete_flow_with_mode(self, distribution, use_train_and_evaluate): @@ -76,11 +83,11 @@ class KerasOptimizerV2IntegrationTest(test.TestCase, parameterized.TestCase): train_input_fn = self.dataset_input_fn( x={'x': data}, y=data, - batch_size=batch_size // len(distribution.worker_devices)) + batch_size=batch_size // distribution.num_replicas_in_sync) eval_input_fn = self.dataset_input_fn( x={'x': data}, y=data, - batch_size=batch_size // len(distribution.worker_devices)) + batch_size=batch_size // distribution.num_replicas_in_sync) predict_input_fn = numpy_io.numpy_input_fn( x={'x': data}, batch_size=batch_size, shuffle=False) @@ -136,44 +143,51 @@ class KerasOptimizerV2IntegrationTest(test.TestCase, parameterized.TestCase): shutil.rmtree(self._model_dir) -class MirroredStrategyOptimizerV2Test(test.TestCase): +def get_model(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + return model - def testKerasOptimizerWithUnequalInput(self): - if context.num_gpus() < 1: - self.skipTest('Not enough GPUs.') - def create_fn(device_id): +class MirroredStrategyOptimizerV2Test(test.TestCase, parameterized.TestCase): + + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph'])) + def testKerasOptimizerWithUnequalInput(self, distribution): + def create_fn(): var = variables.Variable( 2.0, name='var', aggregation=variable_scope.VariableAggregation.SUM) # grad for cpu is 1, grad for gpu is 2, avg grad is 1.5. - loss = (device_id + 1) * var + loss = math_ops.cast(_replica_id() + 1, dtype=dtypes.float32) * var optimizer = adam.Adam(learning_rate=0.01, beta_1=0.2, beta_2=0.2) train_op = optimizer.minimize(loss, var_list=[var]) m = optimizer.get_slot(var, 'm') v = optimizer.get_slot(var, 'v') - return (var, m, v, train_op, optimizer.iteration) + return (var, m, v, train_op, optimizer.iterations) devices = ['/device:GPU:0', '/device:CPU:0'] - dist = mirrored_strategy.MirroredStrategy(devices) - with dist.scope(): - (var, m, v, op, counter) = dist.call_for_each_replica( - create_fn, args=[dist.worker_device_index]) + with distribution.scope(): + (var, m, v, op, counter) = distribution.call_for_each_replica(create_fn) self.evaluate(variables.global_variables_initializer()) var_val = [2.0, 2.0, 2.0] self.assertAllClose( var_val, self.evaluate( - [dist.read_var(var), + [distribution.read_var(var), var.get(devices[0]), var.get(devices[1])])) self.assertAllClose([0, 0, 0], self.evaluate([ - dist.read_var(counter), + distribution.read_var(counter), counter.get(devices[0]), counter.get(devices[1]) ])) - train_op = dist.unwrap(op) + train_op = distribution.unwrap(op) self.evaluate(train_op) # m(1) = beta1 * m(0) + (1-beta1) * grad = 0.2 * 0 + 0.8 * (1 + 2) / 2 m_val = [1.2, 1.2, 1.2] @@ -181,7 +195,7 @@ class MirroredStrategyOptimizerV2Test(test.TestCase): self.assertAllClose( m_val, self.evaluate( - [dist.read_var(m), + [distribution.read_var(m), m.get(devices[0]), m.get(devices[1])])) # v(1) = beta2 * v(0) + (1-beta2) * grad^2 = 0.2 * 0 + 0.8 * 2.25 @@ -189,7 +203,7 @@ class MirroredStrategyOptimizerV2Test(test.TestCase): self.assertAllClose( v_val, self.evaluate( - [dist.read_var(v), + [distribution.read_var(v), v.get(devices[0]), v.get(devices[1])])) # var(1) = var(0) - lr * m(1) * sqrt(1 - beta2) / sqrt(v(1)) / (1 - beta1) @@ -198,12 +212,12 @@ class MirroredStrategyOptimizerV2Test(test.TestCase): self.assertAllClose( var_val, self.evaluate( - [dist.read_var(var), + [distribution.read_var(var), var.get(devices[0]), var.get(devices[1])])) self.assertAllClose([1, 1, 1], self.evaluate([ - dist.read_var(counter), + distribution.read_var(counter), counter.get(devices[0]), counter.get(devices[1]) ])) @@ -214,7 +228,7 @@ class MirroredStrategyOptimizerV2Test(test.TestCase): self.assertAllClose( m_val, self.evaluate( - [dist.read_var(m), + [distribution.read_var(m), m.get(devices[0]), m.get(devices[1])])) # v(2) = beta2 * v(1) + (1-beta2) * grad^2 = 0.2 * 1.8 + 0.8 * 2.25 @@ -222,16 +236,50 @@ class MirroredStrategyOptimizerV2Test(test.TestCase): self.assertAllClose( v_val, self.evaluate( - [dist.read_var(v), + [distribution.read_var(v), v.get(devices[0]), v.get(devices[1])])) self.assertAllClose([2, 2, 2], self.evaluate([ - dist.read_var(counter), + distribution.read_var(counter), counter.get(devices[0]), counter.get(devices[1]) ])) + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph'])) + def testOptimizerWithKerasModelAndNumpyArrays(self, distribution): + + with self.cached_session(): + model = get_model() + optimizer = gradient_descent.SGD(0.001) + loss = 'mse' + metrics = ['mae'] + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + + inputs = np.zeros((64, 3), dtype=np.float32) + targets = np.zeros((64, 4), dtype=np.float32) + + model.fit( + inputs, + targets, + epochs=1, + batch_size=2, + verbose=0, + validation_data=(inputs, targets)) + model.evaluate(inputs, targets) + model.predict(inputs) + + +def _replica_id(): + replica_id = ds_context.get_replica_context().replica_id_in_sync_group + if not isinstance(replica_id, ops.Tensor): + replica_id = constant_op.constant(replica_id) + return replica_id + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py index 0db5844e4c4..e530ab6f173 100644 --- a/tensorflow/contrib/distribute/python/keras_test.py +++ b/tensorflow/contrib/distribute/python/keras_test.py @@ -24,9 +24,10 @@ import numpy as np from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import mirrored_strategy from tensorflow.contrib.distribute.python import tpu_strategy -from tensorflow.contrib.distribute.python import values from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import values +from tensorflow.python.eager import test from tensorflow.python.estimator import keras as keras_lib from tensorflow.python.estimator import run_config as run_config_lib from tensorflow.python.framework import constant_op @@ -35,14 +36,13 @@ from tensorflow.python.framework import random_seed from tensorflow.python.framework import test_util from tensorflow.python.keras import testing_utils from tensorflow.python.keras.engine import distributed_training_utils +from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras from tensorflow.python.ops.parsing_ops import gen_parsing_ops from tensorflow.python.platform import gfile -from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache from tensorflow.python.training import gradient_descent from tensorflow.python.training import rmsprop - _RANDOM_SEED = 1337 _TRAIN_SIZE = 200 _INPUT_SIZE = (10,) @@ -212,13 +212,18 @@ def multi_input_output_model(): return model -def get_correctness_test_inputs(use_numpy, with_distribution, +def get_correctness_test_inputs(use_numpy, use_validation_data, + with_distribution, x_train, y_train, x_predict): """Generates the inputs for correctness check when enable Keras with DS.""" global_batch_size = 64 batch_size = global_batch_size # TODO(b/118776054): Use global batch size for Keras/DS support. - if with_distribution: + use_per_core_batch_size = ( + with_distribution and + not distributed_training_utils.global_batch_size_supported( + with_distribution)) + if use_per_core_batch_size: batch_size //= with_distribution.num_replicas_in_sync if use_numpy: @@ -229,16 +234,17 @@ def get_correctness_test_inputs(use_numpy, with_distribution, 'epochs': 1, 'shuffle': False, } - eval_inputs = { - 'batch_size': batch_size, - 'x': x_train, - 'y': y_train, - } + + if use_validation_data: + eval_inputs = None + training_inputs['validation_data'] = (x_train, y_train) + else: + eval_inputs = { + 'batch_size': batch_size, + 'x': x_train, + 'y': y_train, + } predict_inputs = { - # TODO(b/119318587): We should not require batch_size when distribution - # is enabled. - 'batch_size': (len(x_predict) // with_distribution.num_replicas_in_sync - if with_distribution else None), 'x': np.array(x_predict, dtype=np.float32), } else: @@ -256,20 +262,28 @@ def get_correctness_test_inputs(use_numpy, with_distribution, 'shuffle': False, 'steps_per_epoch': len(x_train) // global_batch_size, } - eval_inputs = { - 'batch_size': None, - 'x': x, - 'y': None, - 'steps': 20, - } + if use_validation_data: + eval_inputs = None # Remove the eval_inputs + eval_dataset = dataset_ops.Dataset.from_tensor_slices( + (x_train, y_train)) + x = batch_wrapper(eval_dataset, batch_size, with_distribution) + training_inputs['validation_data'] = x + training_inputs['validation_steps'] = 5 + else: + eval_inputs = { + 'batch_size': None, + 'x': x, + 'y': None, + 'steps': 20, + } + predict_batch_size = len(x_predict) - if with_distribution: + if use_per_core_batch_size: predict_batch_size //= with_distribution.num_replicas_in_sync predict_dataset = dataset_ops.Dataset.from_tensor_slices(x_predict) predict_dataset = batch_wrapper(predict_dataset, predict_batch_size, with_distribution) predict_inputs = { - 'batch_size': None, 'steps': 1, 'x': predict_dataset, } @@ -277,47 +291,71 @@ def get_correctness_test_inputs(use_numpy, with_distribution, return training_inputs, eval_inputs, predict_inputs -strategies = [combinations.default_strategy, - combinations.one_device_strategy, - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus, - combinations.tpu_strategy, # steps_per_run=2 - combinations.tpu_strategy_one_step] +strategies_minus_tpu = [ + combinations.default_strategy, + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus] + +tpu_strategies = [ + combinations.tpu_strategy, # steps_per_run=2 + combinations.tpu_strategy_one_step] def strategy_minus_tpu_combinations(): return combinations.combine( - distribution=[combinations.default_strategy, - combinations.one_device_strategy, - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus], - mode=['graph']) + distribution=strategies_minus_tpu, + mode=['graph', 'eager']) -def strategy_combinations(): +def tpu_strategy_combinations(): return combinations.combine( - distribution=strategies, + distribution=tpu_strategies, mode=['graph']) +def all_strategy_combinations(): + return strategy_minus_tpu_combinations() + tpu_strategy_combinations() + + +# TODO(priyag): Add v2 optimizers here. def strategy_and_optimizer_combinations(): + return combinations.times( + all_strategy_combinations(), + combinations.combine( + optimizer=[combinations.adagrad_optimizer_v1_fn, + combinations.adam_optimizer_v1_fn, + combinations.gradient_descent_optimizer_v1_fn, + combinations.rmsprop_optimizer_v1_fn])) + + +def strategy_and_input_combinations(): + return ( + combinations.times( + combinations.combine(distribution=strategies_minus_tpu), + combinations.combine(mode=['graph'], + use_numpy=[True, False], + use_validation_data=[True, False]) + + combinations.combine(mode=['eager'], + use_numpy=[False], + use_validation_data=[False])) + + combinations.times( + combinations.combine(distribution=tpu_strategies), + combinations.combine(mode=['graph'], + use_numpy=[True, False], + use_validation_data=[True, False]))) + + +def strategy_for_numpy_input_combinations(): return combinations.combine( - distribution=strategies, - optimizer=[combinations.adagrad_optimizer_v1_fn, - combinations.adam_optimizer_v1_fn, - combinations.gradient_descent_optimizer_v1_fn, - combinations.rmsprop_optimizer_v1_fn], + distribution=strategies_minus_tpu + tpu_strategies, mode=['graph']) -def strategy_and_inputs(): - return combinations.combine( - distribution=strategies, - use_numpy=[True, False], - mode=['graph']) - - -class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): +class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase, + parameterized.TestCase): def setUp(self): self._base_dir = os.path.join(self.get_temp_dir(), @@ -325,17 +363,18 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): gfile.MakeDirs(self._base_dir) self._config = run_config_lib.RunConfig( tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir) - self._dist = mirrored_strategy.MirroredStrategy( - devices=['/device:GPU:0', '/device:GPU:1']) def tearDown(self): writer_cache.FileWriterCache.clear() if os.path.isdir(self._base_dir): gfile.DeleteRecursively(self._base_dir) - def test_train_functional_with_distribution_strategy(self): - dist = mirrored_strategy.MirroredStrategy( - devices=['/device:GPU:0', '/device:GPU:1']) + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_two_gpus], + mode=['graph'])) + def test_train_functional_with_distribution_strategy(self, distribution): keras_model = simple_functional_model() keras_model.compile( loss='categorical_crossentropy', @@ -343,8 +382,8 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01)) config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir, - train_distribute=dist, - eval_distribute=dist) + train_distribute=distribution, + eval_distribute=distribution) with self.cached_session(): est_keras = keras_lib.model_to_estimator( keras_model=keras_model, config=config) @@ -358,9 +397,12 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): writer_cache.FileWriterCache.clear() gfile.DeleteRecursively(self._config.model_dir) - def test_train_sequential_with_distribution_strategy(self): - dist = mirrored_strategy.MirroredStrategy( - devices=['/device:GPU:0', '/device:GPU:1']) + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_two_gpus], + mode=['graph'])) + def test_train_sequential_with_distribution_strategy(self, distribution): keras_model = simple_sequential_model() keras_model.compile( loss='categorical_crossentropy', @@ -368,7 +410,7 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01)) config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir, - train_distribute=dist) + train_distribute=distribution) with self.cached_session(): est_keras = keras_lib.model_to_estimator( keras_model=keras_model, config=config) @@ -382,7 +424,12 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): writer_cache.FileWriterCache.clear() gfile.DeleteRecursively(self._config.model_dir) - def test_multi_inputs_multi_outputs_with_input_fn_as_dict(self): + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_two_gpus], + mode=['graph'])) + def test_multi_inputs_multi_outputs_with_input_fn_as_dict(self, distribution): train_data, test_data = get_multi_inputs_multi_outputs_data() def train_input_fn(): @@ -412,14 +459,14 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): output_dict)).batch(16) self.do_test_multi_inputs_multi_outputs_with_input_fn( - train_input_fn, eval_input_fn) + distribution, train_input_fn, eval_input_fn) - def do_test_multi_inputs_multi_outputs_with_input_fn(self, train_input_fn, - eval_input_fn): + def do_test_multi_inputs_multi_outputs_with_input_fn( + self, distribution, train_input_fn, eval_input_fn): config = run_config_lib.RunConfig( tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir, - train_distribute=self._dist) + train_distribute=distribution) with self.cached_session(): model = multi_inputs_multi_outputs_model() est_keras = keras_lib.model_to_estimator(keras_model=model, config=config) @@ -429,9 +476,12 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1) self.assertLess(eval_results['loss'], baseline_eval_results['loss']) - def test_keras_optimizer_with_distribution_strategy(self): - dist = mirrored_strategy.MirroredStrategy( - devices=['/device:GPU:0', '/device:GPU:1']) + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_two_gpus], + mode=['graph'])) + def test_keras_optimizer_with_distribution_strategy(self, distribution): keras_model = simple_sequential_model() keras_model.compile( loss='categorical_crossentropy', @@ -439,7 +489,7 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir, - train_distribute=dist) + train_distribute=distribution) with self.cached_session(): est_keras = keras_lib.model_to_estimator(keras_model=keras_model, config=config) @@ -455,7 +505,7 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): class TestDistributionStrategyWithNumpyArrays(test.TestCase, parameterized.TestCase): - @combinations.generate(strategy_combinations()) + @combinations.generate(strategy_for_numpy_input_combinations()) def test_creating_var_with_numpy_arrays(self, distribution): with self.cached_session(): x = np.asarray(np.random.random((64, 3)), dtype=np.float32) @@ -464,84 +514,135 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, # Verify that the numpy value is copied to the variable. self.assertAllEqual(x, val) - def test_calculating_batch_params(self): - # This verifies that we calculate the number of steps when the batch size - # is specified. + @combinations.generate(strategy_for_numpy_input_combinations()) + def test_calculating_input_params_no_steps_no_batch_size(self, distribution): + # Calculate the per_replica_batch_size scaling factor for strategies + # that use per_core_batch_size + replica_scale_factor = 1.0 + if not distributed_training_utils.global_batch_size_supported(distribution): + replica_scale_factor = distribution.num_replicas_in_sync + with self.cached_session(): - # 64 is the number of input samples. - inputs = np.zeros((64, 3), dtype=np.float32) - # The number of replicas is equal to 3. - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', - '/device:CPU:0', - '/device:GPU:1']) + # Input samples of different sizes + input_20_samples = np.zeros((20, 3), dtype=np.float32) + input_63_samples = np.zeros((63, 3), dtype=np.float32) + input_64_samples = np.zeros((64, 3), dtype=np.float32) - with self.assertRaisesRegexp(ValueError, 'Please specify a batch_size ' - 'that is smaller than'): - # The batch size(128) is larger than the number of input - # samples(64). - distributed_training_utils.get_input_batch_params(inputs, - 128, - strategy) - - with self.assertRaisesRegexp(ValueError, 'is smaller than the number ' - 'of replicas'): - # The batch size(32) * num_replicas_in_sync(3) is 96 which is greater - # than the number of input samples(64). - distributed_training_utils.get_input_batch_params(inputs, - 32, - strategy) - - # The number of replicas now is equal to 2. - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', - '/device:CPU:0']) - # 32 is the batch size per replica. - steps = distributed_training_utils.get_input_batch_params(inputs, - 32, - strategy) - # The number of batches is the ratio of input samples(64) to - # batch size(32) which is 2. The number of steps(1) is the ratio of - # number of batches(2) to the number of replicas(2). - self.assertEqual(steps, 1) - - # 16 is the batch size per replica. - steps = distributed_training_utils.get_input_batch_params(inputs, - 16, - strategy) - # The number of batches is the ratio of input samples(64) to - # batch size(16) which is 4. The number of steps(2) is the ratio of - # number of batches(4) to the number of replicas(2). + # Default global batch size 32 for input with 64 samples run in 2 steps + steps, batch_size = distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=None, batch_size=None) + self.assertEqual(batch_size, 32 // replica_scale_factor) self.assertEqual(steps, 2) - def test_calculating_batch_size(self): + # Computed global batch size 20 is lower than 32 if we pass less samples. + steps, batch_size = distributed_training_utils.get_input_params( + distribution, input_20_samples, steps=None, batch_size=None) + self.assertEqual(batch_size, 20 // replica_scale_factor) + self.assertEqual(steps, 1) + + # Default global batch size 32 cannot be used with 63 samples. + with self.assertRaisesRegexp(ValueError, 'not divisible by batch size'): + distributed_training_utils.get_input_params( + distribution, input_63_samples, steps=None, batch_size=None) + + @combinations.generate(strategy_for_numpy_input_combinations()) + def test_calculating_input_params_with_steps_no_batch_size(self, + distribution): + # Calculate the per_replica_batch_size scaling factor for strategies + # that use per_core_batch_size + replica_scale_factor = 1.0 + if not distributed_training_utils.global_batch_size_supported(distribution): + replica_scale_factor = distribution.num_replicas_in_sync + with self.cached_session(): - # 64 is the number of input samples. - inputs = np.zeros((64, 3), dtype=np.float32) - targets = np.zeros((64, 4), dtype=np.float32) + # Input samples of different sizes + input_63_samples = np.zeros((63, 3), dtype=np.float32) + input_64_samples = np.zeros((64, 3), dtype=np.float32) - model = get_model() - optimizer = gradient_descent.GradientDescentOptimizer(0.001) - loss = 'mse' - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', - '/device:CPU:0']) - strategy._require_static_shapes = True + # Computed global batch size is correct for number of specified 1 step + steps, batch_size = distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=1, batch_size=None) + self.assertEqual(batch_size, 64 // replica_scale_factor) + self.assertEqual(steps, 1) - model.compile(optimizer, loss, distribute=strategy) - iterator = model._distribution_standardize_user_data(inputs, - targets, - batch_size=None, - check_steps=True, - steps_name='steps', - steps=3) + # Computed global batch size is correct for number of specified 2 steps + steps, batch_size = distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=2, batch_size=None) + self.assertEqual(batch_size, 32 // replica_scale_factor) + self.assertEqual(steps, 2) - # The global batch size(21) across all replicas is the ratio of the input - # samples(64) to the steps(3). - # The batch size(10) per device is the ratio of the global batch size(21) - # to the number of replicas(2). - # The global batch size and batch size are rounded integer values. - self.assertEqual(10, distributed_training_utils.get_batch_dimension( - iterator._iterator)) + # All samples can not be consumed in specified number of steps + with self.assertRaisesRegexp(ValueError, 'not divisible by steps'): + distributed_training_utils.get_input_params( + distribution, input_63_samples, steps=2, batch_size=None) - @combinations.generate(strategy_combinations()) + # This cases is different for different strategies due to the + # difference in supported batch size being global or per-replica. + if replica_scale_factor == 1: + # Computed global batch size is correct even if not sharadable + steps, batch_size = distributed_training_utils.get_input_params( + distribution, input_63_samples, steps=3, batch_size=None) + self.assertEqual(batch_size, 21) + self.assertEqual(steps, 3) + else: + # Computed global batch size can not be sharded across replicas + with self.assertRaisesRegexp(ValueError, 'could not be sharded evenly ' + 'across the sync replicas'): + distributed_training_utils.get_input_params( + distribution, input_63_samples, steps=1, batch_size=None) + + @combinations.generate(strategy_for_numpy_input_combinations()) + def test_calculating_input_params_no_steps_with_batch_size(self, + distribution): + # Calculate the per_replica_batch_size scaling factor for strategies + # that use per_core_batch_size + replica_scale_factor = 1.0 + if not distributed_training_utils.global_batch_size_supported(distribution): + replica_scale_factor = distribution.num_replicas_in_sync + + with self.cached_session(): + input_64_samples = np.zeros((64, 3), dtype=np.float32) + + # Computed steps is correct for specified batch size + steps, batch_size = distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=None, batch_size=16) + self.assertEqual(batch_size, 16) + self.assertEqual(steps, 4 // replica_scale_factor) + + # Computed steps is correct for specified batch size + steps, batch_size = distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=None, batch_size=32) + self.assertEqual(batch_size, 32) + self.assertEqual(steps, 2 // replica_scale_factor) + + # Number of samples is not divisible by the global batch size + with self.assertRaisesRegexp(ValueError, 'not divisible by batch size'): + distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=None, batch_size=20) + + # Number of samples is not divisible by the global batch size + with self.assertRaisesRegexp(ValueError, 'not divisible by batch size'): + distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=None, batch_size=3) + + @combinations.generate(strategy_for_numpy_input_combinations()) + def test_calculating_input_params_with_steps_with_batch_size(self, + distribution): + with self.cached_session(): + input_64_samples = np.zeros((64, 3), dtype=np.float32) + + # No change to steps and batch size if both specified and feasible + steps, batch_size = distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=5, batch_size=3) + self.assertEqual(batch_size, 3) + self.assertEqual(steps, 5) + + # Number of samples is less than global batch size * steps + with self.assertRaisesRegexp(ValueError, 'less than samples required'): + distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=10, batch_size=13) + + @combinations.generate(strategy_for_numpy_input_combinations()) def test_calling_model_with_numpy_arrays(self, distribution): with self.cached_session(): model = get_model() @@ -572,7 +673,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, # with batch_size model.predict(inputs, batch_size=8) - @combinations.generate(strategy_combinations()) + @combinations.generate(strategy_for_numpy_input_combinations()) def test_calling_model_with_nested_numpy_arrays(self, distribution): with self.cached_session(): model = multi_input_output_model() @@ -606,21 +707,22 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, # with batch_size model.predict(inputs, batch_size=8) - @combinations.generate(strategy_minus_tpu_combinations()) + @combinations.generate(combinations.combine( + distribution=strategies_minus_tpu, mode=['graph'])) def test_numpy_with_sample_weights(self, distribution): model = get_model() optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) loss = 'mse' model.compile(optimizer, loss, distribute=distribution) - inputs = np.zeros((10, 3), np.float32) - targets = np.zeros((10, 4), np.float32) - sample_weights = np.ones((10), np.float32) + inputs = np.zeros((20, 3), np.float32) + targets = np.zeros((20, 4), np.float32) + sample_weights = np.ones((20), np.float32) model.fit(inputs, targets, sample_weight=sample_weights, epochs=1, steps_per_epoch=2, verbose=1) - @combinations.generate(strategy_combinations()) + @combinations.generate(strategy_for_numpy_input_combinations()) def test_flatten_predict_outputs(self, distribution): with self.cached_session(): model = multi_input_output_model() @@ -638,7 +740,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, # `predict` a list that is equal in length to the number of model outputs. # In this test our model has two outputs and each element of `outs` # corresponds to all the samples of one of the model outputs. - self.assertEqual(2, len(outs)) + self.assertLen(outs, 2) # Each of the output samples have a dimension of 7. We should process all # the available input samples(6). self.assertAllEqual([6, 7], outs[0].shape) @@ -648,7 +750,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, class TestDistributionStrategyWithDatasets(test.TestCase, parameterized.TestCase): - @combinations.generate(strategy_combinations()) + @combinations.generate(all_strategy_combinations()) def test_calling_model_on_same_dataset(self, distribution): with self.cached_session(): model = get_model() @@ -667,7 +769,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase, validation_data=dataset, validation_steps=2) model.predict(get_predict_dataset(distribution), steps=2) - @combinations.generate(strategy_combinations()) + @combinations.generate(all_strategy_combinations()) def test_model_interleaved_eval_same_as_direct_eval(self, distribution): with self.cached_session(): user_controlled_model = get_model() @@ -710,16 +812,20 @@ class TestDistributionStrategyWithDatasets(test.TestCase, # TODO(priyag): Enable this test for TPU. Currently tuples/dict don't work # as clone_model's input_tensors argument only seems to accept list and not # tuples or dict. - def test_fit_with_tuple_and_dict_dataset_inputs(self): + + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph', 'eager'])) + def test_fit_with_tuple_and_dict_dataset_inputs(self, distribution): with self.cached_session(): model = multi_input_output_model() optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.001) loss = 'mse' metrics = ['mae', keras.metrics.CategoricalAccuracy()] - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', - '/device:CPU:0']) - model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) input_a_np = np.random.random((10, 3)) input_b_np = np.random.random((10, 5)) @@ -743,7 +849,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase, model.fit(dataset_dict, epochs=1, steps_per_epoch=2, verbose=1) - @combinations.generate(strategy_combinations()) + @combinations.generate(all_strategy_combinations()) def test_fit_eval_and_predict_methods_on_dataset(self, distribution): with self.cached_session(): model = get_model() @@ -792,25 +898,18 @@ class TestDistributionStrategyWithDatasets(test.TestCase, model.evaluate(dataset, steps=2, verbose=1) model.predict(dataset, steps=2) - def test_dataset_input_shape_validation(self): + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_two_gpus], + mode=['graph', 'eager'])) + def test_dataset_wrong_input_shape(self, distribution): with self.cached_session(): model = get_model() optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) loss = 'mse' - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', - '/device:GPU:0']) - - model.compile(optimizer, loss, distribute=strategy) - - # User forgets to batch the dataset - inputs = np.zeros((10, 3), dtype=np.float32) - targets = np.zeros((10, 4), dtype=np.float32) - dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) - dataset = dataset.repeat(100) - - with self.assertRaisesRegexp(ValueError, 'expected input to have shape'): - model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) + model.compile(optimizer, loss, distribute=distribution) # Wrong input shape inputs = np.zeros((10, 5), dtype=np.float32) @@ -823,6 +922,26 @@ class TestDistributionStrategyWithDatasets(test.TestCase, 'expected input to have shape'): model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) + @combinations.generate(combinations.combine( + distribution=[combinations.mirrored_strategy_with_two_gpus], + mode=['graph', 'eager'])) + def test_dataset_no_batch_input_validation(self, distribution): + with self.cached_session(): + model = get_model() + + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss, distribute=distribution) + + # User forgets to batch the dataset + inputs = np.zeros((10, 3), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + + with self.assertRaisesRegexp(ValueError, 'expected input to have shape'): + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) + @combinations.generate(combinations.combine( distribution=[combinations.tpu_strategy_one_step], mode=['graph'])) @@ -842,7 +961,12 @@ class TestDistributionStrategyWithDatasets(test.TestCase, with self.assertRaisesRegexp(ValueError, 'requires fully defined shapes'): model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) - def test_learning_phase_value(self): + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_two_gpus], + mode=['graph', 'eager'])) + def test_learning_phase_value(self, distribution): # TODO(anjalisridhar): Modify this test to use Lambdas since we can compare # meaningful values. Currently we don't pass the learning phase if the # Lambda layer uses the learning phase. @@ -856,15 +980,17 @@ class TestDistributionStrategyWithDatasets(test.TestCase, optimizer = gradient_descent.GradientDescentOptimizer(0.005) loss = 'mse' metrics = ['acc'] - strategy = mirrored_strategy.MirroredStrategy( - ['/device:GPU:0', '/device:GPU:1']) + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) - model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + batch_size = 8 + if isinstance(distribution, mirrored_strategy.CoreMirroredStrategy): + # CoreMirroredStrategy uses global batch size. + batch_size = 8 * distribution.num_replicas_in_sync inputs = np.ones((10, 1), dtype=np.float32) targets = np.ones((10, 1), dtype=np.float32) dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) - dataset = dataset.repeat().batch(8) + dataset = dataset.repeat().batch(batch_size) hist = model.fit(dataset, epochs=1, steps_per_epoch=20, verbose=1) self.assertAlmostEqual(hist.history['acc'][0], 0, 0) @@ -875,24 +1001,51 @@ class TestDistributionStrategyWithDatasets(test.TestCase, inputs = np.ones((10, 1), dtype=np.float32) predict_dataset = dataset_ops.Dataset.from_tensor_slices(inputs) - predict_dataset = predict_dataset.repeat().batch(5) + + predict_dataset = predict_dataset.repeat().batch(batch_size) output = model.predict(predict_dataset, steps=10) - # `predict` runs for 10 steps and in each step you process 100 samples. - ref_output = np.ones((100, 1), dtype=np.float32) + # `predict` runs for 10 steps + ref_output = np.ones((160, 1), dtype=np.float32) self.assertArrayNear(output, ref_output, 1e-1) + @combinations.generate(strategy_minus_tpu_combinations()) + def testOptimizerWithCallbacks(self, distribution): + with self.cached_session(): + model = get_model() + + optimizer = gradient_descent_keras.SGD(0.01) + loss = 'mse' + model.compile(optimizer, loss, distribute=distribution) + + dataset = get_dataset(distribution) + + def schedule(_): + return 0.001 + + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + callbacks=[keras.callbacks.LearningRateScheduler(schedule)]) + grouped_models = distribution.unwrap(model._grouped_model) + with distribution.scope(): + for m in grouped_models: + self.assertAllClose(0.001, keras.backend.get_value( + m.optimizer.lr), atol=1e-05, rtol=1e-05) + class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): - def test_validating_dataset_input_tensors_with_shape_mismatch(self): + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph', 'eager'])) + def test_validating_dataset_input_tensors_with_shape_mismatch(self, + distribution): with self.cached_session(): - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', - '/device:CPU:0']) a = constant_op.constant([1, 2], shape=(1, 2)) b = constant_op.constant([[1, 2], [1, 2]], shape=(2, 2)) x = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': b}) y = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': a}) - with strategy.scope(): + with distribution.scope(): # Removed device and input tensor shape details from the error message # since the order of the device and the corresponding input tensor shape # is not deterministic over different runs. @@ -901,17 +1054,21 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): 'distributed tensor inputs ' 'DistributedValues:.+'): distributed_training_utils.validate_distributed_dataset_inputs( - strategy, x, y) + distribution, x, y) - def test_validating_dataset_input_tensors_with_dtype_mismatch(self): + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph', 'eager'])) + def test_validating_dataset_input_tensors_with_dtype_mismatch(self, + distribution): with self.cached_session(): - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', - '/device:CPU:0']) a = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.int32) b = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.float64) x = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': b}) y = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': a}) - with strategy.scope(): + with distribution.scope(): # Removed device and input tensor dtype details from the error message # since the order of the device and the corresponding input tensor dtype # is not deterministic over different runs. @@ -920,21 +1077,23 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): 'distributed tensor inputs ' 'DistributedValues:.+'): distributed_training_utils.validate_distributed_dataset_inputs( - strategy, x, y) + distribution, x, y) - def test_unsupported_features(self): + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_two_gpus], + mode=['graph', 'eager'])) + def test_unsupported_features(self, distribution): with self.cached_session(): model = get_model() optimizer = gradient_descent.GradientDescentOptimizer(0.001) loss = 'mse' metrics = ['mae'] - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', - '/device:GPU:0']) + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) - model.compile(optimizer, loss, metrics=metrics, distribute=strategy) - - dataset = get_dataset(strategy) + dataset = get_dataset(distribution) # Test with validation split with self.assertRaisesRegexp( @@ -969,30 +1128,33 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): 'you should specify the `steps` argument'): model.predict(dataset, verbose=0) - def test_calling_with_unsupported_predefined_callbacks(self): + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_two_gpus], + mode=['graph', 'eager'])) + def test_calling_with_unsupported_predefined_callbacks(self, distribution): with self.cached_session(): model = get_model() optimizer = gradient_descent.GradientDescentOptimizer(0.001) loss = 'mse' metrics = ['mae'] - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', - '/device:GPU:0']) - model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) - dataset = get_dataset(strategy) + dataset = get_dataset(distribution) def schedule(_): return 0.001 with self.assertRaisesRegexp(ValueError, - 'LearningRateScheduler callback is not ' - 'supported with DistributionStrategy.'): + 'You must specify a Keras Optimizer V2 when ' + 'using'): model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, callbacks=[keras.callbacks.LearningRateScheduler(schedule)]) with self.assertRaisesRegexp(ValueError, - 'ReduceLROnPlateau callback is not ' - 'supported with DistributionStrategy.'): + 'You must specify a Keras Optimizer V2 when ' + 'using'): model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, callbacks=[keras.callbacks.ReduceLROnPlateau()]) with self.assertRaisesRegexp(ValueError, @@ -1003,11 +1165,17 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): callbacks=[keras.callbacks.TensorBoard(histogram_freq=10)]) -class TestDistributionStrategyWithLossMasking(test.TestCase): +class TestDistributionStrategyWithLossMasking(test.TestCase, + parameterized.TestCase): # TODO(priyag): Enable all strategies for this test. Currently it does not # work for TPU due to some invalid datatype. - def test_masking(self): + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_two_gpus], + mode=['graph', 'eager'])) + def test_masking(self, distribution): with self.cached_session(): np.random.seed(1337) x = np.array([[[1], [1]], [[0], [0]]]) @@ -1016,12 +1184,9 @@ class TestDistributionStrategyWithLossMasking(test.TestCase): model.add( keras.layers.TimeDistributed( keras.layers.Dense(1, kernel_initializer='one'))) - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', - '/device:GPU:0']) - model.compile(loss='mse', optimizer=gradient_descent.GradientDescentOptimizer(0.01), - distribute=strategy) + distribute=distribution) y = np.array([[[1], [1]], [[1], [1]]]) dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) dataset = dataset.repeat(100) @@ -1033,7 +1198,7 @@ class TestDistributionStrategyWithLossMasking(test.TestCase): class TestDistributionStrategyWithNormalizationLayer( test.TestCase, parameterized.TestCase): - @combinations.generate(strategy_combinations()) + @combinations.generate(all_strategy_combinations()) def test_batchnorm_correctness(self, distribution): with self.cached_session(): model = keras.models.Sequential() @@ -1065,7 +1230,7 @@ class TestDistributionStrategyWithNormalizationLayer( class TestDistributionStrategyCorrectness(test.TestCase, parameterized.TestCase): - @combinations.generate(strategy_combinations()) + @combinations.generate(all_strategy_combinations()) def test_metric_correctness(self, distribution): with self.cached_session(): keras.backend.set_image_data_format('channels_last') @@ -1088,22 +1253,32 @@ class TestDistributionStrategyCorrectness(test.TestCase, distribute=distribution) batch_size = 64 - batch_size //= distribution.num_replicas_in_sync + if not distributed_training_utils.global_batch_size_supported( + distribution): + batch_size //= distribution.num_replicas_in_sync train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) train_dataset = batch_wrapper(train_dataset, batch_size, distribution) history = model.fit(x=train_dataset, epochs=1, steps_per_epoch=10) self.assertEqual(history.history['binary_accuracy'], [1.0]) - @combinations.generate(strategy_and_inputs()) - def test_correctness(self, distribution, use_numpy): + @combinations.generate(strategy_and_input_combinations()) + def test_correctness(self, distribution, use_numpy, use_validation_data): + with self.cached_session(): tolerance = 1e-5 - if isinstance(distribution, mirrored_strategy.MirroredStrategy): + if isinstance(distribution, (mirrored_strategy.MirroredStrategy, + mirrored_strategy.CoreMirroredStrategy)): # TODO(b/119257215): use the default one once the flakyness is fixed. tolerance = 1e-4 + if (use_validation_data and + not isinstance(distribution, tpu_strategy.TPUStrategy)): + # TODO(b/120435565): Enable tests with use_validation_data once the + # the underlying bug is fixed. + return + keras.backend.set_image_data_format('channels_last') np.random.seed(_RANDOM_SEED) random_seed.set_random_seed(_RANDOM_SEED) @@ -1123,49 +1298,72 @@ class TestDistributionStrategyCorrectness(test.TestCase, # This is used to initialize the model for both the distribution and # non-distribution run. In addition, we add few non-linear layers to make # it non-trivial. - model = keras.Sequential() - model.add(keras.layers.Dense(10, activation='relu', input_shape=(1,))) - model.add(keras.layers.Dense(10, activation='relu')) - model.add(keras.layers.Dense(10, activation='relu')) - model.add(keras.layers.Dense(1)) - initial_weights = model.get_weights() + def _create_model(): + model = keras.Sequential() + model.add(keras.layers.Dense(10, activation='relu', input_shape=(1,))) + model.add(keras.layers.Dense(10, activation='relu')) + model.add(keras.layers.Dense(10, activation='relu')) + model.add(keras.layers.Dense(1)) + return model - def fit_and_predict(with_distribution=None): + model = _create_model() + initial_weights = model.get_weights() + del model # avoid accident usage. + + def fit_eval_and_predict(with_distribution=None): + model = _create_model() # We have initialized the model to the same weight for the distribution # and non-distribution run. model.set_weights(initial_weights) model.compile( loss=keras.losses.mean_squared_error, - optimizer=gradient_descent.GradientDescentOptimizer(0.5), + optimizer=gradient_descent_keras.SGD(0.5), distribute=with_distribution) training_inputs, eval_inputs, predict_inputs = ( - get_correctness_test_inputs(use_numpy, with_distribution, + get_correctness_test_inputs(use_numpy, use_validation_data, + with_distribution, x_train, y_train, x_predict)) - model.fit(**training_inputs) - eval_result = model.evaluate(**eval_inputs) + traning_history = model.fit(**training_inputs).history + + if eval_inputs is not None: + eval_result = model.evaluate(**eval_inputs) + else: + # Creates a dummy identical eval_result to be compared later. + eval_result = 1.0 + weights = model.get_weights() predict_result = model.predict(**predict_inputs) - return weights, eval_result, predict_result + return weights, traning_history, eval_result, predict_result - wts_with_ds, eval_with_ds, predict_with_ds = fit_and_predict( - with_distribution=distribution) - wts_without_ds, eval_without_ds, predict_without_ds = fit_and_predict( - with_distribution=None) + wts_with_ds, history_with_ds, eval_with_ds, predict_with_ds = ( + fit_eval_and_predict(with_distribution=distribution)) - # Verify that the weights, eval results, predict outputs are the same - # within some limits of tolerance. + (wts_without_ds, history_without_ds, eval_without_ds, + predict_without_ds) = fit_eval_and_predict(with_distribution=None) + + # Verify that the weights, training history, eval results, predict outputs + # are the same within some limits of tolerance. self.assertAllClose( - wts_with_ds, wts_without_ds, atol=tolerance, rtol=tolerance) - self.assertAllClose( - eval_with_ds, eval_without_ds, atol=tolerance, rtol=tolerance) - self.assertAllClose( - predict_with_ds, predict_without_ds, atol=tolerance, rtol=tolerance) + wts_with_ds, wts_without_ds, atol=tolerance, rtol=tolerance, + msg='Fail to assert weights after training.') + self.assertAllClose( + eval_with_ds, eval_without_ds, atol=tolerance, rtol=tolerance, + msg='Fail to assert eval results.') + self.assertAllClose( + predict_with_ds, predict_without_ds, atol=tolerance, rtol=tolerance, + msg='Fail to assert predict results.') -# TODO(priyag): Add a test for TPUStrategy with steps_per_run > 1. + if not (isinstance(distribution, tpu_strategy.TPUStrategy) + and distribution.extended.steps_per_run > 1): + # TODO(b/119894254): Enable this test for all cases once the underlying + # bug is fixed. + self.assertAllClose( + history_with_ds, history_without_ds, atol=tolerance, rtol=tolerance, + msg='Fail to assert training history.') if __name__ == '__main__': diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py index c28ab416518..8ac659abe96 100644 --- a/tensorflow/contrib/distribute/python/metrics_v1_test.py +++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py @@ -72,14 +72,14 @@ def _regression_dataset_fn(): "predictions": [1., .75, .25, 0.]}).repeat() -# TODO(priyag): Add TPU Strategy to this once metrics aggregate correctly using -# ReplicaLocalVariables on TPUs. Submit http://cl/208914352. def all_combinations(): return combinations.combine( distribution=[combinations.default_strategy, combinations.one_device_strategy, combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus], + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus], mode=["graph"]) @@ -100,18 +100,19 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): if isinstance(distribution, tpu_strategy.TPUStrategy): def step_fn(ctx, inputs): value, update = distribution.call_for_each_replica( - metric_fn, args=[inputs]) + metric_fn, args=inputs) ctx.set_non_tensor_output(name="value", output=value) return distribution.group(update) ctx = distribution.run_steps_on_dataset( - step_fn, iterator, iterations=distribution.steps_per_run) + step_fn, iterator, iterations=distribution.extended.steps_per_run) update = ctx.run_op value = ctx.non_tensor_outputs["value"] # In each run, we run multiple steps, and each steps consumes as many # batches as number of replicas. batches_per_update = ( - distribution.num_replicas_in_sync * distribution.steps_per_run) + distribution.num_replicas_in_sync * + distribution.extended.steps_per_run) else: value, update = distribution.call_for_each_replica( metric_fn, iterator.get_next()) diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py index c6562463edb..dcc9df4cda5 100644 --- a/tensorflow/contrib/distribute/python/minimize_loss_test.py +++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py @@ -25,6 +25,7 @@ from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python.single_loss_example import batchnorm_example from tensorflow.contrib.distribute.python.single_loss_example import minimize_loss_example from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import reduce_util from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import constant_op @@ -63,7 +64,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): model_fn, dataset_fn, layer = minimize_loss_example( optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss) - def step_fn(ctx, *inputs): + def step_fn(ctx, inputs): del ctx # Unused return distribution.group( distribution.call_for_each_replica(model_fn, args=inputs)) @@ -157,7 +158,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): use_callable_loss=True, create_optimizer_inside_model_fn=True) - def step_fn(ctx, *inputs): + def step_fn(ctx, inputs): del ctx # Unused return distribution.group( distribution.call_for_each_replica(model_fn, args=inputs)) @@ -226,7 +227,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): renorm=renorm, update_ops_in_replica_mode=not update_ops_in_cross_replica_mode) - def step_fn(ctx, *inputs): + def step_fn(ctx, inputs): del ctx # Unused fetches = distribution.unwrap( distribution.call_for_each_replica(model_fn, args=inputs)) @@ -285,7 +286,9 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): distribution=[ combinations.one_device_strategy, combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus ]), combinations.combine( mode=["graph"], use_callable_loss=[True, False]) + @@ -321,10 +324,10 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): labels = dataset_ops.Dataset.from_tensors([[6.], [21.]]) return dataset_ops.Dataset.zip((features, labels)).repeat() - def step_fn(ctx, x, y): + def step_fn(ctx, inputs): del ctx # Unused return distribution.group( - distribution.call_for_each_replica(model_fn, args=(x, y))) + distribution.call_for_each_replica(model_fn, args=inputs)) iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) @@ -341,7 +344,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): run_step() v = all_vars[0] - self.assertTrue(all([v is vi for vi in all_vars[1:]])) + self.assertTrue(all(v is vi for vi in all_vars[1:])) weight = numpy.squeeze(self.evaluate(v)) # Our model is: # predict = x * w @@ -402,21 +405,21 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): train_op = optimizer.minimize(loss_fn) loss = loss_fn() output_context.set_last_step_output( - name="replica_loss_agg", + name="replica_loss_reduced", output=loss, - aggregation=variables_lib.VariableAggregation.MEAN) + reduce_op=reduce_util.ReduceOp.MEAN) output_context.set_non_tensor_output(key1, value1) return (train_op, loss) - def step_fn(output_context, *inputs): + def step_fn(output_context, inputs): (train_op, loss) = distribution.call_for_each_replica( model_fn, args=(output_context,) + inputs) output_context.set_last_step_output( - name="cross_replica_loss_agg", + name="cross_replica_loss_reduced", output=loss, - aggregation=variables_lib.VariableAggregation.MEAN) + reduce_op=reduce_util.ReduceOp.MEAN) output_context.set_last_step_output( - name="cross_replica_loss_noagg", + name="cross_replica_loss_not_reduced", output=loss) return distribution.group(train_op) @@ -424,16 +427,16 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def run_step(): initial_loss = lambda: constant_op.constant(1e7) - # Initial values corresponding to aggregated losses are just single - # tensors. But for non aggregated losses, we need to have initial + # Initial values corresponding to reduced losses are just single + # tensors. But for non reduced losses, we need to have initial # values that are of the same structure as non reduced losses. In # MirroredStrategy, this will be a list of losses, in TPUStrategy # it will be single tensor. Using `broadcast` followed by `unwrap` # gives us the desired initial value structure. initial_loop_values = { - "replica_loss_agg": initial_loss(), - "cross_replica_loss_agg": initial_loss(), - "cross_replica_loss_noagg": + "replica_loss_reduced": initial_loss(), + "cross_replica_loss_reduced": initial_loss(), + "cross_replica_loss_not_reduced": distribution.unwrap(distribution.broadcast(initial_loss())) } ctx = distribution.run_steps_on_dataset( @@ -443,17 +446,17 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): self.assertEqual({key1: [value1]}, ctx.non_tensor_outputs) self._verify_loss_output( initial_loss(), - loss_output=ctx.last_step_outputs["replica_loss_agg"], - aggregated=True, distribution=distribution) + loss_output=ctx.last_step_outputs["replica_loss_reduced"], + reduced=True, distribution=distribution) self._verify_loss_output( initial_loss(), - loss_output=ctx.last_step_outputs["cross_replica_loss_agg"], - aggregated=True, distribution=distribution) + loss_output=ctx.last_step_outputs["cross_replica_loss_reduced"], + reduced=True, distribution=distribution) self._verify_loss_output( initial_loss(), - loss_output=ctx.last_step_outputs["cross_replica_loss_noagg"], - aggregated=False, distribution=distribution) - return (ctx.run_op, ctx.last_step_outputs["replica_loss_agg"]) + loss_output=ctx.last_step_outputs["cross_replica_loss_not_reduced"], + reduced=False, distribution=distribution) + return (ctx.run_op, ctx.last_step_outputs["replica_loss_reduced"]) self.evaluate(distribution.initialize()) if not context.executing_eagerly(): @@ -478,18 +481,16 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): error_is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) self.assertTrue(error_is_not_increasing) - def _verify_loss_output(self, initial_loss, loss_output, aggregated, + def _verify_loss_output(self, initial_loss, loss_output, reduced, distribution): - if not aggregated: - self.assertEqual(distribution.num_replicas_in_sync, - len(distribution.unwrap(loss_output))) - loss_output = distribution.reduce( - aggregation=variables_lib.VariableAggregation.MEAN, - value=loss_output, destinations="/device:CPU:0") - - unwrapped_output = distribution.unwrap(loss_output) - self.assertEqual(1, len(unwrapped_output)) - loss_tensor = unwrapped_output[0] + if not reduced: + self.assertLen(distribution.unwrap(loss_output), + distribution.num_replicas_in_sync) + loss_tensor = distribution.reduce(reduce_util.ReduceOp.MEAN, loss_output) + else: + unwrapped_output = distribution.unwrap(loss_output) + self.assertLen(unwrapped_output, 1) + loss_tensor = unwrapped_output[0] self.assertEqual(initial_loss.dtype, loss_tensor.dtype) self.assertEqual(initial_loss.shape, loss_tensor.shape) diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 2d75024e7a0..20f1a08d426 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -12,293 +12,35 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Class MirroredStrategy implementing DistributionStrategy.""" +"""Contrib version of MirroredStrategy.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import contextlib -from functools import partial -import threading +import functools -from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib -from tensorflow.contrib.distribute.python import shared_variable_creator -from tensorflow.contrib.distribute.python import values -from tensorflow.python import pywrap_tensorflow -from tensorflow.python.distribute import multi_worker_util -from tensorflow.python.eager import context -from tensorflow.python.eager import tape -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import device as tf_device -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables as variables_lib -from tensorflow.python.training import coordinator -from tensorflow.python.training import device_util -from tensorflow.python.training import distribute as distribute_lib -from tensorflow.python.util import nest +from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import mirrored_strategy +from tensorflow.python.distribute import values -# TODO(josh11b): Replace asserts in this file with if ...: raise ... - - -@contextlib.contextmanager -def _enter_graph(g): - if context.executing_eagerly(): - with g.as_default(), context.eager_mode(): - yield - else: - with g.as_default(): - yield - - -def _cpu_device(device): - cpu_device = tf_device.DeviceSpec.from_string(device) - cpu_device.merge_from(tf_device.DeviceSpec(device_type="CPU", device_index=0)) - return cpu_device.to_string() - - -class _RequestedStop(Exception): - pass - - -# _call_for_each_replica and _reduce_non_distributed_value are not members of -# MirroredStrategy so that they are generally not allowed to use anything -# specific to MirroredStrategy and thus can be shared with other distribution -# strategies. - - -# TODO(yuefengz): maybe create a common class for those who need to call this -# _call_for_each_replica. -def _call_for_each_replica(distribution, fn, args, kwargs): - """Run `fn` in separate threads, once per replica/worker device. - - Args: - distribution: the DistributionStrategy object. - fn: function to run (will be run once per device, each in its own thread). - args: positional arguments for `fn` - kwargs: keyword arguments for `fn`. - - Returns: - Merged return value of `fn` across all replicas. - - Raises: - RuntimeError: If fn() calls get_replica_context().merge_call() a different - number of times from the available devices. - """ - # TODO(josh11b): Add this option once we add synchronization to variable - # creation. Until then, this is pretty unsafe to use. - run_concurrently = False - if not context.executing_eagerly(): - # Needed for per-thread device, etc. contexts in graph mode. - ops.get_default_graph().switch_to_thread_local() - - coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,)) - - shared_variable_store = {} - - # TODO(isaprykin): Create these threads once instead of during every run() - # call. - threads = [] - for index, d in enumerate(distribution.worker_devices): - variable_creator_fn = shared_variable_creator.make_fn( - shared_variable_store, index) - t = MirroredStrategy._MirroredReplicaThread( # pylint: disable=protected-access - distribution, coord, d, variable_creator_fn, fn, - *values.select_device(d, args), **values.select_device(d, kwargs)) - threads.append(t) - - for t in threads: - t.start() - - # When `fn` starts `should_run` event is set on _MirroredReplicaThread - # (`MRT`) threads. The execution waits until - # `MRT.has_paused` is set, which indicates that either `fn` is - # complete or a `get_replica_context().merge_call()` is called. If `fn` is - # complete, then `MRT.done` is set to True. Otherwise, arguments - # of `get_replica_context().merge_call` from all paused threads are grouped - # and the `merge_fn` is performed. Results of the - # `get_replica_context().merge_call` are then set to `MRT.merge_result`. - # Each such `get_replica_context().merge_call` call returns the - # `MRT.merge_result` for that thread when `MRT.should_run` event - # is reset again. Execution of `fn` resumes. - - try: - with coord.stop_on_exception(): - all_done = False - while not all_done and not coord.should_stop(): - done = [] - if run_concurrently: - for t in threads: - t.should_run.set() - for t in threads: - t.has_paused.wait() - t.has_paused.clear() - if coord.should_stop(): - return None - done.append(t.done) - else: - for t in threads: - t.should_run.set() - t.has_paused.wait() - t.has_paused.clear() - if coord.should_stop(): - return None - done.append(t.done) - if coord.should_stop(): - return None - all_done = all(done) - if not all_done: - if any(done): - raise RuntimeError("Some replicas made a different number of " - "replica_context().merge_call() calls.") - # get_replica_context().merge_call() case - merge_args = values.regroup({t.device: t.merge_args for t in threads}) - merge_kwargs = values.regroup( - {t.device: t.merge_kwargs for t in threads}) - # We capture the name_scope of the MRT when we call merge_fn - # to ensure that if we have opened a name scope in the MRT, - # it will be respected when executing the merge function. We only - # capture the name_scope from the first MRT and assume it is - # the same for all other MRTs. - mtt_captured_name_scope = threads[0].captured_name_scope - with ops.name_scope(mtt_captured_name_scope): - merge_result = threads[0].merge_fn(distribution, *merge_args, - **merge_kwargs) - for t in threads: - t.merge_result = values.select_device(t.device, merge_result) - finally: - for t in threads: - t.should_run.set() - coord.join(threads) - - return values.regroup({t.device: t.main_result for t in threads}) - - -def _reduce_non_distributed_value(distribution, aggregation, value, - destinations): - """Reduce a non-DistributedValue `value` to `destinations`.""" - if isinstance(value, values.DistributedValues): - raise ValueError("You are passing a `DistributedValue` to " - "`_reduce_non_distributed_value`, which is not allowed.") - - # If the same value is present on all replicas then the PerReplica value will - # be a single value. We also handle the case when `value` is a single value - # and equal to 0. - if value == 0: - return 0 - # If the aggregation type is MEAN or ONLY_FIRST_REPLICA, then this - # essentially means that the same value should be on all destinations. - if aggregation in ( - variable_scope.VariableAggregation.MEAN, - variable_scope.VariableAggregation.ONLY_FIRST_REPLICA): - return value - - cross_tower_ops_lib.validate_destinations(destinations) - # We do not support an aggregation type of SUM if the value is the same across - # all replicas. We call this as part of assign functions for MirroredVariables - # and summing up identical values across replicas is not clearly defined. - if (len(distribution.worker_devices) != 1 or - not cross_tower_ops_lib.check_destinations(destinations)): - raise ValueError("A non-DistributedValues value %s cannot be reduced with " - "the given aggregation %s." % (value, aggregation)) - # TODO(anjalisridhar): Moves these methods to a device utility file? - devices = cross_tower_ops_lib.get_devices_from(destinations) - if len(devices) == 1: - with ops.device(devices[0]): - return array_ops.identity(value) - else: - value_updates = {} - for d in devices: - with ops.device(d): - value_updates[d] = array_ops.identity(value) - return values.Mirrored(value_updates) - - -def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs): # pylint: disable=g-missing-docstring - # Figure out what collections this variable should be added to. - # We'll add the MirroredVariable to those collections instead. - collections = kwargs.pop("collections", None) - if collections is None: - collections = [ops.GraphKeys.GLOBAL_VARIABLES] - kwargs["collections"] = [] - - # Get synchronization value - synchronization = kwargs.get("synchronization", - variable_scope.VariableSynchronization.ON_WRITE) - if synchronization == variable_scope.VariableSynchronization.NONE: - raise ValueError("`NONE` variable synchronization mode is not " - "supported with `Mirrored` distribution strategy. Please" - " change the `synchronization` for variable: " + - kwargs["name"]) - elif synchronization == variable_scope.VariableSynchronization.ON_READ: - # Variables that are to be synced on read are replica local. - is_replica_local = True - kwargs["trainable"] = False - elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or - synchronization == variable_scope.VariableSynchronization.AUTO): - # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`. - is_replica_local = False - else: - raise ValueError("Invalid variable synchronization mode: " + - synchronization + " for variable: " + kwargs["name"]) - - # Get aggregation value - aggregation = kwargs.pop("aggregation", - variable_scope.VariableAggregation.NONE) - if aggregation not in ( - variable_scope.VariableAggregation.NONE, - variable_scope.VariableAggregation.SUM, - variable_scope.VariableAggregation.MEAN, - variable_scope.VariableAggregation.ONLY_FIRST_REPLICA - ): - raise ValueError("Invalid variable aggregation mode: " + aggregation + - " for variable: " + kwargs["name"]) - - # Ignore user-specified caching device, not needed for mirrored variables. - kwargs.pop("caching_device", None) - - # TODO(josh11b,apassos): It would be better if variable initialization - # was never recorded on the tape instead of having to do this manually - # here. - with tape.stop_recording(): - index = real_mirrored_creator(devices, *args, **kwargs) - - if is_replica_local: - result = values.ReplicaLocalVariable( - index, index[devices[0]], aggregation) - else: - result = values.MirroredVariable(index, index[devices[0]], aggregation) - - # Add the wrapped variable to the requested collections. - # The handling of eager mode and the global step matches - # ResourceVariable._init_from_args(). - if not context.executing_eagerly(): - g = ops.get_default_graph() - # If "trainable" is True, next_creator() will add the member variables - # to the TRAINABLE_VARIABLES collection, so we manually remove - # them and replace with the MirroredVariable. We can't set - # "trainable" to False for next_creator() since that causes functions - # like implicit_gradients to skip those variables. - if kwargs.get("trainable", True): - collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) - l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) - for v in index.values(): - if v in l: - l.remove(v) - g.add_to_collections(collections, result) - elif ops.GraphKeys.GLOBAL_STEP in collections: - ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result) - - return result +# pylint: disable=protected-access,invalid-name +_call_for_each_replica = mirrored_strategy._call_for_each_replica +_reduce_non_distributed_value = mirrored_strategy._reduce_non_distributed_value +_create_mirrored_variable = mirrored_strategy._create_mirrored_variable +all_local_devices = mirrored_strategy.all_local_devices +CoreMirroredStrategy = mirrored_strategy.MirroredStrategy +CoreMirroredExtended = mirrored_strategy.MirroredExtended +# pylint: enable=protected-access,invalid-name class MirroredStrategy(distribute_lib.DistributionStrategy): """Mirrors vars to distribute across multiple devices and machines. + *** contrib version *** + This strategy uses one replica per device and sync replication for its multi-GPU version. @@ -353,468 +95,66 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): cross_device_ops=None, auto_shard_dataset=False, cross_tower_ops=None): - super(MirroredStrategy, self).__init__() - assert not (cross_device_ops and cross_tower_ops) - self._cross_tower_ops = cross_device_ops or cross_tower_ops - self._auto_shard_dataset = auto_shard_dataset - # Remember num GPUs which might be needed by `configure` method. if num_gpus is not None and num_gpus_per_worker is not None: raise ValueError( "You cannot specify both `num_gpus` and `num_gpus_per_worker`.") - if num_gpus is not None: - self._num_gpus = num_gpus - else: - self._num_gpus = num_gpus_per_worker - - self._initialize_local(self._num_gpus, devices) - - def _initialize_local(self, num_gpus, devices): - """Initializes the object for local training.""" - self._cluster_spec = None - # Convert `num_gpus` into `devices`, shouldn't specify both. - if devices is None: - if num_gpus is None: - num_gpus = context.num_gpus() - if num_gpus == 0: - devices = ["/device:CPU:0"] - else: - devices = ["/device:GPU:%d" % d for d in range(num_gpus)] - elif num_gpus is not None: - raise ValueError("Must only specify one of `devices` and `num_gpus`.") - self._num_gpus = num_gpus - # TODO(yuefengz): consider setting the default device. - - assert devices, "Must specify at least one device." - assert len(set(devices)) == len(devices), ( - "No duplicates allowed in `devices` argument.") - # TODO(josh11b): Require at least 2 devices? - self._devices = [device_util.resolve(d) for d in devices] - self._canonical_device_set = set(self._devices) - self._device_index = values.PerReplica( - {d: i for i, d in enumerate(devices)}) - - def _initialize_multi_worker(self, num_gpus, cluster_spec): - """Initializes the object for multi-worker training.""" - cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec) - self._cluster_spec = cluster_spec - - self._workers = [] - for job in ["chief", "worker"]: - for task in range(len(cluster_spec.as_dict().get(job, []))): - self._workers.append("/job:%s/task:%d" % (job, task)) - if num_gpus is None: - raise ValueError("`num_gpus` is required if `cluster_spec` is given.") - if num_gpus > 0: - self._worker_devices = [ - (worker, [ - device_util.canonicalize(worker + "/device:GPU:%d" % gpu) - for gpu in range(num_gpus) - ]) for worker in self._workers - ] + num_gpus = num_gpus_per_worker + extended = MirroredExtended(self, devices, num_gpus, + cross_device_ops or cross_tower_ops, + auto_shard_dataset) + super(MirroredStrategy, self).__init__(extended) + + +class MirroredExtended(CoreMirroredExtended): + """Implementation of (contrib) MirroredStrategy.""" + + def __init__(self, + container_strategy, + devices=None, + num_gpus_per_worker=None, + cross_device_ops=None, + auto_shard_dataset=False): + if devices is None: + devices = mirrored_strategy.all_local_devices(num_gpus_per_worker) + elif num_gpus_per_worker is not None: + raise ValueError( + "Must only specify one of `devices` and `num_gpus_per_worker`.") + super(MirroredExtended, self).__init__(container_strategy, devices, + cross_device_ops) + self._auto_shard_dataset = auto_shard_dataset + + def _make_dataset_iterator(self, dataset): + """Make iterator from dataset without splitting the batch. + + This implementation is different than the one in + `tf.distribute.MirroredStrategy` for purposes of backward compatibility. + We treat the incoming dataset's batch size as per replica batch size. + + Args: + dataset: `tf.data.Dataset` for input. + Returns: + An `InputIterator` which returns inputs for each step of the computation. + """ + if self._local_mode: + worker = device_util.canonicalize("/device:CPU:0") + worker_device_pairs = [(worker, self._devices)] else: - self._worker_devices = [ - (worker, [device_util.canonicalize(worker, "/device:CPU:0")]) - for worker in self._workers - ] + worker_device_pairs = self._worker_devices + return values.DatasetIterator(dataset, worker_device_pairs) - devices = nest.flatten([l for _, l in self._worker_devices]) - - # Setting `_default_device` will add a device scope in the - # distribution.scope. We set the default device to the first worker. When - # users specify device under distribution.scope by - # with tf.device("/cpu:0"): - # ... - # their ops will end up on the cpu device of its first worker, e.g. - # "/job:worker/task:0/device:CPU:0". Note this is not used in replica mode. - self._default_device = self._workers[0] - - assert devices, "Must specify at least one device." - assert len(set(devices)) == len(devices), ( - "No duplicates allowed in `devices` argument.") - # TODO(josh11b): Require at least 2 devices? - self._devices = [device_util.resolve(d) for d in devices] - self._canonical_device_set = set(self._devices) - self._device_index = values.PerReplica( - {d: i for i, d in enumerate(devices)}) - - def _create_variable(self, next_creator, *args, **kwargs): - """Create a mirrored variable. See `DistributionStrategy.scope`.""" - colocate_with = kwargs.pop("colocate_with", None) - devices = self._get_devices_from(colocate_with) - - def _real_mirrored_creator(devices, *args, **kwargs): # pylint: disable=g-missing-docstring - index = {} - for i, d in enumerate(devices): - with ops.device(d): - if i > 0: - # Give replicas meaningful distinct names: - var0name = index[devices[0]].name.split(":")[0] - # We append a / to variable names created on replicas with id > 0 to - # ensure that we ignore the name scope and instead use the given - # name as the absolute name of the variable. - kwargs["name"] = "%s/replica_%d/" % (var0name, i) - # Initialize replicas with the same value: - def initial_value_fn(device=d): - if context.executing_eagerly(): - init_value = index[devices[0]].value() - return array_ops.identity(init_value) - else: - with ops.device(device): - init_value = index[devices[0]].initial_value - return array_ops.identity(init_value) - kwargs["initial_value"] = initial_value_fn - with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): - # Don't record operations (e.g. other variable reads) during - # variable creation. - with tape.stop_recording(): - v = next_creator(*args, **kwargs) - assert not isinstance(v, values.DistributedVariable) - index[d] = v - return index - - return _create_mirrored_variable(devices, _real_mirrored_creator, *args, - **kwargs) - - def distribute_dataset(self, dataset_fn): - if self._cluster_spec: - return values.MultiWorkerDataset( - partial(self._call_dataset_fn, dataset_fn), self._worker_devices, - auto_shard=self._auto_shard_dataset) - else: + def _distribute_dataset(self, dataset_fn): + if self._local_mode: return values.PerReplicaDataset( self._call_dataset_fn(dataset_fn), self._devices) - - # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. - def _run_steps_on_dataset(self, fn, iterator, iterations, - initial_loop_values=None): - if initial_loop_values is None: - initial_loop_values = {} - initial_loop_values = nest.flatten(initial_loop_values) - - ctx = values.MultiStepContext() - def body(i, *args): - """A wrapper around `fn` to create the while loop body.""" - del args - fn_inputs = iterator.get_next() - if not isinstance(fn_inputs, tuple): - fn_inputs = (fn_inputs,) - fn_result = fn(ctx, *fn_inputs) - for (name, output) in ctx.last_step_outputs.items(): - # Convert all outputs to tensors, potentially from `DistributedValues`. - ctx.last_step_outputs[name] = self.unwrap(output) - flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) - with ops.control_dependencies([fn_result]): - return [i + 1] + flat_last_step_outputs - - # We capture the control_flow_context at this point, before we run `fn` - # inside a while_loop. This is useful in cases where we might need to exit - # these contexts and get back to the outer context to do some things, for - # e.g. create an op which should be evaluated only once at the end of the - # loop on the host. One such usage is in creating metrics' value op. - self._outer_control_flow_context = ( - ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access - - cond = lambda i, *args: i < iterations - i = constant_op.constant(0) - loop_result = control_flow_ops.while_loop( - cond, body, [i] + initial_loop_values, name="", - parallel_iterations=1, back_prop=False, swap_memory=False, - return_same_structure=True) - del self._outer_control_flow_context - - ctx.run_op = control_flow_ops.group(loop_result) - - # Convert the last_step_outputs from a list to the original dict structure - # of last_step_outputs. - last_step_tensor_outputs = loop_result[1:] - last_step_tensor_outputs_dict = nest.pack_sequence_as( - ctx.last_step_outputs, last_step_tensor_outputs) - - for (name, aggregation) in ctx._last_step_outputs_aggregations.items(): # pylint: disable=protected-access - output = last_step_tensor_outputs_dict[name] - # For outputs that have already been aggregated, wrap them in a Mirrored - # container, else in a PerReplica container. - if aggregation is variables_lib.VariableAggregation.NONE: - last_step_tensor_outputs_dict[name] = values.regroup( - {d: t for d, t in zip(self._devices, output)}, values.PerReplica) - else: - assert len(output) == 1 - last_step_tensor_outputs_dict[name] = output[0] - - ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access - return ctx - - def _broadcast(self, tensor, destinations): - # TODO(josh11b): In eager mode, use one thread per device, or async mode. - return self._get_cross_tower_ops().broadcast(tensor, destinations or - self._devices) - - def _call_for_each_replica(self, fn, args, kwargs): - return _call_for_each_replica(self, fn, args, kwargs) - - def configure(self, - session_config=None, - cluster_spec=None, - task_type=None, - task_id=None): - del task_type, task_id - - if session_config: - session_config.isolate_session_state = True - - if cluster_spec: - self._initialize_multi_worker(self._num_gpus, cluster_spec) - - if self._cross_tower_ops is None: - if self._cluster_spec: - # It currently cannot detect the toplogy of remote workers. So we - # hard-code the multi-worker all-reduce algorithm for now. - if len(self._workers) == 1: - # The default is "nccl". - self._cross_tower_ops = cross_tower_ops_lib.AllReduceCrossDeviceOps() - else: - # The default is hierarchical reduce and broadcast. - self._cross_tower_ops = cross_tower_ops_lib.MultiWorkerAllReduce( - self._workers, self._num_gpus) - else: - self._cross_tower_ops = cross_tower_ops_lib.choose_the_best( - self._devices, session_config=session_config) - - def _get_cross_tower_ops(self): - if self._cross_tower_ops is None: - self._cross_tower_ops = ( - cross_tower_ops_lib.ReductionToOneDeviceCrossDeviceOps()) - return self._cross_tower_ops - - def _reduce(self, aggregation, value, destinations): - assert not isinstance(value, values.Mirrored) - if not isinstance(value, values.DistributedValues): - # This function handles reducing values that are not PerReplica or - # Mirrored values. For example, the same value could be present on all - # replicas in which case `value` would be a single value or value could - # be 0. - return _reduce_non_distributed_value(self, aggregation, value, - destinations) - if aggregation == variable_scope.VariableAggregation.ONLY_FIRST_REPLICA: - value = value.get(self._devices[0]) - if isinstance(value, (int, float)): - return value - return self.broadcast(value, destinations) - return self._get_cross_tower_ops().reduce( - aggregation, value, destinations=destinations) - - def _batch_reduce(self, aggregation, value_destination_pairs): - if aggregation == variable_scope.VariableAggregation.ONLY_FIRST_REPLICA: - return [self.broadcast(v.get(self._devices[0]), d) - for v, d in value_destination_pairs] - return self._get_cross_tower_ops().batch_reduce(aggregation, - value_destination_pairs) - - def _update(self, var, options, fn, *args, **kwargs): - # TODO(josh11b): In eager mode, use one thread per device. - assert isinstance(var, values.DistributedVariable) - should_group = options.pop("grouped") - assert not options # Validate that we are processing all of the options. - updates = {} - for d, v in var._index.items(): # pylint: disable=protected-access - name = "update_%d" % self._device_index.get(d) - with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): - # If args and kwargs are not mirrored, the value is returned as is. - updates[d] = fn(v, - *values.select_device_mirrored(d, args), - **values.select_device_mirrored(d, kwargs)) - return values.update_regroup(self, updates, should_group) - - def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs): - assert isinstance(colocate_with, list) - should_group = options.pop("grouped") - assert not options # Validate that we are processing all of the options. - # TODO(josh11b): In eager mode, use one thread per device. - updates = {} - for d in colocate_with: - name = "update_%d" % self._device_index.get(d) - with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): - updates[d] = fn(*values.select_device_mirrored(d, args), - **values.select_device_mirrored(d, kwargs)) - return values.update_regroup(self, updates, should_group) - - def read_var(self, replica_local_var): - """Read the aggregate value of a replica-local variable.""" - if isinstance(replica_local_var, values.ReplicaLocalVariable): - return replica_local_var._get_cross_replica() # pylint: disable=protected-access - assert isinstance(replica_local_var, values.Mirrored) - return array_ops.identity(replica_local_var.get()) - - def _unwrap(self, val): - if isinstance(val, values.DistributedValues): - # Return in a deterministic order. - if set(val.devices) == self._canonical_device_set: - return [val.get(device=d) for d in self._devices] - return [val.get(device=d) for d in sorted(val.devices)] - return [val] - - def value_container(self, val): - return values.value_container(val) - - @property - def num_replicas(self): - return len(self._devices) - - @property - def num_replicas_in_sync(self): - return len(self._devices) - - def _worker_device_index(self): - return self._device_index - - @property - def worker_devices(self): - # Make a copy to prevent users from accidentally mutating our copy. - return list(self._devices) - - @property - def parameter_devices(self): - return list(self._devices) - - @property - def between_graph(self): - return False - - @property - def should_init(self): - return True - - @property - def should_checkpoint(self): - return True - - @property - def should_save_summary(self): - return True - - def non_slot_devices(self, var_list): - del var_list - return list(self._devices) - - def _get_devices_from(self, colocate_with=None): - if colocate_with is None: - return self._devices else: - return cross_tower_ops_lib.get_devices_from(colocate_with) - - class _MirroredReplicaThread(threading.Thread): - """A thread that runs() a function on a device.""" - - def __init__(self, dist, coord, device, variable_creator_fn, fn, *args, - **kwargs): - super(MirroredStrategy._MirroredReplicaThread, self).__init__() # pylint: disable=protected-access - self.coord = coord - self.distribution = dist - self.device = device - self.replica_id = dist.worker_devices.index(device) - self.variable_creator_fn = variable_creator_fn - # State needed to run and return the results of `fn`. - self.main_fn = fn - self.main_args = args - self.main_kwargs = kwargs - self.main_result = None - self.done = False - # State needed to run the next merge_call() (if any) requested via - # ReplicaContext. - self.merge_fn = None - self.merge_args = None - self.merge_kwargs = None - self.merge_result = None - self.captured_name_scope = None - # We use a thread.Event for the main thread to signal when this - # thread should start running (`should_run`), and another for - # this thread to transfer control back to the main thread - # (`has_paused`, either when it gets to a - # `get_replica_context().merge_call` or when `fn` returns). In - # either case the event starts cleared, is signaled by calling - # set(). The receiving thread waits for the signal by calling - # wait() and then immediately clearing the event using clear(). - self.should_run = threading.Event() - self.has_paused = threading.Event() - # These fields have to do with inheriting various contexts from the - # parent thread: - # pylint: disable=protected-access - self.context_mode = context.context()._eager_context.mode - if not context.context()._context_handle: - context.context()._initialize_handle_and_devices() - self.context_device_policy = ( - pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy( - context.context()._context_handle)) - self.graph = ops.get_default_graph() - self._variable_creator_stack = self.graph._variable_creator_stack[:] - self._captured_var_scope = variable_scope.get_variable_scope() - # Adding a "/" at end lets us re-enter this scope later. - self._name_scope = self.graph.get_name_scope() - if self._name_scope: - self._name_scope += "/" - if self.replica_id > 0: - if not self._name_scope: - self._name_scope = "" - self._name_scope += "replica_%d/" % self.replica_id - - def run(self): - # pylint: disable=protected-access - self.graph._variable_creator_stack = self._variable_creator_stack - self.should_run.wait() - self.should_run.clear() - try: - if self.coord.should_stop(): - return - with self.coord.stop_on_exception(), \ - context.context()._mode(self.context_mode), \ - context.context().device_policy(self.context_device_policy), \ - _enter_graph(self.graph), \ - MirroredReplicaContext(self.distribution, self.replica_id), \ - ops.device(self.device), \ - ops.name_scope(self._name_scope), \ - variable_scope.variable_scope( - self._captured_var_scope, reuse=self.replica_id > 0), \ - variable_scope.variable_creator_scope(self.variable_creator_fn): - self.main_result = self.main_fn(*self.main_args, **self.main_kwargs) - self.done = True - finally: - self.has_paused.set() - - -class MirroredReplicaContext(distribute_lib.ReplicaContext): - """ReplicaContext used in MirroredStrategy.call_for_each_replica(). - - Opened in `_MirroredReplicaThread`, to allow the user to invoke - `MirroredStrategy`'s specific implementation of `merge_call()`, - which works by delegating the function and its arguments to - the main thread (the one that invoked - `MirroredStrategy.call_for_each_replica()`). - """ - - def _merge_call(self, fn, args, kwargs): - """Delegate to the main thread to actually perform merge_call().""" - t = threading.current_thread() # a _MirroredReplicaThread - t.merge_fn = fn - t.merge_args = args - t.merge_kwargs = kwargs - t.captured_name_scope = t.graph.get_name_scope() - # Adding a "/" at end lets us re-enter this scope later. - if t.captured_name_scope: - t.captured_name_scope += "/" - t.has_paused.set() - t.should_run.wait() - t.should_run.clear() - if t.coord.should_stop(): - raise _RequestedStop() - return t.merge_result + return values.MultiWorkerDataset( + functools.partial(self._call_dataset_fn, dataset_fn), + self._worker_devices, + auto_shard=self._auto_shard_dataset) + # TODO(priyag): Delete this once all strategies use global batch size. @property - def device(self): - raise RuntimeError("Use .devices instead") - - @property - def devices(self): - distribute_lib.require_replica_context(self) - return [self._distribution_strategy.worker_devices[self._replica_id]] + def _global_batch_size(self): + return False diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index 1fd18e09c01..66512f983e1 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -20,22 +20,27 @@ from __future__ import print_function import sys +from absl.testing import parameterized import numpy as np +from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import mirrored_strategy from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.contrib.distribute.python import strategy_test_lib -from tensorflow.contrib.distribute.python import values from tensorflow.core.protobuf import config_pb2 from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import distribution_strategy_context as ds_context +from tensorflow.python.distribute import reduce_util +from tensorflow.python.distribute import values from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import function from tensorflow.python.eager import test from tensorflow.python.framework import constant_op +from tensorflow.python.framework import func_graph from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util from tensorflow.python.keras.engine import training as keras_training from tensorflow.python.keras.layers import core as keras_core from tensorflow.python.layers import core @@ -46,8 +51,6 @@ from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables -from tensorflow.python.training import device_util -from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import gradient_descent from tensorflow.python.training import optimizer as optimizer_lib from tensorflow.python.training import server_lib @@ -56,248 +59,229 @@ from tensorflow.python.training import server_lib GPU_TEST = "test_gpu" in sys.argv[0] -class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase): +@combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus], + mode=["graph", "eager"])) +class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase, + parameterized.TestCase): - def _get_distribution_strategy(self): - devices = ["/device:CPU:0", "/device:GPU:0"] - if GPU_TEST: - self.assertGreater(context.num_gpus(), 0) - if context.num_gpus() > 1: - devices = ["/device:GPU:0", "/device:GPU:1"] - print(self.id().split(".")[-1], "devices:", ", ".join(devices)) - return mirrored_strategy.MirroredStrategy(devices) + def testMinimizeLoss(self, distribution): + if context.executing_eagerly(): + self._test_minimize_loss_eager(distribution) + else: + self._test_minimize_loss_graph(distribution) - def testMinimizeLossEager(self): - if not GPU_TEST: - self.skipTest("Not GPU test") - self._test_minimize_loss_eager(self._get_distribution_strategy()) + def testReplicaId(self, distribution): + self._test_replica_id(distribution) - def testMinimizeLossGraph(self): - soft_placement = not GPU_TEST - print("testMinimizeLossGraph soft_placement:", soft_placement) - self._test_minimize_loss_graph( - self._get_distribution_strategy(), soft_placement=soft_placement) + def testNumReplicasInSync(self, distribution): + self.assertEqual(2, distribution.num_replicas_in_sync) - def testDeviceIndex(self): - if not GPU_TEST: - self.skipTest("Not GPU test") - self._test_device_index(self._get_distribution_strategy()) + def testCallAndMergeExceptions(self, distribution): + self._test_call_and_merge_exceptions(distribution) - def testReplicaId(self): - if not GPU_TEST: - self.skipTest("Not GPU test") - self._test_replica_id(self._get_distribution_strategy()) - - def testNumReplicas(self): - if not GPU_TEST: - self.skipTest("Not GPU test") - self.assertEqual(2, self._get_distribution_strategy().num_replicas) - - def testNumReplicasInSync(self): - if not GPU_TEST: - self.skipTest("Not GPU test") - self.assertEqual(2, self._get_distribution_strategy(). - num_replicas_in_sync) - - @test_util.run_in_graph_and_eager_modes - def testCallAndMergeExceptions(self): - if not GPU_TEST: - self.skipTest("Not GPU test") - self._test_call_and_merge_exceptions(self._get_distribution_strategy()) - - @test_util.run_in_graph_and_eager_modes - def testRunRegroupError(self): - - def run_fn(device_id): + def testRunRegroupError(self, distribution): + def run_fn(): + replica_id = int(self.evaluate(_replica_id())) # Generates a list with different lengths on different devices. # Will fail in _regroup() (if more than one device). - return list(range(device_id)) + return list(range(replica_id)) - dist = self._get_distribution_strategy() - with dist.scope(), self.assertRaises(AssertionError): - dist.call_for_each_replica(run_fn, args=(dist.worker_device_index,)) + with distribution.scope(), self.assertRaises(AssertionError): + distribution.extended.call_for_each_replica(run_fn) - @test_util.run_in_graph_and_eager_modes - def testReduceToCpu(self): - if not GPU_TEST: - self.skipTest("Not GPU test") + def testReduceToCpu(self, distribution): + with distribution.scope(): + result = distribution.extended.call_for_each_replica(_replica_id) + reduced = distribution.reduce(reduce_util.ReduceOp.SUM, result) + expected = sum(range(distribution.num_replicas_in_sync)) + self.assertEqual(expected, self.evaluate(reduced)) - def run_fn(device_id): - return device_id + def testMakeInputFnIterator(self, distribution): + dataset_fn = lambda: dataset_ops.Dataset.range(10) + expected_values = [[i, i+1] for i in range(0, 10, 2)] - dist = self._get_distribution_strategy() - with dist.scope(): - result = dist.call_for_each_replica( - run_fn, args=(dist.worker_device_index,)) - reduced = dist.reduce( - variable_scope.VariableAggregation.SUM, - result, - destinations="/device:CPU:0") - unwrapped = dist.unwrap(reduced) - self.assertEqual(1, len(unwrapped)) - expected = sum(range(len(dist.worker_devices))) - self.assertEqual(expected, self.evaluate(unwrapped[0])) + input_fn = self._input_fn_to_test_input_context( + dataset_fn, + expected_num_replicas_in_sync=2, + expected_num_input_pipelines=1, + expected_input_pipeline_id=0) + iterator = distribution.make_input_fn_iterator(input_fn) + self._test_input_fn_iterator(iterator, distribution.extended.worker_devices, + expected_values) - @test_util.run_in_graph_and_eager_modes - def testReduceOnlyFirstReplicaUpdates(self): - if not GPU_TEST: - self.skipTest("Not GPU test") - - def run_fn(device_id): - return constant_op.constant(3 + 5 * device_id) - - dist = self._get_distribution_strategy() - with dist.scope(): - result = dist.call_for_each_replica( - run_fn, args=(dist.worker_device_index,)) - reduced = dist.reduce( - variable_scope.VariableAggregation.ONLY_FIRST_REPLICA, - result, - destinations="/device:CPU:0") - unwrapped = dist.unwrap(reduced) - self.assertEqual(1, len(unwrapped)) - self.assertEqual(3, self.evaluate(unwrapped[0])) - - @test_util.run_in_graph_and_eager_modes() - def testReduceToMultipleDestinations(self): - if not GPU_TEST: - self.skipTest("Not GPU test") - - devices = ["/device:GPU:0"] - if GPU_TEST: - self.assertGreater(context.num_gpus(), 0) - print(self.id().split(".")[-1], "devices:", ", ".join(devices)) - - dist = mirrored_strategy.MirroredStrategy(devices) - with dist.scope(): - reduced = dist.reduce( - variable_scope.VariableAggregation.SUM, - 1.0, - destinations=["/device:CPU:0", "/device:GPU:0"]) - unwrapped = dist.unwrap(reduced) - self.assertEqual(2, len(unwrapped)) - self.assertEqual(1.0, self.evaluate(unwrapped[0])) + def testGlobalStepUpdate(self, distribution): + self._test_global_step_update(distribution) +def one_device_combinations(): + return combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_one_cpu, + combinations.mirrored_strategy_with_one_gpu, + combinations.core_mirrored_strategy_with_one_cpu, + combinations.core_mirrored_strategy_with_one_gpu], + mode=["graph", "eager"]) + + +class MirroredOneDeviceDistributionTest( + strategy_test_lib.DistributionTestBase, + parameterized.TestCase): + + @combinations.generate(one_device_combinations()) + def testMinimizeLoss(self, distribution): + if context.executing_eagerly(): + self._test_minimize_loss_eager(distribution) + else: + self._test_minimize_loss_graph(distribution) + + @combinations.generate(one_device_combinations()) + def testReplicaId(self, distribution): + self._test_replica_id(distribution) + + @combinations.generate(one_device_combinations()) + def testCallAndMergeExceptions(self, distribution): + self._test_call_and_merge_exceptions(distribution) + + +class MirroredStrategyVariableCreatorStackTest( + test.TestCase, parameterized.TestCase): + + @combinations.generate(combinations.combine( + distribution=[combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=["graph"])) + def testCreatorStacksAreThreadLocal(self, distribution): + def model_fn(): + replica_id_str = str(self.evaluate(_replica_id())) + + def thread_creator_fn(next_creator, *args, **kwargs): + return next_creator(*args, **kwargs) + ":thread_" + replica_id_str + + with variable_scope.variable_creator_scope(thread_creator_fn): + # Create a variable in this scope. + v = variable_scope.variable(1.0) + + # This will pause the current thread, and execute the other thread. + ds_context.get_replica_context().merge_call(lambda _: _) + return v + + def main_thread_creator(next_creator, *args, **kwargs): + # We are not using the underlying next_creator for test purposes. + del next_creator, args, kwargs + return "main_thread" + + with context.graph_mode(), \ + distribution.scope(), \ + variable_scope.variable_creator_scope(main_thread_creator): + result = distribution.extended.call_for_each_replica(model_fn) + result = distribution.unwrap(result) + expected = ["main_thread:thread_0", "main_thread:thread_1"] + self.assertEqual(expected, result) + + +@combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=["graph", "eager"])) class MirroredStrategyVariableCreationTest(test.TestCase): - config = config_pb2.ConfigProto() - config.allow_soft_placement = True + # TODO(priyag): Modify more tests to use this helper and check more + # properties. + def _test_mv_properties(self, var, name): + self.assertIsInstance(var, values.MirroredVariable) + self.assertEqual(name, var.name) + for d in var.devices: + self.assertEqual(d, var.get(d).device) - def _skip_eager_if_gpus_less_than(self, num_gpus): - if context.num_gpus() < num_gpus and context.executing_eagerly(): - self.skipTest("Enough GPUs not available for this test in eager mode.") + def testVariableInFuncGraph(self, distribution): + def model_fn(): + v = variable_scope.variable(2.0, name="bar") + ds_context.get_replica_context().merge_call(lambda _: _) + return v - @test_util.run_in_graph_and_eager_modes(config=config) - def testSingleVariable(self): - self._skip_eager_if_gpus_less_than(1) + with func_graph.FuncGraph("fg").as_default(), distribution.scope(): + v1 = variable_scope.variable(1.0, name="foo") + v2 = distribution.extended.call_for_each_replica(model_fn) + self._test_mv_properties(v1, "foo:0") + self._test_mv_properties(v2, "bar:0") + + def testSingleVariable(self, distribution): def model_fn(): # This variable should be created only once across the threads because of - # special variable_creator functions used by `dist.call_for_each_replica`. + # special variable_creator functions used by + # `distribution.extended.call_for_each_replica`. v = variable_scope.variable(1.0, name="foo") - distribution_strategy_context.get_replica_context().merge_call( - lambda _: _) + ds_context.get_replica_context().merge_call(lambda _: _) return v - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - result = dist.call_for_each_replica(model_fn) - self.assertIsInstance(result, values.MirroredVariable) - self.assertEquals("foo:0", result.name) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testUnnamedVariable(self): - self._skip_eager_if_gpus_less_than(1) + with distribution.scope(): + result = distribution.extended.call_for_each_replica(model_fn) + self._test_mv_properties(result, "foo:0") + def testUnnamedVariable(self, distribution): def model_fn(): v = variable_scope.variable(1.0) - distribution_strategy_context.get_replica_context().merge_call( - lambda _: _) + ds_context.get_replica_context().merge_call(lambda _: _) return v - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - result = dist.call_for_each_replica(model_fn) - self.assertIsInstance(result, values.MirroredVariable) - # Default name of "Variable" will be used. - self.assertEquals("Variable:0", result.name) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testMultipleVariables(self): - self._skip_eager_if_gpus_less_than(1) + with distribution.scope(): + result = distribution.extended.call_for_each_replica(model_fn) + self._test_mv_properties(result, "Variable:0") + def testMultipleVariables(self, distribution): def model_fn(): vs = [] for i in range(5): vs.append(variable_scope.variable(1.0, name="foo" + str(i))) - distribution_strategy_context.get_replica_context().merge_call( - lambda _: _) + ds_context.get_replica_context().merge_call(lambda _: _) return vs - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - result = dist.call_for_each_replica(model_fn) + with distribution.scope(): + result = distribution.extended.call_for_each_replica(model_fn) for i, v in enumerate(result): - self.assertIsInstance(v, values.MirroredVariable) - self.assertEquals("foo" + str(i) + ":0", v.name) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testMultipleVariablesWithSameCanonicalName(self): - self._skip_eager_if_gpus_less_than(1) + self._test_mv_properties(v, "foo" + str(i) + ":0") + def testMultipleVariablesWithSameCanonicalName(self, distribution): def model_fn(): vs = [] vs.append(variable_scope.variable(1.0, name="foo/bar")) vs.append(variable_scope.variable(1.0, name="foo_1/bar")) vs.append(variable_scope.variable(1.0, name="foo_1/bar_1")) vs.append(variable_scope.variable(1.0, name="foo/bar_1")) - distribution_strategy_context.get_replica_context().merge_call( - lambda _: _) + ds_context.get_replica_context().merge_call(lambda _: _) return vs - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - result = dist.call_for_each_replica(model_fn) + with distribution.scope(): + result = distribution.extended.call_for_each_replica(model_fn) for v in result: self.assertIsInstance(v, values.MirroredVariable) - self.assertEquals(4, len(result)) - self.assertEquals("foo/bar:0", result[0].name) - self.assertEquals("foo_1/bar:0", result[1].name) - self.assertEquals("foo_1/bar_1:0", result[2].name) - self.assertEquals("foo/bar_1:0", result[3].name) + self.assertEqual(4, len(result)) + self.assertEqual("foo/bar:0", result[0].name) + self.assertEqual("foo_1/bar:0", result[1].name) + self.assertEqual("foo_1/bar_1:0", result[2].name) + self.assertEqual("foo/bar_1:0", result[3].name) - @test_util.run_in_graph_and_eager_modes(config=config) - def testVariableWithSameCanonicalNameAcrossThreads(self): - self._skip_eager_if_gpus_less_than(1) - - def model_fn(device_id): - v = variable_scope.variable(1.0, name="foo_" + str(device_id)) - distribution_strategy_context.get_replica_context().merge_call( - lambda _: _) + def testVariableWithSameCanonicalNameAcrossThreads(self, distribution): + def model_fn(): + replica_id = self.evaluate(_replica_id()) + v = variable_scope.variable(1.0, name="foo_" + str(replica_id)) + ds_context.get_replica_context().merge_call(lambda _: _) return v - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - result = dist.call_for_each_replica( - model_fn, args=(dist.worker_device_index,)) + with distribution.scope(): + result = distribution.extended.call_for_each_replica(model_fn) self.assertIsInstance(result, values.MirroredVariable) # The resulting mirrored variable will use the name from the first device. - self.assertEquals("foo_0:0", result.name) + self.assertEqual("foo_0:0", result.name) - @test_util.run_in_graph_and_eager_modes(config=config) - def testWithLayers(self): - self._skip_eager_if_gpus_less_than(1) + def testWithLayers(self, distribution): def model_fn(features): with variable_scope.variable_scope("common"): layer1 = core.Dense(1) @@ -305,17 +289,14 @@ class MirroredStrategyVariableCreationTest(test.TestCase): layer2 = core.Dense(1) layer2(features) # This will pause the current thread, and execute the other thread. - distribution_strategy_context.get_replica_context().merge_call( - lambda _: _) + ds_context.get_replica_context().merge_call(lambda _: _) layer3 = core.Dense(1) layer3(features) return [(layer1.kernel, layer1.bias), (layer2.kernel, layer2.bias), (layer3.kernel, layer3.bias)] - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - ds = dist.distribute_dataset( + ds = distribution.distribute_dataset( lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10)) if context.executing_eagerly(): iterator = ds.make_one_shot_iterator() @@ -325,26 +306,23 @@ class MirroredStrategyVariableCreationTest(test.TestCase): features = iterator.get_next() - with dist.scope(): - result = dist.call_for_each_replica(model_fn, args=(features,)) + with distribution.scope(): + result = distribution.extended.call_for_each_replica( + model_fn, args=(features,)) suffixes = ["", "_1", "_2"] for (kernel, bias), suffix in zip(result, suffixes): self.assertIsInstance(kernel, values.MirroredVariable) - self.assertEquals("common/dense" + suffix + "/kernel:0", kernel.name) + self.assertEqual("common/dense" + suffix + "/kernel:0", kernel.name) self.assertIsInstance(bias, values.MirroredVariable) - self.assertEquals("common/dense" + suffix + "/bias:0", bias.name) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testWithVariableAndVariableScope(self): - self._skip_eager_if_gpus_less_than(1) + self.assertEqual("common/dense" + suffix + "/bias:0", bias.name) + def testWithVariableAndVariableScope(self, distribution): def model_fn(): v0 = variable_scope.variable(1.0, name="var0", aggregation=None) with variable_scope.variable_scope("common"): v1 = variable_scope.variable(1.0, name="var1") # This will pause the current thread, and execute the other thread. - distribution_strategy_context.get_replica_context().merge_call( - lambda _: _) + ds_context.get_replica_context().merge_call(lambda _: _) v2 = variable_scope.variable( 1.0, name="var2", @@ -358,37 +336,31 @@ class MirroredStrategyVariableCreationTest(test.TestCase): return v0, v1, v2, v3 - devices = ["/device:CPU:0", "/device:GPU:0"] - dist = mirrored_strategy.MirroredStrategy(devices) - with dist.scope(): + with distribution.scope(): v = variable_scope.variable(1.0, name="var-main0") - self.assertEquals("var-main0:0", v.name) + self.assertEqual("var-main0:0", v.name) - result = dist.call_for_each_replica(model_fn) - self.assertEquals(4, len(result)) + result = distribution.extended.call_for_each_replica(model_fn) + self.assertEqual(4, len(result)) v0, v1, v2, v3 = result self.assertIsInstance(v0, values.MirroredVariable) - self.assertEquals("var0:0", v0.name) + self.assertEqual("var0:0", v0.name) self.assertIsInstance(v1, values.MirroredVariable) - self.assertEquals("common/var1:0", v1.name) + self.assertEqual("common/var1:0", v1.name) self.assertIsInstance(v2, values.ReplicaLocalVariable) - self.assertEquals("common/var2:0", v2.name) - self.assertEquals(variable_scope.VariableAggregation.SUM, v2.aggregation) + self.assertEqual("common/var2:0", v2.name) + self.assertEqual(variable_scope.VariableAggregation.SUM, v2.aggregation) self.assertIsInstance(v3, values.MirroredVariable) - self.assertEquals("common/var3:0", v3.name) - self.assertEquals(variable_scope.VariableAggregation.MEAN, v3.aggregation) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testWithGetVariableAndVariableScope(self): - self._skip_eager_if_gpus_less_than(1) + self.assertEqual("common/var3:0", v3.name) + self.assertEqual(variable_scope.VariableAggregation.MEAN, v3.aggregation) + def testWithGetVariableAndVariableScope(self, distribution): def model_fn(): v0 = variable_scope.get_variable("var0", [1]) with variable_scope.variable_scope("common"): v1 = variable_scope.get_variable("var1", [1]) # This will pause the current thread, and execute the other thread. - distribution_strategy_context.get_replica_context().merge_call( - lambda _: _) + ds_context.get_replica_context().merge_call(lambda _: _) v2 = variable_scope.get_variable( "var2", [1], synchronization=variable_scope.VariableSynchronization.ON_READ, @@ -400,33 +372,28 @@ class MirroredStrategyVariableCreationTest(test.TestCase): return v0, v1, v2, v3 - devices = ["/device:CPU:0", "/device:GPU:0"] - dist = mirrored_strategy.MirroredStrategy(devices) - with dist.scope(): + with distribution.scope(): with variable_scope.variable_scope("main"): v = variable_scope.get_variable("var-main0", [1]) - self.assertEquals("main/var-main0:0", v.name) + self.assertEqual("main/var-main0:0", v.name) - result = dist.call_for_each_replica(model_fn) - self.assertEquals(4, len(result)) + result = distribution.extended.call_for_each_replica(model_fn) + self.assertEqual(4, len(result)) v0, v1, v2, v3 = result self.assertIsInstance(v0, values.MirroredVariable) - self.assertEquals("main/var0:0", v0.name) + self.assertEqual("main/var0:0", v0.name) self.assertIsInstance(v1, values.MirroredVariable) - self.assertEquals("main/common/var1:0", v1.name) + self.assertEqual("main/common/var1:0", v1.name) self.assertIsInstance(v2, values.ReplicaLocalVariable) - self.assertEquals("main/common/var2:0", v2.name) - self.assertEquals(variable_scope.VariableAggregation.SUM, - v2.aggregation) + self.assertEqual("main/common/var2:0", v2.name) + self.assertEqual(variable_scope.VariableAggregation.SUM, + v2.aggregation) self.assertIsInstance(v3, values.MirroredVariable) - self.assertEquals("main/common/var3:0", v3.name) - self.assertEquals(variable_scope.VariableAggregation.MEAN, - v3.aggregation) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testOnlyFirstReplicaUpdatesVariables(self): - self._skip_eager_if_gpus_less_than(1) + self.assertEqual("main/common/var3:0", v3.name) + self.assertEqual(variable_scope.VariableAggregation.MEAN, + v3.aggregation) + def testOnlyFirstReplicaUpdatesVariables(self, distribution): def create_fn(): aggregation = variable_scope.VariableAggregation.ONLY_FIRST_REPLICA v0 = variable_scope.variable( @@ -442,71 +409,73 @@ class MirroredStrategyVariableCreationTest(test.TestCase): return v0, v1 devices = ["/device:GPU:0", "/device:CPU:0"] - dist = mirrored_strategy.MirroredStrategy(devices) - with dist.scope(): - v0, v1 = dist.call_for_each_replica(create_fn) + with distribution.scope(): + v0, v1 = distribution.extended.call_for_each_replica(create_fn) self.evaluate(v0.initializer) self.assertEqual(2.0, self.evaluate(v0.get(devices[0]))) self.assertEqual(2.0, self.evaluate(v0.get(devices[1]))) - self.assertEqual(2.0, self.evaluate(dist.read_var(v0))) + self.assertEqual(2.0, self.evaluate(distribution.extended.read_var(v0))) self.evaluate(v1.initializer) self.assertEqual(3.0, self.evaluate(v1.get(devices[0]))) self.assertEqual(3.0, self.evaluate(v1.get(devices[1]))) - self.assertEqual(3.0, self.evaluate(dist.read_var(v1))) + self.assertEqual(3.0, self.evaluate(distribution.extended.read_var(v1))) + + def replica_id_plus_one(): + return math_ops.cast(_replica_id() + 1, dtype=dtypes.float32) # Update using the assign_add member function. - def update_member_fn(device_id): - update0 = v0.assign_add(5.0 * (device_id + 1)) - update1 = v1.assign_add(7.0 * (device_id + 1)) + def update_member_fn(): + update0 = v0.assign_add(5.0 * replica_id_plus_one()) + update1 = v1.assign_add(7.0 * replica_id_plus_one()) return update0, update1 - update0a, update1a = dist.call_for_each_replica( - update_member_fn, args=(dist.worker_device_index,)) + update0a, update1a = distribution.extended.call_for_each_replica( + update_member_fn) # Update "sync on read" variable. - self.evaluate(dist.group(update0a)) + self.evaluate(distribution.group(update0a)) self.assertEqual(2.0 + 5.0, self.evaluate(v0.get(devices[0]))) # Writes are not synchronized for "sync on read" variables, # so device[1] can end up with a different value. self.assertEqual(2.0 + 2*5.0, self.evaluate(v0.get(devices[1]))) # Always reads from device 0. - self.assertEqual(2.0 + 5.0, self.evaluate(dist.read_var(v0))) + self.assertEqual(2.0 + 5.0, self.evaluate( + distribution.extended.read_var(v0))) # Update "sync on write" variable. - self.evaluate(dist.group(update1a)) + self.evaluate(distribution.group(update1a)) self.assertEqual(3.0 + 7.0, self.evaluate(v1.get(devices[0]))) # Writes are synchronized for v1, only the argument to assign_add on # device[0] is used. self.assertEqual(3.0 + 7.0, self.evaluate(v1.get(devices[1]))) - self.assertEqual(3.0 + 7.0, self.evaluate(dist.read_var(v1))) + self.assertEqual(3.0 + 7.0, self.evaluate( + distribution.extended.read_var(v1))) # Update using state_ops.assign_add global function. - def update_state_ops_fn(device_id): - update0 = state_ops.assign_add(v0, 11.0 * (device_id + 1)) - update1 = state_ops.assign_add(v1, 13.0 * (device_id + 1)) + def update_state_ops_fn(): + update0 = state_ops.assign_add(v0, 11.0 * replica_id_plus_one()) + update1 = state_ops.assign_add(v1, 13.0 * replica_id_plus_one()) return update0, update1 - update0b, update1b = dist.call_for_each_replica( - update_state_ops_fn, args=(dist.worker_device_index,)) - self.evaluate(dist.group(update0b)) + update0b, update1b = distribution.extended.call_for_each_replica( + update_state_ops_fn) + self.evaluate(distribution.group(update0b)) # Update "sync on read" variable. self.assertEqual(2.0 + 5.0 + 11.0, self.evaluate(v0.get(devices[0]))) self.assertEqual(2.0 + 2*5.0 + 2*11.0, self.evaluate(v0.get(devices[1]))) - self.assertEqual(2.0 + 5.0 + 11.0, self.evaluate(dist.read_var(v0))) + self.assertEqual(2.0 + 5.0 + 11.0, self.evaluate( + distribution.extended.read_var(v0))) # Update "sync on write" variable. - self.evaluate(dist.group(update1b)) + self.evaluate(distribution.group(update1b)) self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.get(devices[0]))) self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.get(devices[1]))) - self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(dist.read_var(v1))) + self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate( + distribution.extended.read_var(v1))) - @test_util.run_in_graph_and_eager_modes(config=config) - def testNoneSynchronizationWithGetVariable(self): - self._skip_eager_if_gpus_less_than(1) - devices = ["/device:CPU:0", "/device:GPU:0"] - dist = mirrored_strategy.MirroredStrategy(devices) - with dist.scope(): + def testNoneSynchronizationWithGetVariable(self, distribution): + with distribution.scope(): with self.assertRaisesRegexp( ValueError, "`NONE` variable synchronization mode is not " "supported with `Mirrored` distribution strategy. Please change " @@ -515,12 +484,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase): "v", [1], synchronization=variable_scope.VariableSynchronization.NONE) - @test_util.run_in_graph_and_eager_modes(config=config) - def testNoneSynchronizationWithVariable(self): - self._skip_eager_if_gpus_less_than(1) - devices = ["/device:CPU:0", "/device:GPU:0"] - dist = mirrored_strategy.MirroredStrategy(devices) - with dist.scope(): + def testNoneSynchronizationWithVariable(self, distribution): + with distribution.scope(): with self.assertRaisesRegexp( ValueError, "`NONE` variable synchronization mode is not " "supported with `Mirrored` distribution strategy. Please change " @@ -530,23 +495,15 @@ class MirroredStrategyVariableCreationTest(test.TestCase): name="v", synchronization=variable_scope.VariableSynchronization.NONE) - @test_util.run_in_graph_and_eager_modes(config=config) - def testInvalidSynchronizationWithVariable(self): - self._skip_eager_if_gpus_less_than(1) - devices = ["/device:CPU:0", "/device:GPU:0"] - dist = mirrored_strategy.MirroredStrategy(devices) - with dist.scope(): + def testInvalidSynchronizationWithVariable(self, distribution): + with distribution.scope(): with self.assertRaisesRegexp( ValueError, "Invalid variable synchronization mode: Invalid for " "variable: v"): variable_scope.variable(1.0, name="v", synchronization="Invalid") - @test_util.run_in_graph_and_eager_modes(config=config) - def testInvalidAggregationWithGetVariable(self): - self._skip_eager_if_gpus_less_than(1) - devices = ["/device:CPU:0", "/device:GPU:0"] - dist = mirrored_strategy.MirroredStrategy(devices) - with dist.scope(): + def testInvalidAggregationWithGetVariable(self, distribution): + with distribution.scope(): with self.assertRaisesRegexp( ValueError, "Invalid variable aggregation mode: invalid for " "variable: v"): @@ -555,12 +512,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase): synchronization=variable_scope.VariableSynchronization.ON_WRITE, aggregation="invalid") - @test_util.run_in_graph_and_eager_modes(config=config) - def testInvalidAggregationWithVariable(self): - self._skip_eager_if_gpus_less_than(1) - devices = ["/device:CPU:0", "/device:GPU:0"] - dist = mirrored_strategy.MirroredStrategy(devices) - with dist.scope(): + def testInvalidAggregationWithVariable(self, distribution): + with distribution.scope(): with self.assertRaisesRegexp( ValueError, "Invalid variable aggregation mode: invalid for " "variable: v"): @@ -570,55 +523,28 @@ class MirroredStrategyVariableCreationTest(test.TestCase): synchronization=variable_scope.VariableSynchronization.ON_WRITE, aggregation="invalid") - @test_util.run_in_graph_and_eager_modes(config=config) - def testThreeDevices(self): - self._skip_eager_if_gpus_less_than(2) - - def model_fn(): - v = variable_scope.variable(1.0, name="foo") - distribution_strategy_context.get_replica_context().merge_call( - lambda _: _) - return v - - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:GPU:1", "/device:CPU:0"]) - - with dist.scope(): - result = dist.call_for_each_replica(model_fn) - self.assertIsInstance(result, values.MirroredVariable) - self.assertEquals("foo:0", result.name) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testNonMatchingVariableCreation(self): - self._skip_eager_if_gpus_less_than(1) - + def testNonMatchingVariableCreation(self, distribution): def model_fn(name): v = variable_scope.variable(1.0, name=name) - distribution_strategy_context.get_replica_context().merge_call( - lambda _: _) + ds_context.get_replica_context().merge_call(lambda _: _) return v - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): + with distribution.scope(): names = values.DistributedValues({ "/device:CPU:0": "foo", "/device:GPU:0": "bar" }) with self.assertRaises(RuntimeError): - _ = dist.call_for_each_replica(model_fn, args=(names,)) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testReplicaLocalVariable(self): - self._skip_eager_if_gpus_less_than(1) + _ = distribution.extended.call_for_each_replica(model_fn, args=(names,)) + def testReplicaLocalVariable(self, distribution): all_v_sum = {} all_v_mean = {} components_sum = {} components_mean = {} - def model_fn(device_id): + def model_fn(): + replica_id = self.evaluate(_replica_id()) v_sum = variable_scope.variable( 1.0, synchronization=variable_scope.VariableSynchronization.ON_READ, @@ -629,26 +555,22 @@ class MirroredStrategyVariableCreationTest(test.TestCase): aggregation=variable_scope.VariableAggregation.MEAN) self.assertTrue(isinstance(v_sum, values.ReplicaLocalVariable)) self.assertTrue(isinstance(v_mean, values.ReplicaLocalVariable)) - updates = [v_sum.assign_add(2.0 + device_id), - v_mean.assign(6.0 * device_id)] - all_v_sum[device_id] = v_sum - all_v_mean[device_id] = v_mean + updates = [v_sum.assign_add(2.0 + replica_id), + v_mean.assign(6.0 * replica_id)] + all_v_sum[replica_id] = v_sum + all_v_mean[replica_id] = v_mean c_sum = v_sum.get() c_mean = v_mean.get() - components_sum[device_id] = c_sum - components_mean[device_id] = c_mean + components_sum[replica_id] = c_sum + components_mean[replica_id] = c_mean self.assertIsNot(v_sum, c_sum) self.assertIsNot(v_mean, c_mean) return updates, v_sum, v_mean, c_sum, c_mean - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): + with distribution.scope(): # Create "sum" and "mean" versions of ReplicaLocalVariables. ret_ops, ret_v_sum, ret_v_mean, regrouped_sum, regrouped_mean = ( - dist.call_for_each_replica( - model_fn, args=(dist.worker_device_index,))) + distribution.extended.call_for_each_replica(model_fn)) # Should see the same wrapping instance in all replicas. self.assertIs(all_v_sum[0], ret_v_sum) self.assertIs(all_v_mean[0], ret_v_mean) @@ -663,10 +585,10 @@ class MirroredStrategyVariableCreationTest(test.TestCase): # Apply updates self.evaluate(variables.global_variables_initializer()) - self.evaluate([y for x in ret_ops for y in dist.unwrap(x)]) + self.evaluate([y for x in ret_ops for y in distribution.unwrap(x)]) expected_sum = 0.0 expected_mean = 0.0 - for i, d in enumerate(dist.worker_devices): + for i, d in enumerate(distribution.extended.worker_devices): # Should see different values on different devices. v_sum_value = self.evaluate(ret_v_sum.get(d).read_value()) v_mean_value = self.evaluate(ret_v_mean.get(d).read_value()) @@ -676,135 +598,22 @@ class MirroredStrategyVariableCreationTest(test.TestCase): expected = i * 6.0 self.assertEqual(expected, v_mean_value) expected_mean += expected - expected_mean /= len(dist.worker_devices) + expected_mean /= len(distribution.extended.worker_devices) # Without get(device), should return the value you get by # applying the reduction across all replicas (whether you use # read_var(), get(), or nothing). - self.assertEqual(expected_sum, self.evaluate(dist.read_var(ret_v_sum))) - self.assertEqual(expected_mean, self.evaluate(dist.read_var(ret_v_mean))) + self.assertEqual(expected_sum, self.evaluate( + distribution.extended.read_var(ret_v_sum))) + self.assertEqual(expected_mean, self.evaluate( + distribution.extended.read_var(ret_v_mean))) self.assertEqual(expected_sum, self.evaluate(ret_v_sum.get())) self.assertEqual(expected_mean, self.evaluate(ret_v_mean.get())) self.assertEqual(expected_sum, self.evaluate(ret_v_sum)) self.assertEqual(expected_mean, self.evaluate(ret_v_mean)) - # NOTE(priyag): Names and name scopes are ignored in eager, hence we are not - # testing this in eager mode. - - def testNameScope(self): - def model_fn(): - with ops.name_scope("foo"): - a = constant_op.constant(1.0, name="a") - distribution_strategy_context.get_replica_context().merge_call( - lambda _: _) - b = constant_op.constant(1.0, name="b") - return a, b - - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with context.graph_mode(), dist.scope(): - with ops.name_scope("main"): - result = dist.call_for_each_replica(model_fn) - self.assertEquals(2, len(result)) - for v, name in zip(result, ["a", "b"]): - self.assertIsInstance(v, values.DistributedValues) - v0, v1 = dist.unwrap(v) - self.assertEquals("main/foo/" + name + ":0", v0.name) - self.assertEquals("main/replica_1/foo/" + name + ":0", v1.name) - - def testWithDefaultName(self): - def model_fn(): - with ops.name_scope(None, "foo"): - a = constant_op.constant(1.0, name="a") - distribution_strategy_context.get_replica_context().merge_call( - lambda _: _) - b = constant_op.constant(2.0, name="b") - return a, b - - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with context.graph_mode(), dist.scope(): - result = dist.call_for_each_replica(model_fn) - self.assertEquals(2, len(result)) - for v, name in zip(result, ["a", "b"]): - self.assertIsInstance(v, values.DistributedValues) - v0, v1 = dist.unwrap(v) - self.assertEquals("foo/" + name + ":0", v0.name) - self.assertEquals("replica_1/foo/" + name + ":0", v1.name) - - # variable_scope.variable() respects name scopes when creating - # variables. On the other hand variable_scope.get_variable() ignores name - # scopes when creating variables. We test both methods of creating variables - # to make sure that we have the same variable names in both cases. - def testNameScopeWithVariable(self): - def in_cross_replica(_): - c = variable_scope.variable(1.0, name="c") - return c - - def model_fn(): - b = variable_scope.variable(1.0, name="b") - with ops.name_scope("foo"): - c = distribution_strategy_context.get_replica_context().merge_call( - in_cross_replica) - return b, c - - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with context.graph_mode(), dist.scope(): - with ops.name_scope("main"): - a = variable_scope.variable(1.0, name="a") - result = dist.call_for_each_replica(model_fn) - result_b = result[0] - result_c = result[1] - self.assertIsInstance(result_b, values.DistributedValues) - self.assertIsInstance(result_c, values.DistributedValues) - a0, a1 = dist.unwrap(a) - b0, b1 = dist.unwrap(result_b) - c0, c1 = dist.unwrap(result_c) - self.assertEquals("main/a:0", a0.name) - self.assertEquals("main/a/replica_1:0", a1.name) - self.assertEquals("main/b:0", b0.name) - self.assertEquals("main/b/replica_1:0", b1.name) - self.assertEquals("main/foo/c:0", c0.name) - self.assertEquals("main/foo/c/replica_1:0", c1.name) - - def testNameScopeWithGetVariable(self): - def in_cross_replica(_): - c = variable_scope.get_variable("c", [1]) - return c - - def model_fn(): - b = variable_scope.get_variable("b", [1]) - with ops.name_scope("foo"): - c = distribution_strategy_context.get_replica_context().merge_call( - in_cross_replica) - return b, c - - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with context.graph_mode(), dist.scope(): - with ops.name_scope("main"): - a = variable_scope.get_variable("a", [1]) - result = dist.call_for_each_replica(model_fn) - result_b = result[0] - result_c = result[1] - self.assertIsInstance(result_b, values.DistributedValues) - self.assertIsInstance(result_c, values.DistributedValues) - a0, a1 = dist.unwrap(a) - b0, b1 = dist.unwrap(result_b) - c0, c1 = dist.unwrap(result_c) - self.assertEquals("a:0", a0.name) - self.assertEquals("a/replica_1:0", a1.name) - self.assertEquals("b:0", b0.name) - self.assertEquals("b/replica_1:0", b1.name) - self.assertEquals("c:0", c0.name) - self.assertEquals("c/replica_1:0", c1.name) - - def testDynamicRnnVariables(self): + # TODO(priyag): Update this test to work in eager mode as well. + def testDynamicRnnVariables(self, distribution): def model_fn(): inputs = constant_op.constant(2 * [2 * [[0.0, 1.0, 2.0, 3.0, 4.0]]]) cell_fw = rnn_cell_impl.LSTMCell(300) @@ -816,81 +625,208 @@ class MirroredStrategyVariableCreationTest(test.TestCase): dtype=dtypes.float32) return outputs - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with context.graph_mode(), dist.scope(): - result = dist.call_for_each_replica(model_fn) + with context.graph_mode(), distribution.scope(): + result = distribution.extended.call_for_each_replica(model_fn) # Two variables are created by the RNN layer. - self.assertEquals(2, len(result)) + self.assertEqual(2, len(result)) for v in result: self.assertIsInstance(v, values.DistributedValues) - _, v1 = dist.unwrap(v) - self.assertStartsWith(v1.name, "replica_1/") + _, v1 = distribution.unwrap(v) + self.assertStartsWith(v1._op.name, "replica_1/") - @test_util.run_in_graph_and_eager_modes(config=config) - def testReplicaLocalVariableUpdate(self): - with context.graph_mode(): + def testReplicaLocalVariableUpdate(self, distribution): + def model_fn(): + v_sum = variable_scope.variable( + 1.0, + synchronization=variable_scope.VariableSynchronization.ON_READ, + aggregation=variable_scope.VariableAggregation.SUM) + self.assertTrue(isinstance(v_sum, values.ReplicaLocalVariable)) + return v_sum - def model_fn(): - v_sum = variable_scope.variable( - 1.0, - synchronization=variable_scope.VariableSynchronization.ON_READ, - aggregation=variable_scope.VariableAggregation.SUM) - self.assertTrue(isinstance(v_sum, values.ReplicaLocalVariable)) - return v_sum + def update(var, value): + return var.assign(value) - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:GPU:1"]) + with distribution.scope(): + ret_v_sum = distribution.extended.call_for_each_replica(model_fn) - def update(var, value): - return var.assign(value) + # Initialize variables. + self.evaluate(variables.global_variables_initializer()) + # Assert that the aggregated value of the replica local vars is the sum + # of the individual values before running the update ops. + self.assertEqual(1.0, self.evaluate(ret_v_sum.get( + distribution.extended.worker_devices[0]).read_value())) + self.assertEqual(2.0, self.evaluate(ret_v_sum)) - with dist.scope(): - ret_v_sum = dist.call_for_each_replica(model_fn) - update_ops = dist.update(ret_v_sum, update, 5.0, grouped=False) - - # Initialize variables. - self.evaluate(variables.global_variables_initializer()) - # Assert that the aggregated value of the replica local vars is the sum - # of the individual values before running the update ops. - self.assertEquals(1.0, self.evaluate( - ret_v_sum.get(dist._devices[0]).read_value())) - self.assertEquals(2.0, self.evaluate(ret_v_sum)) - - # Apply updates. - self.evaluate(update_ops) - # Assert that the aggregated value of the replica local vars is the sum - # of the individual values after running the update ops. - self.assertEquals(5.0, self.evaluate( - ret_v_sum.get(dist._devices[0]).read_value())) - self.assertEquals(10.0, self.evaluate(ret_v_sum)) + # Apply updates. + update_ops = distribution.extended.update( + ret_v_sum, update, args=(5.0,), group=False) + self.evaluate(update_ops) + # Assert that the aggregated value of the replica local vars is the sum + # of the individual values after running the update ops. + self.assertEqual(5.0, self.evaluate(ret_v_sum.get( + distribution.extended.worker_devices[0]).read_value())) + self.assertEqual(10.0, self.evaluate(ret_v_sum)) +@combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=["graph"])) +class MirroredStrategyNameScopeTest(test.TestCase): + # NOTE(priyag): Names and name scopes are ignored in eager, hence we are not + # testing this in eager mode. + + def testNameScope(self, distribution): + def model_fn(): + with ops.name_scope("foo"): + a = constant_op.constant(1.0, name="a") + ds_context.get_replica_context().merge_call(lambda _: _) + b = constant_op.constant(1.0, name="b") + return a, b + + with context.graph_mode(), distribution.scope(): + with ops.name_scope("main"): + result = distribution.extended.call_for_each_replica(model_fn) + self.assertEqual(2, len(result)) + for v, name in zip(result, ["a", "b"]): + self.assertIsInstance(v, values.DistributedValues) + v0, v1 = distribution.unwrap(v) + self.assertEqual("main/foo/" + name + ":0", v0.name) + self.assertEqual("main/replica_1/foo/" + name + ":0", v1.name) + + def testWithDefaultName(self, distribution): + def model_fn(): + with ops.name_scope(None, "foo"): + a = constant_op.constant(1.0, name="a") + ds_context.get_replica_context().merge_call(lambda _: _) + b = constant_op.constant(2.0, name="b") + return a, b + + with context.graph_mode(), distribution.scope(): + result = distribution.extended.call_for_each_replica(model_fn) + self.assertEqual(2, len(result)) + for v, name in zip(result, ["a", "b"]): + self.assertIsInstance(v, values.DistributedValues) + v0, v1 = distribution.unwrap(v) + self.assertEqual("foo/" + name + ":0", v0.name) + self.assertEqual("replica_1/foo/" + name + ":0", v1.name) + + # variable_scope.variable() respects name scopes when creating + # variables. On the other hand variable_scope.get_variable() ignores name + # scopes when creating variables. We test both methods of creating variables + # to make sure that we have the same variable names in both cases. + def testNameScopeWithVariable(self, distribution): + def in_cross_replica(_): + c = variable_scope.variable(1.0, name="c") + return c + + def model_fn(): + b = variable_scope.variable(1.0, name="b") + with ops.name_scope("foo"): + c = ds_context.get_replica_context().merge_call(in_cross_replica) + return b, c + + with context.graph_mode(), distribution.scope(): + with ops.name_scope("main"): + a = variable_scope.variable(1.0, name="a") + result = distribution.extended.call_for_each_replica(model_fn) + result_b = result[0] + result_c = result[1] + self.assertIsInstance(result_b, values.DistributedValues) + self.assertIsInstance(result_c, values.DistributedValues) + a0, a1 = distribution.unwrap(a) + b0, b1 = distribution.unwrap(result_b) + c0, c1 = distribution.unwrap(result_c) + self.assertEqual("main/a:0", a0.name) + self.assertEqual("main/a/replica_1:0", a1.name) + self.assertEqual("main/b:0", b0.name) + self.assertEqual("main/b/replica_1:0", b1.name) + self.assertEqual("main/foo/c:0", c0.name) + self.assertEqual("main/foo/c/replica_1:0", c1.name) + + def testNameScopeWithGetVariable(self, distribution): + def in_cross_replica(_): + c = variable_scope.get_variable("c", [1]) + return c + + def model_fn(): + b = variable_scope.get_variable("b", [1]) + with ops.name_scope("foo"): + c = ds_context.get_replica_context().merge_call(in_cross_replica) + return b, c + + with context.graph_mode(), distribution.scope(): + with ops.name_scope("main"): + a = variable_scope.get_variable("a", [1]) + result = distribution.extended.call_for_each_replica(model_fn) + result_b = result[0] + result_c = result[1] + self.assertIsInstance(result_b, values.DistributedValues) + self.assertIsInstance(result_c, values.DistributedValues) + a0, a1 = distribution.unwrap(a) + b0, b1 = distribution.unwrap(result_b) + c0, c1 = distribution.unwrap(result_c) + self.assertEqual("a:0", a0.name) + self.assertEqual("a/replica_1:0", a1.name) + self.assertEqual("b:0", b0.name) + self.assertEqual("b/replica_1:0", b1.name) + self.assertEqual("c:0", c0.name) + self.assertEqual("c/replica_1:0", c1.name) + + +@combinations.generate( + combinations.combine( + distribution=[ + combinations.NamedDistribution( + "Mirrored3Devices", + # pylint: disable=g-long-lambda + lambda: mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:GPU:1", "/device:CPU:0"]), + required_gpus=2), + combinations.NamedDistribution( + "CoreMirrored3Devices", + # pylint: disable=g-long-lambda + lambda: mirrored_strategy.CoreMirroredStrategy( + ["/device:GPU:0", "/device:GPU:1", "/device:CPU:0"]), + required_gpus=2) + ], + mode=["graph", "eager"])) +class MirroredThreeDeviceDistributionTest( + strategy_test_lib.DistributionTestBase, + parameterized.TestCase): + + def testThreeDevices(self, distribution): + def model_fn(): + v = variable_scope.variable(1.0, name="foo") + ds_context.get_replica_context().merge_call(lambda _: _) + return v + + with distribution.scope(): + result = distribution.extended.call_for_each_replica(model_fn) + self.assertIsInstance(result, values.MirroredVariable) + self.assertEqual("foo:0", result.name) + + +@combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=["graph", "eager"])) class MirroredVariableUpdateTest(test.TestCase): # The following tests check assign, assign_add and assign_sub on Mirrored # variables in replica and cross replica context. - config = config_pb2.ConfigProto() - config.allow_soft_placement = True - def _skip_eager_if_gpus_less_than(self, num_gpus): - if context.num_gpus() < num_gpus and context.executing_eagerly(): - self.skipTest("Enough GPUs not available for this test in eager mode.") - - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignMirroredVarReplicaContextWithoutAggregationType(self): + def testAssignMirroredVarReplicaContextWithoutAggregationType(self, + distribution): # Test that we always have an aggregation type set on the mirrored variable # if we assign to it in replica mode. - self._skip_eager_if_gpus_less_than(1) def var_fn(): v = variable_scope.variable(1.0, name="foo") return v - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) @@ -900,23 +836,19 @@ class MirroredVariableUpdateTest(test.TestCase): with self.assertRaisesRegexp( ValueError, "You must specify an aggregation method to update a " "MirroredVariable in Replica Context."): - self.evaluate(dist.unwrap(dist.call_for_each_replica(model_fn))) + self.evaluate(distribution.unwrap( + distribution.extended.call_for_each_replica(model_fn))) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignMirroredVarReplicaContextWithSum(self): + def testAssignMirroredVarReplicaContextWithSum(self, distribution): # Test that we don't reduce a non-per-replica value with the "sum" # aggregation type. - self._skip_eager_if_gpus_less_than(1) def var_fn(): v = variable_scope.variable( 1.0, name="foo", aggregation=variable_scope.VariableAggregation.SUM) return v - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) @@ -925,219 +857,184 @@ class MirroredVariableUpdateTest(test.TestCase): with self.assertRaisesRegexp( ValueError, "A non-DistributedValues value 5.0 cannot be reduced " - "with the given aggregation VariableAggregation.SUM."): - self.evaluate(dist.unwrap(dist.call_for_each_replica(model_fn))) + "with the given reduce op ReduceOp.SUM."): + self.evaluate(distribution.unwrap( + distribution.extended.call_for_each_replica(model_fn))) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignMirroredVarCrossDeviceContext(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignMirroredVarCrossDeviceContext(self, distribution): def var_fn(): return variable_scope.variable(1.0, name="foo") - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) - self.assertEquals(1.0, self.evaluate(mirrored_var)) + self.assertEqual(1.0, self.evaluate(mirrored_var)) mirrored_var_result = self.evaluate(mirrored_var.assign(6.0)) - self.assertEquals(6.0, mirrored_var_result) + self.assertEqual(6.0, mirrored_var_result) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignMirroredVarReplicaContext(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignMirroredVarReplicaContext(self, distribution): def var_fn(): return variable_scope.variable( 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) - self.assertEquals(1.0, self.evaluate(mirrored_var)) + self.assertEqual(1.0, self.evaluate(mirrored_var)) def model_fn(): value = math_ops.cast( - distribution_strategy_context.get_replica_context().replica_id, + ds_context.get_replica_context().replica_id_in_sync_group, mirrored_var.dtype) return mirrored_var.assign(value) - self.evaluate(dist.unwrap(dist.call_for_each_replica(model_fn))) - self.assertEquals(0.5, self.evaluate(mirrored_var)) + self.evaluate(distribution.unwrap( + distribution.extended.call_for_each_replica(model_fn))) + self.assertEqual(0.5, self.evaluate(mirrored_var)) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignMirroredVarReplicaContextWithSingleValue(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignMirroredVarReplicaContextWithSingleValue(self, distribution): def var_fn(): return variable_scope.variable( 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) - self.assertEquals(1.0, self.evaluate(mirrored_var)) + self.assertEqual(1.0, self.evaluate(mirrored_var)) def model_fn(): return mirrored_var.assign(5.0) - self.evaluate(dist.unwrap(dist.call_for_each_replica(model_fn))) - self.assertEquals(5.0, self.evaluate(mirrored_var)) + self.evaluate(distribution.unwrap( + distribution.extended.call_for_each_replica(model_fn))) + self.assertEqual(5.0, self.evaluate(mirrored_var)) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignAddMirroredVarCrossDeviceContext(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignAddMirroredVarCrossDeviceContext(self, distribution): def var_fn(): return variable_scope.variable(1.0, name="foo") - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) - self.assertEquals(1.0, self.evaluate(mirrored_var)) + self.assertEqual(1.0, self.evaluate(mirrored_var)) # read_value == True mirrored_var_result = self.evaluate( mirrored_var.assign_add(6.0, read_value=True)) - self.assertEquals(7.0, mirrored_var_result) - self.assertEquals(7.0, self.evaluate(mirrored_var.get("/device:CPU:0"))) - self.assertEquals(7.0, self.evaluate(mirrored_var.get("/device:GPU:0"))) + self.assertEqual(7.0, mirrored_var_result) + self.assertEqual(7.0, self.evaluate(mirrored_var.get("/device:CPU:0"))) + self.assertEqual(7.0, self.evaluate(mirrored_var.get("/device:GPU:0"))) # read_value == False self.evaluate(mirrored_var.assign_add(2.0, read_value=False)) - self.assertEquals(9.0, self.evaluate(mirrored_var.get("/device:CPU:0"))) - self.assertEquals(9.0, self.evaluate(mirrored_var.get("/device:GPU:0"))) + self.assertEqual(9.0, self.evaluate(mirrored_var.get("/device:CPU:0"))) + self.assertEqual(9.0, self.evaluate(mirrored_var.get("/device:GPU:0"))) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignAddMirroredVarReplicaContext(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignAddMirroredVarReplicaContext(self, distribution): def var_fn(): return variable_scope.variable( 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) - self.assertEquals(1.0, self.evaluate(mirrored_var)) + self.assertEqual(1.0, self.evaluate(mirrored_var)) def model_fn(): value = math_ops.cast( - distribution_strategy_context.get_replica_context().replica_id, + ds_context.get_replica_context().replica_id_in_sync_group, mirrored_var.dtype) return mirrored_var.assign_add(value) - self.evaluate(dist.unwrap(dist.call_for_each_replica(model_fn))) - self.assertEquals(1.5, self.evaluate(mirrored_var)) + self.evaluate(distribution.unwrap( + distribution.extended.call_for_each_replica(model_fn))) + self.assertEqual(1.5, self.evaluate(mirrored_var)) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignAddMirroredVarReplicaContextWithSingleValue(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignAddMirroredVarReplicaContextWithSingleValue(self, distribution): def var_fn(): return variable_scope.variable( 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) - self.assertEquals(1.0, self.evaluate(mirrored_var)) + self.assertEqual(1.0, self.evaluate(mirrored_var)) def model_fn(): return mirrored_var.assign_add(5.0) - self.evaluate(dist.unwrap(dist.call_for_each_replica(model_fn))) - self.assertEquals(6.0, self.evaluate(mirrored_var)) + self.evaluate(distribution.unwrap( + distribution.extended.call_for_each_replica(model_fn))) + self.assertEqual(6.0, self.evaluate(mirrored_var)) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignSubMirroredVarCrossDeviceContext(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignSubMirroredVarCrossDeviceContext(self, distribution): def var_fn(): return variable_scope.variable(5.0, name="foo") - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) - self.assertEquals(5.0, self.evaluate(mirrored_var)) + self.assertEqual(5.0, self.evaluate(mirrored_var)) mirrored_var_result = self.evaluate(mirrored_var.assign_sub(2.0)) - self.assertEquals(3.0, mirrored_var_result) - self.assertEquals(3.0, self.evaluate(mirrored_var.get("/device:GPU:0"))) - self.assertEquals(3.0, self.evaluate(mirrored_var.get("/device:CPU:0"))) + self.assertEqual(3.0, mirrored_var_result) + self.assertEqual(3.0, self.evaluate(mirrored_var.get("/device:GPU:0"))) + self.assertEqual(3.0, self.evaluate(mirrored_var.get("/device:CPU:0"))) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignSubMirroredVarReplicaContext(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignSubMirroredVarReplicaContext(self, distribution): def var_fn(): return variable_scope.variable( 5.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) - self.assertEquals(5.0, self.evaluate(mirrored_var)) + self.assertEqual(5.0, self.evaluate(mirrored_var)) def model_fn(): value = math_ops.cast( - distribution_strategy_context.get_replica_context().replica_id, + ds_context.get_replica_context().replica_id_in_sync_group, mirrored_var.dtype) return mirrored_var.assign_sub(value) - self.evaluate(dist.unwrap(dist.call_for_each_replica(model_fn))) - self.assertEquals(4.5, self.evaluate(mirrored_var)) + self.evaluate(distribution.unwrap( + distribution.extended.call_for_each_replica(model_fn))) + self.assertEqual(4.5, self.evaluate(mirrored_var)) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignSubMirroredVarReplicaContextWithSingleValue(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignSubMirroredVarReplicaContextWithSingleValue(self, distribution): def var_fn(): return variable_scope.variable( 5.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) - self.assertEquals(5.0, self.evaluate(mirrored_var)) + self.assertEqual(5.0, self.evaluate(mirrored_var)) def model_fn(): return mirrored_var.assign_sub(1.0) - self.evaluate(dist.unwrap(dist.call_for_each_replica(model_fn))) - self.assertEquals(4.0, self.evaluate(mirrored_var)) + self.evaluate(distribution.unwrap( + distribution.extended.call_for_each_replica(model_fn))) + self.assertEqual(4.0, self.evaluate(mirrored_var)) +@combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=["graph", "eager"])) class MirroredAndReplicaLocalVariableInitializerTest(test.TestCase): - config = config_pb2.ConfigProto() - config.allow_soft_placement = True - def testAssignMirroredVarInitializer(self): + def testAssignMirroredVarInitializer(self, distribution): # This test is not eager compatible since in eager variables are initialized # upon construction instead of once the initialization op is run. with context.graph_mode(): @@ -1145,17 +1042,14 @@ class MirroredAndReplicaLocalVariableInitializerTest(test.TestCase): v = variable_scope.variable(1.0, name="foo") return v - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.assertFalse(self.evaluate(mirrored_var.is_initialized())) self.evaluate(mirrored_var.initializer) self.assertTrue(self.evaluate(mirrored_var.is_initialized())) - def testAssignReplicaLocalVarInitializer(self): + def testAssignReplicaLocalVarInitializer(self, distribution): # This test is not eager compatible since in eager variables are initialized # upon construction instead of once the initialization op is run. with context.graph_mode(): @@ -1167,11 +1061,9 @@ class MirroredAndReplicaLocalVariableInitializerTest(test.TestCase): self.assertTrue(isinstance(v_sum, values.ReplicaLocalVariable)) return v_sum - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - replica_local_var = dist.call_for_each_replica(model_fn) + with distribution.scope(): + replica_local_var = distribution.extended.call_for_each_replica( + model_fn) self.assertTrue(isinstance(replica_local_var, values.ReplicaLocalVariable)) self.assertFalse(self.evaluate(replica_local_var.is_initialized())) @@ -1179,17 +1071,14 @@ class MirroredAndReplicaLocalVariableInitializerTest(test.TestCase): self.assertTrue(self.evaluate(replica_local_var.is_initialized())) +@combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=["graph", "eager"])) class ReplicaLocalVariableAssignTest(test.TestCase): - config = config_pb2.ConfigProto() - config.allow_soft_placement = True - def _skip_eager_if_gpus_less_than(self, num_gpus): - if context.num_gpus() < num_gpus and context.executing_eagerly(): - self.skipTest("Not enough GPUs available for this test in eager mode.") - - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignReplicaLocalVarSumAggregation(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignReplicaLocalVarSumAggregation(self, distribution): def model_fn(): v_sum = variable_scope.variable( 1.0, @@ -1197,18 +1086,16 @@ class ReplicaLocalVariableAssignTest(test.TestCase): aggregation=variable_scope.VariableAggregation.SUM) return v_sum - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - replica_local_var = dist.call_for_each_replica(model_fn) + with distribution.scope(): + replica_local_var = distribution.extended.call_for_each_replica(model_fn) self.assertTrue(isinstance(replica_local_var, values.ReplicaLocalVariable)) self.evaluate(variables.global_variables_initializer()) # Each replica has a value of 1.0 assigned to it in replica context. # When we read the value using `read_var` we should see the SUM of each of # values on each of the replicas. - self.assertEqual(2.0, self.evaluate(dist.read_var(replica_local_var))) + self.assertEqual(2.0, self.evaluate( + distribution.read_var(replica_local_var))) # Assigning 6.0 in cross replica context will assign a value of # 6.0/num_replicas to each replica. tlv_ops = replica_local_var.assign(6.0) @@ -1216,11 +1103,10 @@ class ReplicaLocalVariableAssignTest(test.TestCase): # On reading the replica local var we should get the assigned value back. # The value on all the replicas are added before being returned by # `read_var`. - self.assertEqual(6.0, self.evaluate(dist.read_var(replica_local_var))) + self.assertEqual(6.0, self.evaluate( + distribution.read_var(replica_local_var))) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignReplicaLocalVarMeanAggregation(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignReplicaLocalVarMeanAggregation(self, distribution): def model_fn(): v_sum = variable_scope.variable( 1.0, @@ -1228,23 +1114,22 @@ class ReplicaLocalVariableAssignTest(test.TestCase): aggregation=variable_scope.VariableAggregation.MEAN) return v_sum - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - replica_local_var = dist.call_for_each_replica(model_fn) + with distribution.scope(): + replica_local_var = distribution.extended.call_for_each_replica(model_fn) self.assertTrue(isinstance(replica_local_var, values.ReplicaLocalVariable)) self.evaluate(variables.global_variables_initializer()) # Each replica has a value of 1.0 assigned to it in replica context. # When we read the value using `read_var` we should see the MEAN of values # on all replicas which is the value assigned in replica context. - self.assertEqual(1.0, self.evaluate(dist.read_var(replica_local_var))) + self.assertEqual(1.0, self.evaluate( + distribution.read_var(replica_local_var))) tlv_ops = replica_local_var.assign(6.0) self.evaluate(tlv_ops) # On reading the replica local var we should get the MEAN of all values # which is equal to the value assigned. - self.assertEqual(6.0, self.evaluate(dist.read_var(replica_local_var))) + self.assertEqual(6.0, self.evaluate( + distribution.read_var(replica_local_var))) class MockModel(object): @@ -1278,24 +1163,25 @@ class MiniModel(keras_training.Model): return self.fc(inputs) +@combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=["graph", "eager"])) class MirroredStrategyDefunTest(test.TestCase): - def _skip_eager_if_gpus_less_than(self, num_gpus): - if context.num_gpus() < num_gpus and context.executing_eagerly(): - self.skipTest("Not enough GPUs available for this test in eager mode.") - - def _call_and_check(self, model_fn, inputs, expected_result, defuns, - two_variables=False): + def _call_and_check(self, distribution, model_fn, inputs, expected_result, + defuns, two_variables=False): cpu_dev = device_util.canonicalize("CPU:0") gpu_dev = device_util.canonicalize("GPU:0") devices = [cpu_dev, gpu_dev] - dist = mirrored_strategy.MirroredStrategy(devices) - with dist.scope(): + with distribution.scope(): mock_model = MockModel(two_variables) self.evaluate(variables.global_variables_initializer()) - result = dist.call_for_each_replica(model_fn, args=[mock_model] + inputs) + result = distribution.extended.call_for_each_replica( + model_fn, args=[mock_model] + inputs) for device in devices: device_result = values.select_device(device, result) device_expected_result = values.select_device(device, expected_result) @@ -1307,17 +1193,15 @@ class MirroredStrategyDefunTest(test.TestCase): # call_for_each has one trace per device. To check that the expected set # of variables was accessed on each trace, we first retrieve each # device-specific graph function. - per_replica_graph_functions = dist.call_for_each_replica( - defun.get_concrete_function, args=[mock_model] + inputs) + per_replica_graph_functions = ( + distribution.extended.call_for_each_replica( + defun.get_concrete_function, args=[mock_model] + inputs)) for device in devices: graph_function = per_replica_graph_functions.get(device=device) self.assertEqual(set(mock_model.variables), set(graph_function.graph.variables)) - @test_util.run_in_graph_and_eager_modes() - def testVariableInDefun(self): - self._skip_eager_if_gpus_less_than(1) - + def testVariableInDefun(self, distribution): @function.defun def times_two(mock_model): return mock_model() @@ -1325,12 +1209,9 @@ class MirroredStrategyDefunTest(test.TestCase): def model_fn(mock_model): return times_two(mock_model) - self._call_and_check(model_fn, [], 2.5, [times_two]) - - @test_util.run_in_graph_and_eager_modes() - def testVariableInNestedDefun(self): - self._skip_eager_if_gpus_less_than(1) + self._call_and_check(distribution, model_fn, [], 2.5, [times_two]) + def testVariableInNestedDefun(self, distribution): @function.defun def times_two(mock_model): return mock_model() @@ -1342,12 +1223,10 @@ class MirroredStrategyDefunTest(test.TestCase): def model_fn(mock_model): return two_x_plus_one(mock_model) - self._call_and_check(model_fn, [], 3.5, [times_two, two_x_plus_one]) - - @test_util.run_in_graph_and_eager_modes() - def testTwoVariablesInNestedDefun(self): - self._skip_eager_if_gpus_less_than(1) + self._call_and_check(distribution, model_fn, [], 3.5, + [times_two, two_x_plus_one]) + def testTwoVariablesInNestedDefun(self, distribution): @function.defun def fn1(mock_model): return mock_model() @@ -1359,12 +1238,10 @@ class MirroredStrategyDefunTest(test.TestCase): def model_fn(mock_model): return fn2(mock_model) - self._call_and_check(model_fn, [], 5.5, [fn1, fn2], two_variables=True) - - @test_util.run_in_graph_and_eager_modes() - def testGradientTapeOverNestedDefuns(self): - self._skip_eager_if_gpus_less_than(1) + self._call_and_check(distribution, model_fn, [], 5.5, [fn1, fn2], + two_variables=True) + def testGradientTapeOverNestedDefuns(self, distribution): @function.defun def fn1(mock_model): return mock_model() @@ -1380,13 +1257,10 @@ class MirroredStrategyDefunTest(test.TestCase): [v.get() for v in mock_model.variables]) return grads - self._call_and_check(model_fn, [], [2.0, 1.0], [fn1, fn2], + self._call_and_check(distribution, model_fn, [], [2.0, 1.0], [fn1, fn2], two_variables=True) - @test_util.run_in_graph_and_eager_modes() - def testPassPerReplica(self): - self._skip_eager_if_gpus_less_than(1) - + def testPassPerReplica(self, distribution): @function.defun def fn1(mock_model, factor): return mock_model(factor) @@ -1394,18 +1268,10 @@ class MirroredStrategyDefunTest(test.TestCase): factors = values.PerReplica({"CPU:0": 5.0, "GPU:0": 3.0}) expected_result = values.PerReplica({"CPU:0": 5.0 * 1.25, "GPU:0": 3.0 * 1.25}) - self._call_and_check(fn1, [factors], expected_result, [fn1]) + self._call_and_check(distribution, fn1, [factors], expected_result, [fn1]) - @test_util.run_in_graph_and_eager_modes() - def testTrain(self): - self._skip_eager_if_gpus_less_than(1) - - cpu_dev = device_util.canonicalize("CPU:0") - gpu_dev = device_util.canonicalize("GPU:0") - devices = [cpu_dev, gpu_dev] - dist = mirrored_strategy.MirroredStrategy(devices) - - with dist.scope(): + def testTrain(self, distribution): + with distribution.scope(): mock_model = MiniModel() mock_model.call = function.defun(mock_model.call) @@ -1415,10 +1281,11 @@ class MirroredStrategyDefunTest(test.TestCase): gradients_fn = backprop.implicit_grad(loss_fn) gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn) - grads_and_vars = dist.call_for_each_replica(gradients_fn, args=(None,)) + grads_and_vars = distribution.extended.call_for_each_replica( + gradients_fn, args=(None,)) optimizer = gradient_descent.GradientDescentOptimizer(0.25) - update_ops = optimizer._distributed_apply(dist, grads_and_vars) # pylint: disable=protected-access + update_ops = optimizer._distributed_apply(distribution, grads_and_vars) # pylint: disable=protected-access if not context.executing_eagerly(): self.evaluate(variables.global_variables_initializer()) @@ -1430,30 +1297,82 @@ class MirroredStrategyDefunTest(test.TestCase): self.assertAllEqual([0.5], updated_var_values[1]) +@combinations.generate( + combinations.combine( + distribution=[ + combinations.NamedDistribution( + "Mirrored", + # pylint: disable=g-long-lambda + lambda: mirrored_strategy.MirroredStrategy(num_gpus_per_worker= + context.num_gpus()), + required_gpus=1), + combinations.NamedDistribution( + "CoreMirrored", + # pylint: disable=g-long-lambda + lambda: mirrored_strategy.CoreMirroredStrategy( + mirrored_strategy.all_local_devices()), + required_gpus=1) + ], + mode=["graph"])) class MultiWorkerMirroredStrategyTest( multi_worker_test_base.MultiWorkerTestBase, strategy_test_lib.DistributionTestBase): - def _get_distribution_strategy(self): + def _configure_distribution_strategy(self, distribution): cluster_spec = server_lib.ClusterSpec({ "worker": ["/job:worker/task:0", "/job:worker/task:1"] }) - strategy = mirrored_strategy.MirroredStrategy(num_gpus=context.num_gpus()) - strategy.configure(cluster_spec=cluster_spec) - return strategy + distribution.configure(cluster_spec=cluster_spec) - def test_num_replicas_in_sync(self): - if not GPU_TEST: - self.skipTest("Not GPU test") - - strategy = self._get_distribution_strategy() + def test_num_replicas_in_sync(self, distribution): + self._configure_distribution_strategy(distribution) # We calculate the total number of gpus across the workers(2) specified in # the cluster spec. - self.assertEqual(context.num_gpus() * 2, strategy.num_replicas_in_sync) + self.assertEqual(context.num_gpus() * 2, distribution.num_replicas_in_sync) - def testMinimizeLossGraph(self): - self._test_minimize_loss_graph(self._get_distribution_strategy(), - learning_rate=0.05) + def testMinimizeLossGraph(self, distribution): + self._configure_distribution_strategy(distribution) + self._test_minimize_loss_graph(distribution, learning_rate=0.05) + + def testDeviceScope(self, distribution): + """Test the device scope of multi-worker MirroredStrategy.""" + self._configure_distribution_strategy(distribution) + with distribution.scope(): + a = constant_op.constant(1.) + with ops.device("/cpu:0"): + b = constant_op.constant(1.) + self.assertEqual(a.device, "/job:worker/task:0") + self.assertEqual(b.device, "/job:worker/task:0/device:CPU:0") + + def testMakeInputFnIterator(self, distribution): + self._configure_distribution_strategy(distribution) + dataset_fn = lambda: dataset_ops.Dataset.range(100) + num_gpus = context.num_gpus() + num_workers = 2 + + expected_values = [[i+j for j in range(num_gpus)] * num_workers + for i in range(0, 100, num_gpus)] + + with context.graph_mode(), self.cached_session() as sess: + # `expected_input_pipeline_id` is None because the input_fn will be called + # multiple times, each with a different input_pipeline_id. + input_fn = self._input_fn_to_test_input_context( + dataset_fn, + expected_num_replicas_in_sync=num_workers*num_gpus, + expected_num_input_pipelines=num_workers, + expected_input_pipeline_id=None) + iterator = distribution.make_input_fn_iterator(input_fn) + self._test_input_fn_iterator( + iterator, distribution.extended.worker_devices, expected_values, sess) + + def testUpdateConfigProto(self, distribution): + distribution.configure(cluster_spec={"worker": ["fake1", "fake2"]}) + + config_proto = config_pb2.ConfigProto() + new_config = distribution.update_config_proto(config_proto) + + # Verify isolate_session_state + self.assertTrue(new_config.isolate_session_state) class MultiWorkerMirroredStrategyTestWithChief( @@ -1473,6 +1392,19 @@ class MultiWorkerMirroredStrategyTestWithChief( strategy.configure(cluster_spec=self._cluster_spec) self._test_minimize_loss_graph(strategy, learning_rate=0.05) + def testMinimizeLossGraphCoreMirroredStrategy(self): + strategy = mirrored_strategy.CoreMirroredStrategy( + mirrored_strategy.all_local_devices()) + strategy.configure(cluster_spec=self._cluster_spec) + self._test_minimize_loss_graph(strategy, learning_rate=0.05) + + +def _replica_id(): + replica_id = ds_context.get_replica_context().replica_id_in_sync_group + if not isinstance(replica_id, ops.Tensor): + replica_id = constant_op.constant(replica_id) + return replica_id + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py deleted file mode 100644 index bea684e77ca..00000000000 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for class MirroredStrategy.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.distribute.python import mirrored_strategy -from tensorflow.contrib.distribute.python import strategy_test_lib -from tensorflow.python.eager import context -from tensorflow.python.eager import test -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util -from tensorflow.python.ops import variable_scope -from tensorflow.python.training import distribution_strategy_context - - -class MirroredOneCPUDistributionTest(strategy_test_lib.DistributionTestBase): - - def _get_distribution_strategy(self): - return mirrored_strategy.MirroredStrategy(["/device:CPU:0"]) - - def testMinimizeLossEager(self): - self._test_minimize_loss_eager(self._get_distribution_strategy()) - - def testMinimizeLossGraph(self): - self._test_minimize_loss_graph(self._get_distribution_strategy()) - - def testDeviceIndex(self): - self._test_device_index(self._get_distribution_strategy()) - - def testReplicaId(self): - self._test_replica_id(self._get_distribution_strategy()) - - @test_util.run_in_graph_and_eager_modes - def testCallAndMergeExceptions(self): - self._test_call_and_merge_exceptions(self._get_distribution_strategy()) - - -class VariableCreatorStackTest(test.TestCase): - - def testCreatorStacksAreThreadLocal(self): - devices = ["/device:CPU:0", "/device:GPU:0"] - dist = mirrored_strategy.MirroredStrategy(devices) - - def model_fn(device_id): - assert isinstance(device_id, int) - - def thread_creator_fn(next_creator, *args, **kwargs): - return next_creator(*args, **kwargs) + ":thread_" + str(device_id) - - with variable_scope.variable_creator_scope(thread_creator_fn): - # Create a variable in this scope. - v = variable_scope.variable(1.0) - - # This will pause the current thread, and execute the other thread. - distribution_strategy_context.get_replica_context().merge_call( - lambda _: _) - return v - - def main_thread_creator(next_creator, *args, **kwargs): - # We are not using the underlying next_creator for test purposes. - del next_creator, args, kwargs - return "main_thread" - - with context.graph_mode(), \ - dist.scope(), \ - variable_scope.variable_creator_scope(main_thread_creator): - result = dist.call_for_each_replica( - model_fn, args=(dist.worker_device_index,)) - result = dist.unwrap(result) - expected = ["main_thread:thread_0", "main_thread:thread_1"] - self.assertEquals(expected, result) - - -class MultiWorkerMirroredStrategyTest(test.TestCase): - - def testDeviceScope(self): - """Test the device scope of multi-worker MirroredStrategy.""" - with context.graph_mode(): - strategy = mirrored_strategy.MirroredStrategy(num_gpus=context.num_gpus()) - strategy.configure( - cluster_spec={"worker": ["/job:worker/task:0", "/job:worker/task:1"]}) - with strategy.scope(): - a = constant_op.constant(1.) - with ops.device("/cpu:0"): - b = constant_op.constant(1.) - self.assertEqual(a.device, "/job:worker/task:0") - self.assertEqual(b.device, "/job:worker/task:0/device:CPU:0") - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/distribute/python/moving_averages_test.py b/tensorflow/contrib/distribute/python/moving_averages_test.py index 7ecc852d205..c492d8bafc9 100644 --- a/tensorflow/contrib/distribute/python/moving_averages_test.py +++ b/tensorflow/contrib/distribute/python/moving_averages_test.py @@ -32,7 +32,8 @@ from tensorflow.python.training import moving_averages all_combinations = combinations.combine( distribution=[combinations.default_strategy, combinations.one_device_strategy, - combinations.mirrored_strategy_with_gpu_and_cpu], + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], mode=["graph"]) diff --git a/tensorflow/contrib/distribute/python/multi_worker_test_base.py b/tensorflow/contrib/distribute/python/multi_worker_test_base.py index 8eec3dc0f6e..147c9b83f86 100644 --- a/tensorflow/contrib/distribute/python/multi_worker_test_base.py +++ b/tensorflow/contrib/distribute/python/multi_worker_test_base.py @@ -18,8 +18,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import contextlib import copy +import json +import os import threading import numpy as np @@ -271,7 +274,6 @@ class MultiWorkerTestBase(test.TestCase): return config - def _run_client(self, client_fn, task_type, task_id, num_gpus, *args, **kwargs): result = client_fn(task_type, task_id, num_gpus, *args, **kwargs) @@ -303,3 +305,101 @@ class MultiWorkerTestBase(test.TestCase): for t in threads: t.join() self.assertEqual(self._result, len(threads)) + + +class MockOsEnv(collections.Mapping): + """A class that allows per-thread TF_CONFIG.""" + + def __init__(self, *args): + self._dict = dict() + self._thread_local = threading.local() + super(MockOsEnv, self).__init__(*args) + + def get(self, key, default=None): + if not hasattr(self._thread_local, 'dict'): + self._thread_local.dict = dict() + if key == 'TF_CONFIG': + return dict.get(self._thread_local.dict, key, default) + else: + return dict.get(self._dict, key, default) + + def __getitem__(self, key): + if not hasattr(self._thread_local, 'dict'): + self._thread_local.dict = dict() + if key == 'TF_CONFIG': + return dict.__getitem__(self._thread_local.dict, key) + else: + return dict.__getitem__(self._dict, key) + + def __setitem__(self, key, val): + if not hasattr(self._thread_local, 'dict'): + self._thread_local.dict = dict() + if key == 'TF_CONFIG': + return dict.__setitem__(self._thread_local.dict, key, val) + else: + return dict.__setitem__(self._dict, key, val) + + def __iter__(self): + if not hasattr(self._thread_local, 'dict'): + self._thread_local.dict = dict() + for x in self._thread_local.dict.items(): + yield x + for x in self._dict.items(): + yield x + + def __len__(self): + if not hasattr(self._thread_local, 'dict'): + self._thread_local.dict = dict() + return self._thread_local.dict.__len__() + self._dict.__len__() + + +class IndependentWorkerTestBase(test.TestCase): + """Testing infra for independent workers.""" + + def setUp(self): + self._mock_os_env = MockOsEnv() + self._mock_context = test.mock.patch.object(os, 'environ', + self._mock_os_env) + super(IndependentWorkerTestBase, self).setUp() + self._mock_context.__enter__() + + def tearDown(self): + self._mock_context.__exit__(None, None, None) + super(IndependentWorkerTestBase, self).tearDown() + + def _task_thread(self, task_fn, tf_config, *args, **kwargs): + os.environ['TF_CONFIG'] = json.dumps(tf_config) + task_fn(*args, **kwargs) + + def _run_task_in_thread(self, task_fn, cluster_spec, task_type, task_id, + *args, **kwargs): + if task_type: + tf_config = { + 'cluster': cluster_spec, + 'task': { + 'type': task_type, + 'index': task_id + } + } + else: + tf_config = { + 'cluster': cluster_spec, + } + t = threading.Thread( + target=self._task_thread, + args=(task_fn, tf_config) + args, + kwargs=kwargs) + t.start() + return t + + def run_multiple_tasks_in_threads(self, task_fn, cluster_spec, *args, + **kwargs): + # The task_fn should create std_server by itself. + threads = {} + for task_type in cluster_spec.keys(): + threads[task_type] = [] + for task_id in range(len(cluster_spec[task_type])): + t = self._run_task_in_thread(task_fn, cluster_spec, task_type, task_id, + *args, **kwargs) + threads[task_type].append(t) + return threads diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index a0d8f938874..e322b6acb84 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -20,12 +20,14 @@ from __future__ import print_function import six -from tensorflow.contrib.distribute.python import values +from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import values from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.util import nest @@ -39,7 +41,14 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): # implementations? def __init__(self, device): - super(OneDeviceStrategy, self).__init__() + super(OneDeviceStrategy, self).__init__(OneDeviceExtended(self, device)) + + +class OneDeviceExtended(distribute_lib.DistributionStrategyExtended): + """Implementation of OneDeviceStrategy.""" + + def __init__(self, container_strategy, device): + super(OneDeviceExtended, self).__init__(container_strategy) self._device = device self._default_device = device @@ -58,17 +67,33 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): with ops.colocate_with(colocate_with): return next_creator(*args, **kwargs) - def distribute_dataset(self, dataset_fn): + def _make_dataset_iterator(self, dataset): + """Make iterator from dataset without splitting the batch.""" + worker = device_util.canonicalize("/device:CPU:0") + worker_device_pairs = [(worker, [self._device])] + return values.DatasetIterator(dataset, worker_device_pairs) + + def _distribute_dataset(self, dataset_fn): return values.PerReplicaDataset( self._call_dataset_fn(dataset_fn), [self._device]) - def _broadcast(self, tensor, destinations): + def _make_input_fn_iterator( + self, + input_fn, + replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): + worker = device_util.canonicalize("/device:CPU:0") + worker_device_pairs = [(worker, [self._device])] + return values.InputFunctionIterator( + input_fn, worker_device_pairs, + [distribute_lib.InputContext()]) + + def _broadcast_to(self, tensor, destinations): del destinations return tensor # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. - def _run_steps_on_dataset(self, fn, iterator, iterations, - initial_loop_values=None): + def _experimental_run_steps_on_iterator(self, fn, iterator, iterations, + initial_loop_values=None): if initial_loop_values is None: initial_loop_values = {} initial_loop_values = nest.flatten(initial_loop_values) @@ -80,7 +105,7 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): fn_inputs = iterator.get_next() if not isinstance(fn_inputs, tuple): fn_inputs = (fn_inputs,) - fn_result = fn(ctx, *fn_inputs) + fn_result = fn(ctx, fn_inputs) flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) with ops.control_dependencies([fn_result]): return [i + 1] + flat_last_step_outputs @@ -114,25 +139,24 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): return ctx def _call_for_each_replica(self, fn, args, kwargs): - with ops.device(self._device), _OneDeviceReplicaContext(self): + strategy = self._container_strategy() + with ops.device(self._device), _OneDeviceReplicaContext(strategy): return fn(*args, **kwargs) - def _reduce(self, aggregation, value, destinations): - del aggregation, destinations + def _reduce_to(self, reduce_op, value, destinations): + del reduce_op, destinations return value - def _update(self, var, options, fn, *args, **kwargs): + def _update(self, var, fn, args, kwargs, group): # The implementations of _update() and _update_non_slot() are identical # except _update() passes `var` as the first argument to `fn()`. - return self._update_non_slot(var, options, fn, var, *args, **kwargs) + return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group) - def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs): + def _update_non_slot(self, colocate_with, fn, args, kwargs, group): del colocate_with - should_group = options.pop("grouped") - assert not options # Validate that we are processing all of the options. with ops.device(self._device), distribute_lib.UpdateContext(self._device): result = fn(*args, **kwargs) - if should_group: + if group: return result else: return nest.map_structure(self._unwrap, result) @@ -148,11 +172,7 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): return value @property - def num_replicas(self): - return 1 - - @property - def num_replicas_in_sync(self): + def _num_replicas_in_sync(self): return 1 @property @@ -167,8 +187,22 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): del var_list return [self._device] - def _worker_device_index(self): - return 0 + @property + def experimental_should_init(self): + return True + + @property + def should_checkpoint(self): + return True + + @property + def should_save_summary(self): + return True + + # TODO(priyag): Delete this once all strategies use global batch size. + @property + def _global_batch_size(self): + return True class _OneDeviceReplicaContext(distribute_lib.ReplicaContext): @@ -176,12 +210,10 @@ class _OneDeviceReplicaContext(distribute_lib.ReplicaContext): def __init__(self, distribution_strategy): distribute_lib.ReplicaContext.__init__( - self, distribution_strategy, replica_id=0) - - @property - def device(self): - raise RuntimeError("Use .devices instead") + self, + distribution_strategy, + replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)) @property def devices(self): - return [self._distribution_strategy.worker_devices[0]] + return [self._distribution_strategy.extended.worker_devices[0]] diff --git a/tensorflow/contrib/distribute/python/one_device_strategy_test.py b/tensorflow/contrib/distribute/python/one_device_strategy_test.py index 95f4cdb7868..d46cd6f529e 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy_test.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.contrib.distribute.python import one_device_strategy from tensorflow.contrib.distribute.python import strategy_test_lib +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import test from tensorflow.python.framework import test_util @@ -35,9 +36,6 @@ class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase): def testMinimizeLossGraph(self): self._test_minimize_loss_graph(self._get_distribution_strategy()) - def testDeviceIndex(self): - self._test_device_index(self._get_distribution_strategy()) - def testReplicaId(self): self._test_replica_id(self._get_distribution_strategy()) @@ -45,6 +43,20 @@ class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase): def testCallAndMergeExceptions(self): self._test_call_and_merge_exceptions(self._get_distribution_strategy()) + @test_util.run_in_graph_and_eager_modes + def testMakeInputFnIterator(self): + d = one_device_strategy.OneDeviceStrategy("/device:CPU:0") + dataset_fn = lambda: dataset_ops.Dataset.range(10) + expected_values = [[i] for i in range(10)] + input_fn = self._input_fn_to_test_input_context( + dataset_fn, + expected_num_replicas_in_sync=1, + expected_num_input_pipelines=1, + expected_input_pipeline_id=0) + iterator = d.make_input_fn_iterator(input_fn) + self._test_input_fn_iterator( + iterator, d.extended.worker_devices, expected_values) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py index 790b37f8601..eaeb4d70301 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py @@ -18,10 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib +import copy + from tensorflow.contrib.distribute.python import mirrored_strategy -from tensorflow.contrib.distribute.python import values +from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib +from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import multi_worker_util +from tensorflow.python.distribute import values from tensorflow.python.eager import context from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import ops @@ -30,8 +34,6 @@ from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import device_setter -from tensorflow.python.training import device_util -from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.util import nest _LOCAL_CPU = "/device:CPU:0" @@ -94,13 +96,21 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): ValueError: if `cluster_spec` is given but `task_type` or `task_id` is not. """ - super(ParameterServerStrategy, self).__init__() + super(ParameterServerStrategy, self).__init__( + ParameterServerExtended(self, num_gpus_per_worker)) + + +class ParameterServerExtended(distribute_lib.DistributionStrategyExtended): + """Implementation of ParameterServerStrategy.""" + + def __init__(self, container_strategy, num_gpus_per_worker): + super(ParameterServerExtended, self).__init__(container_strategy) self._num_gpus_per_worker = num_gpus_per_worker self._initialize_local(num_gpus_per_worker) # We typically don't need to do all-reduce in this strategy. - self._cross_tower_ops = ( - cross_tower_ops_lib.ReductionToOneDeviceCrossDeviceOps( + self._cross_device_ops = ( + cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps( reduce_to_device=_LOCAL_CPU)) def _initialize_multi_worker(self, num_gpus_per_worker, cluster_spec, @@ -189,6 +199,7 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): def _initialize_local(self, num_gpus_per_worker): """Initialize internal devices for local training.""" + self._worker_device = device_util.canonicalize("/device:CPU:0") # Define compute devices which is a list of device strings and one for each # replica. When there are GPUs, replicate operations on these GPUs. # Otherwise, place operations on CPU. @@ -221,15 +232,48 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): "ParameterServerStrategy with compute_devices = %r, " "variable_device = %r", self._compute_devices, self._variable_device) - def distribute_dataset(self, dataset_fn): + def _distribute_dataset(self, dataset_fn): """Distributes the dataset to each local GPU.""" return values.PerReplicaDataset( self._call_dataset_fn(dataset_fn), self._compute_devices, True) - def _broadcast(self, tensor, destinations): - if not cross_tower_ops_lib.check_destinations(destinations): + def _make_dataset_iterator(self, dataset): + worker_device_pairs = [(self._worker_device, self._compute_devices)] + return values.DatasetIterator(dataset, worker_device_pairs, + self._num_replicas_in_sync) + + def _make_input_fn_iterator( + self, + input_fn, + replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): + """Distributes the dataset to each local GPU.""" + if self._cluster_spec: + input_pipeline_id = multi_worker_util.id_in_cluster( + self._cluster_spec, self._task_type, self._task_id) + num_input_pipelines = multi_worker_util.worker_count( + self._cluster_spec, self._task_type) + else: + input_pipeline_id = 0 + num_input_pipelines = 1 + input_context = distribute_lib.InputContext( + num_input_pipelines=num_input_pipelines, + input_pipeline_id=input_pipeline_id, + num_replicas_in_sync=self._num_replicas_in_sync) + worker_device_pairs = [(self._worker_device, self._compute_devices)] + return values.InputFunctionIterator( + input_fn, worker_device_pairs, [input_context]) + + def _broadcast_to(self, tensor, destinations): + # This is both a fast path for Python constants, and a way to delay + # converting Python values to a tensor until we know what type it + # should be converted to. Otherwise we have trouble with: + # global_step.assign_add(1) + # since the `1` gets broadcast as an int32 but global_step is int64. + if isinstance(tensor, (float, int)): + return tensor + if not cross_device_ops_lib.check_destinations(destinations): destinations = self._compute_devices - return self._cross_tower_ops.broadcast(tensor, destinations) + return self._cross_device_ops.broadcast(tensor, destinations) def _allow_variable_partition(self): return not context.executing_eagerly() @@ -237,7 +281,7 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): # TODO(yuefengz): not all ops in device_setter.STANDARD_PS_OPS will go through # this creator, such as "MutableHashTable". def _create_variable(self, next_creator, *args, **kwargs): - if self.num_replicas_in_sync > 1: + if self._num_replicas_in_sync > 1: aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE) if aggregation not in ( vs.VariableAggregation.NONE, @@ -293,39 +337,35 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): def _call_for_each_replica(self, fn, args, kwargs): # pylint: disable=protected-access - return mirrored_strategy._call_for_each_replica(self, fn, args, kwargs) + return mirrored_strategy._call_for_each_replica( + self._container_strategy(), fn, args, kwargs) def _verify_destinations_not_different_worker(self, destinations): if not self._cluster_spec: return if destinations is None: return - for d in cross_tower_ops_lib.get_devices_from(destinations): + for d in cross_device_ops_lib.get_devices_from(destinations): d_spec = tf_device.DeviceSpec.from_string(d) if d_spec.job == self._task_type and d_spec.task != self._task_id: raise ValueError( "Cannot reduce to another worker: %r, current worker is %r" % (d, self._worker_device)) - def _reduce(self, aggregation, value, destinations): + def _reduce_to(self, reduce_op, value, destinations): self._verify_destinations_not_different_worker(destinations) if not isinstance(value, values.DistributedValues): # pylint: disable=protected-access return mirrored_strategy._reduce_non_distributed_value( - self, aggregation, value, destinations) - if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: - return self.broadcast(value.get(self._compute_devices[0]), destinations) - return self._cross_tower_ops.reduce( - aggregation, value, destinations=destinations) + self, reduce_op, value, destinations) + return self._cross_device_ops.reduce( + reduce_op, value, destinations=destinations) - def _batch_reduce(self, aggregation, value_destination_pairs): - if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: - return [self.broadcast(v.get(self._compute_devices[0]), d) - for v, d in value_destination_pairs] + def _batch_reduce_to(self, reduce_op, value_destination_pairs): for _, destinations in value_destination_pairs: self._verify_destinations_not_different_worker(destinations) - return self._cross_tower_ops.batch_reduce(aggregation, - value_destination_pairs) + return self._cross_device_ops.batch_reduce(reduce_op, + value_destination_pairs) def _select_single_value(self, structured): """Select any single values in `structured`.""" @@ -349,30 +389,26 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): return nest.map_structure(_select_fn, structured) - def _update(self, var, options, fn, *args, **kwargs): + def _update(self, var, fn, args, kwargs, group): if isinstance(var, values.AggregatingVariable): var = var.get() if not isinstance(var, resource_variable_ops.ResourceVariable): raise ValueError( "You can not update `var` %r. It must be a Variable." % var) - should_group = options.pop("grouped") - assert not options # Validate that we are processing all of the options. with ops.colocate_with(var), distribute_lib.UpdateContext(var.device): result = fn(var, *self._select_single_value(args), **self._select_single_value(kwargs)) - if should_group: + if group: return result else: return nest.map_structure(self._unwrap, result) # TODO(yuefengz): does it need to call _select_single_value? - def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs): - should_group = options.pop("grouped") - assert not options # Validate that we are processing all of the options. + def _update_non_slot(self, colocate_with, fn, args, kwargs, group): with ops.device( colocate_with.device), distribute_lib.UpdateContext(colocate_with): result = fn(*args, **kwargs) - if should_group: + if group: return result else: return nest.map_structure(self._unwrap, result) @@ -398,11 +434,11 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): # variables. return array_ops.identity(var) - def configure(self, - session_config=None, - cluster_spec=None, - task_type=None, - task_id=None): + def _configure(self, + session_config=None, + cluster_spec=None, + task_type=None, + task_id=None): """Configures the strategy class. The strategy object will be re-initialized if `cluster_spec` is given but @@ -433,28 +469,30 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): self._initialize_multi_worker(self._num_gpus_per_worker, self._cluster_spec, task_type, task_id) - if not session_config or not self._cluster_spec: - return + if session_config: + session_config.CopyFrom(self._update_config_proto(session_config)) - session_config.isolate_session_state = False + def _update_config_proto(self, config_proto): + updated_config = copy.deepcopy(config_proto) + if not self._cluster_spec: + updated_config.isolate_session_state = True + return updated_config + + updated_config.isolate_session_state = False - assert self._cluster_spec assert self._task_type assert self._task_id is not None # The device filters prevent communication between workers. if self._task_type not in ["chief", "worker"]: - return - del session_config.device_filters[:] - session_config.device_filters.extend( + return updated_config + del updated_config.device_filters[:] + updated_config.device_filters.extend( ["/job:%s/task:%d" % (self._task_type, self._task_id), "/job:ps"]) + return updated_config @property - def num_replicas(self): - return len(self._compute_devices) - - @property - def num_replicas_in_sync(self): + def _num_replicas_in_sync(self): return len(self._compute_devices) @property @@ -470,11 +508,12 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): return min(var_list, key=lambda x: x.name) @property - def between_graph(self): + def experimental_between_graph(self): + # TODO(yuefengz): Should this return False in the local case? return True @property - def should_init(self): + def experimental_should_init(self): return self._is_chief @property @@ -484,3 +523,8 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): @property def should_save_summary(self): return self._is_chief + + # TODO(priyag): Delete this once all strategies use global batch size. + @property + def _global_batch_size(self): + return False diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py index 81a23c89030..83d7473666a 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py @@ -25,14 +25,21 @@ from absl.testing import parameterized from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.contrib.distribute.python import parameter_server_strategy -from tensorflow.contrib.distribute.python import values +from tensorflow.contrib.distribute.python import strategy_test_lib from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.distribute import multi_worker_util +from tensorflow.python.distribute import reduce_util +from tensorflow.python.distribute import values from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.estimator import run_config from tensorflow.python.framework import constant_op +from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.layers import core from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -41,8 +48,6 @@ from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test -from tensorflow.python.training import device_util -from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import training_util CHIEF = run_config.TaskType.CHIEF @@ -50,6 +55,13 @@ WORKER = run_config.TaskType.WORKER PS = run_config.TaskType.PS +def _get_replica_id_integer(): + replica_id = ds_context.get_replica_context().replica_id_in_sync_group + if isinstance(replica_id, ops.Tensor): + replica_id = tensor_util.constant_value(replica_id) + return replica_id + + class ParameterServerStrategyTestBase( multi_worker_test_base.MultiWorkerTestBase): @@ -94,9 +106,8 @@ class ParameterServerStrategyTestBase( if num_gpus == 0: last_part_device = 'device:CPU:0' else: - last_part_device = ( - 'device:GPU:%d' % - distribution_strategy_context.get_replica_context().replica_id) + replica_id = _get_replica_id_integer() + last_part_device = ('device:GPU:%d' % replica_id) a = constant_op.constant(1.0) b = constant_op.constant(2.0) @@ -261,18 +272,16 @@ class ParameterServerStrategyTestBase( if 'CPU' in compute_device: replica_compute_device = '/device:CPU:0' else: - replica_compute_device = ( - '/device:GPU:%d' % - distribution_strategy_context.get_replica_context().replica_id) + replica_id = _get_replica_id_integer() + replica_compute_device = ('/device:GPU:%d' % replica_id) replica_compute_device = device_util.canonicalize( replica_compute_device) if 'CPU' in variable_device: replica_variable_device = '/device:CPU:0' else: - replica_variable_device = ( - '/device:GPU:%d' % - distribution_strategy_context.get_replica_context().replica_id) + replica_id = _get_replica_id_integer() + replica_variable_device = ('/device:GPU:%d' % replica_id) replica_variable_device = device_util.canonicalize( replica_variable_device) @@ -354,9 +363,9 @@ class ParameterServerStrategyTestBase( def _test_simple_increment(self, task_type, task_id, num_gpus): d, master_target, sess_config = self._get_test_objects( task_type, task_id, num_gpus) - if hasattr(d, '_cluster_spec') and d._cluster_spec: - num_workers = len(d._cluster_spec.as_dict().get(WORKER)) - if 'chief' in d._cluster_spec.as_dict(): + if d.extended._cluster_spec: + num_workers = len(d.extended._cluster_spec.as_dict().get(WORKER)) + if 'chief' in d.extended._cluster_spec.as_dict(): num_workers += 1 else: num_workers = 1 @@ -389,7 +398,7 @@ class ParameterServerStrategyTestBase( x, y, z, train_op = d.call_for_each_replica(model_fn) train_op = d.group(train_op) - if context.num_gpus() < d._num_gpus_per_worker: + if context.num_gpus() < d.extended._num_gpus_per_worker: return True if task_id == 0: @@ -426,9 +435,9 @@ class ParameterServerStrategyTestBase( task_type, task_id, num_gpus) if task_type: # Multi-worker - assert hasattr(d, '_cluster_spec') and d._cluster_spec - num_workers = len(d._cluster_spec.as_dict().get(WORKER)) - if CHIEF in d._cluster_spec.as_dict(): + assert hasattr(d.extended, '_cluster_spec') and d.extended._cluster_spec + num_workers = len(d.extended._cluster_spec.as_dict().get(WORKER)) + if CHIEF in d.extended._cluster_spec.as_dict(): num_workers += 1 else: # local @@ -472,8 +481,8 @@ class ParameterServerStrategyTestBase( before_list.append(fetched) with ops.control_dependencies([fetched]): # TODO(yuefengz): support non-Mirrored variable as destinations. - g = d.reduce( - variable_scope.VariableAggregation.SUM, g, destinations=v) + g = d.extended.reduce_to( + reduce_util.ReduceOp.SUM, g, destinations=v) with ops.control_dependencies( d.update(v, update, g, grouped=False)): after_list.append(d.read_var(v)) @@ -481,11 +490,12 @@ class ParameterServerStrategyTestBase( before_out, after_out = step() - if context.num_gpus() < d._num_gpus_per_worker: + if context.num_gpus() < d.extended._num_gpus_per_worker: return True if (not task_type or - multi_worker_util.is_chief(d._cluster_spec, task_type, task_id)): + multi_worker_util.is_chief( + d.extended._cluster_spec, task_type, task_id)): variables.global_variables_initializer().run() # Workers waiting for chief worker's initializing variables. @@ -508,8 +518,40 @@ class ParameterServerStrategyTestBase( self.assertLess(error_after, error_before) return error_after < error_before + def _test_input_fn_iterator(self, task_type, task_id, num_gpus, input_fn, + expected_values): + distribution, master_target, config = self._get_test_objects( + task_type, task_id, num_gpus) + devices = distribution.extended.worker_devices + + with ops.Graph().as_default(), \ + self.cached_session(config=config, + target=master_target) as sess: + iterator = distribution.make_input_fn_iterator(input_fn) + sess.run(iterator.initialize()) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = sess.run( + [values.select_device(d, next_element) for d in devices]) + self.assertEqual(expected_value, computed_value) + + with self.assertRaises(errors.OutOfRangeError): + next_element = iterator.get_next() + sess.run([values.select_device(d, next_element) for d in devices]) + + # After re-initializing the iterator, should be able to iterate again. + sess.run(iterator.initialize()) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = sess.run( + [values.select_device(d, next_element) for d in devices]) + self.assertEqual(expected_value, computed_value) + class ParameterServerStrategyTest(ParameterServerStrategyTestBase, + strategy_test_lib.DistributionTestBase, parameterized.TestCase): @classmethod @@ -574,6 +616,73 @@ class ParameterServerStrategyTest(ParameterServerStrategyTestBase, def testMinimizeLossGraphLocal(self, num_gpus): self._test_minimize_loss_graph(None, None, num_gpus) + # TODO(priyag): Refactor this and other multi worker tests. + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[1, 2], required_gpus=1)) + def testMakeInputFnIteratorDistributed(self, num_gpus): + if context.num_gpus() < num_gpus: + self.skipTest('Not enough GPUs') + dataset_fn = lambda: dataset_ops.Dataset.range(100) + expected_values = [[i+j for j in range(num_gpus)] + for i in range(0, 100, num_gpus)] + + input_fn = self._input_fn_to_test_input_context( + dataset_fn, + expected_num_replicas_in_sync=num_gpus, + expected_num_input_pipelines=3, + expected_input_pipeline_id=1) # because task_id = 1 + self._test_input_fn_iterator('worker', 1, num_gpus, + input_fn, expected_values) + + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[1, 2], required_gpus=1)) + def testMakeInputFnIteratorLocal(self, num_gpus): + if context.num_gpus() < num_gpus: + self.skipTest('Not enough GPUs') + dataset_fn = lambda: dataset_ops.Dataset.range(100) + expected_values = [[i+j for j in range(num_gpus)] + for i in range(0, 100, num_gpus)] + + input_fn = self._input_fn_to_test_input_context( + dataset_fn, + expected_num_replicas_in_sync=num_gpus, + expected_num_input_pipelines=1, + expected_input_pipeline_id=0) # only one worker and pipeline for local. + self._test_input_fn_iterator(None, None, num_gpus, + input_fn, expected_values) + + def testGlobalStepUpdate(self): + strategy = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=context.num_gpus()) + self._test_global_step_update(strategy) + + def testUpdateConfigProtoMultiWorker(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + distribution.configure( + cluster_spec=self._cluster_spec, task_type='worker', task_id=1) + + config_proto = config_pb2.ConfigProto(device_filters=['to_be_overridden']) + + new_config = distribution.update_config_proto(config_proto) + + # Verify device filters. + self.assertEqual(['/job:worker/task:1', '/job:ps'], + new_config.device_filters) + + # Verify isolate_session_state + self.assertFalse(new_config.isolate_session_state) + + def testUpdateConfigProtoLocal(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + + config_proto = config_pb2.ConfigProto() + new_config = distribution.update_config_proto(config_proto) + + # Verify isolate_session_state + self.assertTrue(new_config.isolate_session_state) + class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase, parameterized.TestCase): @@ -616,9 +725,9 @@ class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase, v = variable_scope.get_variable('v', initializer=10.0) _ = v * v v, = tape.watched_variables() - w = distribution.value_container(v) + w = distribution.extended.value_container(v) self.assertIs(values.AggregatingVariable, type(w)) - distribution.call_for_each_replica(f) + distribution.extended.call_for_each_replica(f) if __name__ == '__main__': diff --git a/tensorflow/contrib/distribute/python/step_fn.py b/tensorflow/contrib/distribute/python/step_fn.py index 3dc815f0371..c928b6d9f1f 100644 --- a/tensorflow/contrib/distribute/python/step_fn.py +++ b/tensorflow/contrib/distribute/python/step_fn.py @@ -94,7 +94,7 @@ class StandardSingleLossStep(StandardInputStep): def __call__(self): with self._distribution.scope(): - def step_fn(ctx, *inputs): + def step_fn(ctx, inputs): """Function to run one iteration with one input.""" gradients_fn = backprop.implicit_grad(self._loss_fn) gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn) diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py index 3c0c10430eb..d50b142c5e9 100644 --- a/tensorflow/contrib/distribute/python/strategy_test_lib.py +++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py @@ -19,16 +19,21 @@ from __future__ import division from __future__ import print_function from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.distribute import distribution_strategy_context as ds_context +from tensorflow.python.distribute import reduce_util +from tensorflow.python.distribute import values from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.layers import core from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables -from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import optimizer @@ -45,8 +50,7 @@ def _raise_exception_fn(_=None): # Must be the argument to a distribution.call_for_each_replica() call, calls a # get_replica_context().merge_call() that raises an exception. def _merge_raises_fn(): - distribution_strategy_context.get_replica_context().merge_call( - _raise_exception_fn) + ds_context.get_replica_context().merge_call(_raise_exception_fn) # Must be the argument to a get_replica_context().merge_call() call, calls @@ -59,8 +63,7 @@ def _call_raises_fn(dist): # calls a get_replica_context().merge_call() that calls a # call_for_each_replica() that raises an exception. def _merge_call_raises_fn(): - distribution_strategy_context.get_replica_context().merge_call( - _call_raises_fn) + ds_context.get_replica_context().merge_call(_call_raises_fn) # Must be the argument to a get_replica_context().merge_call() call, calls @@ -74,8 +77,7 @@ def _call_merge_raises_fn(dist): # get_replica_context().merge_call() that calls a call_for_each_replica() that # calls a get_replica_context().merge_call() that raises an exception. def _merge_call_merge_raises_fn(): - distribution_strategy_context.get_replica_context().merge_call( - _call_merge_raises_fn) + ds_context.get_replica_context().merge_call(_call_merge_raises_fn) class DistributionTestBase(test.TestCase): @@ -114,8 +116,8 @@ class DistributionTestBase(test.TestCase): before_list.append(fetched) # control_dependencies irrelevant but harmless in eager execution with ops.control_dependencies([fetched]): - g = d.reduce( - variable_scope.VariableAggregation.SUM, g, destinations=v) + g = d.extended.reduce_to( + reduce_util.ReduceOp.SUM, g, destinations=v) with ops.control_dependencies(d.update( v, update, g, grouped=False)): after_list.append(d.read_var(v)) @@ -169,8 +171,8 @@ class DistributionTestBase(test.TestCase): fetched = d.read_var(v) before_list.append(fetched) with ops.control_dependencies([fetched]): - g = d.reduce( - variable_scope.VariableAggregation.SUM, g, destinations=v) + g = d.extended.reduce_to( + reduce_util.ReduceOp.SUM, g, destinations=v) with ops.control_dependencies(d.update( v, update, g, grouped=False)): after_list.append(d.read_var(v)) @@ -189,31 +191,20 @@ class DistributionTestBase(test.TestCase): # Error should go down self.assertLess(error_after, error_before) - def _test_device_index(self, d): - with d.scope(): - expected_devices = [False] * len(d.worker_devices) - - def mark_devices_fn(device_id): - self.assertLess(device_id, len(d.worker_devices)) - self.assertFalse(expected_devices[device_id]) - expected_devices[device_id] = True - - d.call_for_each_replica(mark_devices_fn, args=(d.worker_device_index,)) - self.assertAllEqual(expected_devices, [True] * len(d.worker_devices)) - def _test_replica_id(self, d): with d.scope(): - expected_devices = [False] * len(d.worker_devices) + expected_devices = [False] * len(d.extended.worker_devices) def mark_devices_fn(): - replica_id = ( - distribution_strategy_context.get_replica_context().replica_id) - self.assertLess(replica_id, len(d.worker_devices)) + replica_id = self.evaluate( + ds_context.get_replica_context().replica_id_in_sync_group) + self.assertLess(replica_id, len(d.extended.worker_devices)) self.assertFalse(expected_devices[replica_id]) expected_devices[replica_id] = True d.call_for_each_replica(mark_devices_fn) - self.assertAllEqual(expected_devices, [True] * len(d.worker_devices)) + self.assertAllEqual(expected_devices, + [True] * len(d.extended.worker_devices)) def _test_call_and_merge_exceptions(self, dist): with dist.scope(): @@ -225,3 +216,78 @@ class DistributionTestBase(test.TestCase): dist.call_for_each_replica(_merge_call_raises_fn) with self.assertRaises(_TestException): dist.call_for_each_replica(_merge_call_merge_raises_fn) + + def _input_fn_to_test_input_context(self, + dataset_fn, + expected_num_replicas_in_sync, + expected_num_input_pipelines, + expected_input_pipeline_id): + # Use a list of one element as counter so that it can be captured by the + # `_input_fn`. This counter is incremented by 1 each time an input_fn is + # called. We use this counter to check whether the `input_pipeline_id` + # matches the counter in the in-graph replication. + worker_id_counter = [0] + + def _input_fn(input_context): + """Input fn for testing.""" + self.assertIsNotNone(input_context) + self.assertEqual(expected_num_replicas_in_sync, + input_context.num_replicas_in_sync) + self.assertEqual(expected_num_input_pipelines, + input_context.num_input_pipelines) + if expected_input_pipeline_id is not None: + self.assertEqual(expected_input_pipeline_id, + input_context.input_pipeline_id) + else: + self.assertEqual(worker_id_counter[0], input_context.input_pipeline_id) + worker_id_counter[0] += 1 + + return dataset_fn() + + return _input_fn + + def _test_input_fn_iterator(self, iterator, devices, expected_values, + sess=None): + evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) + evaluate(iterator.initialize()) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = evaluate( + [values.select_device(d, next_element) for d in devices]) + self.assertEqual(expected_value, computed_value) + + with self.assertRaises(errors.OutOfRangeError): + next_element = iterator.get_next() + evaluate([values.select_device(d, next_element) for d in devices]) + + # After re-initializing the iterator, should be able to iterate again. + evaluate(iterator.initialize()) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = evaluate( + [values.select_device(d, next_element) for d in devices]) + self.assertEqual(expected_value, computed_value) + + def _test_global_step_update(self, strategy): + with strategy.scope(): + global_step = variable_scope.get_variable( + "global_step", + shape=[], + dtype=dtypes.int64, + initializer=init_ops.zeros_initializer(), + trainable=False, + aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA) + self.evaluate(variables.global_variables_initializer()) + + def model_fn(): + train_op = global_step.assign_add(1) + value = global_step.read_value() + return train_op, value + + train_ops, value = strategy.call_for_each_replica(model_fn) + self.evaluate(strategy.group(train_ops)) + global_step_tensors = strategy.unwrap(value) + global_step_values = self.evaluate(global_step_tensors) + self.assertEqual([1] * len(global_step_tensors), global_step_values) diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index f5b4531ba8c..39ed8f7cf10 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -21,25 +21,28 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import copy import functools -from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib -from tensorflow.contrib.distribute.python import values from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib from tensorflow.contrib.tpu.python.tpu import training_loop +from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib +from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import reduce_util +from tensorflow.python.distribute import values from tensorflow.python.eager import context from tensorflow.python.eager import tape from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope as vs -from tensorflow.python.ops import variables as variables_lib -from tensorflow.python.training import device_util -from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.util import nest @@ -130,8 +133,21 @@ class TPUStrategy(distribute_lib.DistributionStrategy): num_cores: Number of cores to use on the TPU. If None specified, then auto-detect the cores and topology of the TPU system. """ - super(TPUStrategy, self).__init__() + super(TPUStrategy, self).__init__(TPUExtended( + self, tpu_cluster_resolver, steps_per_run, num_cores)) + @property + def steps_per_run(self): + """DEPRECATED: use .extended.steps_per_run instead.""" + return self._extended.steps_per_run + + +class TPUExtended(distribute_lib.DistributionStrategyExtended): + """Implementation of TPUStrategy.""" + + def __init__(self, container_strategy, tpu_cluster_resolver, steps_per_run, + num_cores=None): + super(TPUExtended, self).__init__(container_strategy) self._tpu_cluster_resolver = tpu_cluster_resolver self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver) # TODO(sourabhbajaj): Change this from num_cores to metadata_override @@ -145,7 +161,7 @@ class TPUStrategy(distribute_lib.DistributionStrategy): self._host_device = self.get_host_cpu_device(0) self._tpu_devices = sorted(device_map.keys()) # Only create variables for the number of replicas we're running. - self._tpu_devices = self._tpu_devices[:self.num_replicas] + self._tpu_devices = self._tpu_devices[:self._num_replicas_in_sync] # TODO(sourabhbajaj): Remove this once performance of running one step # at a time is comparable to multiple steps. @@ -214,7 +230,17 @@ class TPUStrategy(distribute_lib.DistributionStrategy): return enqueue_op_per_host - def distribute_dataset(self, dataset_fn): + def _make_dataset_iterator(self, dataset): + """Make iterators for each of the TPU hosts.""" + + worker_devices = [ + (self.get_host(hid), [self.get_host_cpu_device(hid)]) + for hid in range(self.num_hosts) + ] + return values.DatasetIterator(dataset, worker_devices, + self._num_replicas_in_sync) + + def _distribute_dataset(self, dataset_fn): worker_devices = [ (self.get_host(hid), [self.get_host_cpu_device(hid)]) for hid in range(self.num_hosts) @@ -225,12 +251,11 @@ class TPUStrategy(distribute_lib.DistributionStrategy): # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. # TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have # a mechanism to infer the outputs of `fn`. Pending b/110550782. - def _run_steps_on_dataset(self, fn, multi_worker_iterator, iterations, - initial_loop_values=None): - + def _experimental_run_steps_on_iterator( + self, fn, multi_worker_iterator, iterations, initial_loop_values=None): output_shapes = multi_worker_iterator.output_shapes shapes = nest.flatten(output_shapes) - if any([not s.is_fully_defined() for s in shapes]): + if any(not s.is_fully_defined() for s in shapes): raise ValueError( "TPU currently requires fully defined shapes. Either use " "set_shape() on the input tensors or use " @@ -251,13 +276,13 @@ class TPUStrategy(distribute_lib.DistributionStrategy): initial_loop_values = {} initial_loop_values = nest.flatten(initial_loop_values) ctx = values.MultiStepContext() - def run_fn(*args, **kwargs): + + def run_fn(): """Single step on the TPU device.""" - del args, kwargs fn_inputs = dequeue_fn() if not isinstance(fn_inputs, tuple): fn_inputs = (fn_inputs,) - fn_result = fn(ctx, *fn_inputs) + fn_result = fn(ctx, fn_inputs) flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) if flat_last_step_outputs: with ops.control_dependencies([fn_result]): @@ -265,11 +290,6 @@ class TPUStrategy(distribute_lib.DistributionStrategy): else: return fn_result - # TODO(sourabhbajaj): The input to while loop should be based on the output - # type of the step_fn - def iterate_on_tpu(): - return training_loop.repeat(iterations, run_fn, initial_loop_values) - # We capture the control_flow_context at this point, before we run `fn` # inside a while_loop and TPU replicate context. This is useful in cases # where we might need to exit these contexts and get back to the outer @@ -279,38 +299,70 @@ class TPUStrategy(distribute_lib.DistributionStrategy): self._outer_control_flow_context = ( ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access - replicate_inputs = [[]] * self.num_replicas - replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs) + def rewrite_fn(*args): + """The rewritten step fn running on TPU.""" + del args + replicate_inputs = [[]] * self._num_replicas_in_sync + replicate_outputs = tpu.replicate(run_fn, replicate_inputs) + + # If run_fn has tensor outputs, tpu.replicate returns a list of list. We + # will flatten it in this case. If run_fn has no tensor outputs, + # tpu.replicate returns a list of no_ops, we will keep the output as it + # is. + if isinstance(replicate_outputs[0], list): + replicate_outputs = nest.flatten(replicate_outputs) + + return replicate_outputs + + # TODO(sourabhbajaj): The input to while loop should be based on the output + # type of the step_fn + assert isinstance(initial_loop_values, list) + initial_loop_values = initial_loop_values * self._num_replicas_in_sync + + # Put the while loop op on host 0. + with ops.device(self.get_host_cpu_device(0)): + replicate_outputs = training_loop.repeat(iterations, rewrite_fn, + initial_loop_values) + del self._outer_control_flow_context ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops) - # Filter out any ops from the outputs, typically this would be the case - # when there were no tensor outputs. - last_step_tensor_outputs = [x for x in replicate_outputs - if not isinstance(x, ops.Operation)] + if isinstance(replicate_outputs, list): + # Filter out any ops from the outputs, typically this would be the case + # when there were no tensor outputs. + last_step_tensor_outputs = [ + x for x in replicate_outputs if not isinstance(x, ops.Operation) + ] - # Outputs are currently of the structure (grouped by device) - # [[output0_device0, output1_device0, output2_device0], - # [output0_device1, output1_device1, output2_device1]] - # Convert this to the following structure instead: (grouped by output) - # [[output0_device0, output0_device1], - # [output1_device0, output1_device1], - # [output2_device0, output2_device1]] - last_step_tensor_outputs = [list(x) for x in zip(*last_step_tensor_outputs)] + # Outputs are currently of the structure (flattened) + # [output0_device0, output1_device0, output2_device0, + # output0_device1, output1_device1, output2_device1, + # ...] + # Convert this to the following structure instead: (grouped by output) + # [[output0_device0, output0_device1], + # [output1_device0, output1_device1], + # [output2_device0, output2_device1]] + output_num = len(last_step_tensor_outputs) // self._num_replicas_in_sync + last_step_tensor_outputs = [ + last_step_tensor_outputs[i::output_num] for i in range(output_num) + ] + else: + # no tensors returned. + last_step_tensor_outputs = [] # Convert replicate_outputs to the original dict structure of # last_step_outputs. last_step_tensor_outputs_dict = nest.pack_sequence_as( ctx.last_step_outputs, last_step_tensor_outputs) - for (name, aggregation) in ctx._last_step_outputs_aggregations.items(): # pylint: disable=protected-access + for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access output = last_step_tensor_outputs_dict[name] - # For outputs that have already been aggregated, take the first value + # For outputs that have already been reduced, take the first value # from the list as each value should be the same. Else return the full # list of values. - # TODO(josh11b): If aggregation is NONE, we should return a PerReplica + # TODO(josh11b): If reduce_op is NONE, we should return a PerReplica # value. - if aggregation is not variables_lib.VariableAggregation.NONE: + if reduce_op is not None: # TODO(priyag): Should this return the element or a list with 1 element last_step_tensor_outputs_dict[name] = output[0] ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access @@ -320,10 +372,10 @@ class TPUStrategy(distribute_lib.DistributionStrategy): def _call_for_each_replica(self, fn, args, kwargs): # TODO(jhseu): Consider making it so call_for_each_replica implies that # we're in a tpu.rewrite(), and update TPUMirroredVariable accordingly. - with _TPUReplicaContext(self): + with _TPUReplicaContext(self._container_strategy()): return fn(*args, **kwargs) - def initialize(self): + def _initialize(self): if context.executing_eagerly(): # TODO(priyag): Add appopriate call here when eager is supported for TPUs. raise NotImplementedError("Eager mode not supported in TPUStrategy.") @@ -338,7 +390,7 @@ class TPUStrategy(distribute_lib.DistributionStrategy): tpu.initialize_system()) return graph.get_collection(_TPU_INITIALIZE_SYSTEM_COLLECTION) - def finalize(self): + def _finalize(self): if context.executing_eagerly(): # TODO(priyag): Add appopriate call here when eager is supported for TPUs. raise NotImplementedError("Eager mode not supported in TPUStrategy.") @@ -346,7 +398,7 @@ class TPUStrategy(distribute_lib.DistributionStrategy): return [tpu.shutdown_system()] def _get_devices_from(self, colocate_with=None): - # TODO(jhseu): Change this when we support model parallelism. + # TODO(jhseu): Change this when we support model parallelism. return self._tpu_devices def _create_variable(self, next_creator, *args, **kwargs): @@ -383,12 +435,12 @@ class TPUStrategy(distribute_lib.DistributionStrategy): return _create_tpu_mirrored_variable(devices, _real_mirrored_creator, *args, **kwargs) - def _reduce(self, aggregation, value, destinations): + def _reduce_to(self, reduce_op, value, destinations): if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access - if aggregation == vs.VariableAggregation.MEAN: + if reduce_op == reduce_util.ReduceOp.MEAN: # TODO(jhseu): Revisit once we support model-parallelism. - value *= (1. / self.num_replicas) - elif aggregation != vs.VariableAggregation.SUM: + value *= (1. / self._num_replicas_in_sync) + elif reduce_op != reduce_util.ReduceOp.SUM: raise NotImplementedError( "Currently only support sum & mean in TPUStrategy.") return tpu_ops.cross_replica_sum(value) @@ -396,27 +448,22 @@ class TPUStrategy(distribute_lib.DistributionStrategy): # Validate that the destination is same as the host device # Note we don't do this when in replicate context as the reduction is # performed on the TPU device itself. - devices = cross_tower_ops_lib.get_devices_from(destinations) + devices = cross_device_ops_lib.get_devices_from(destinations) if len(devices) == 1: assert device_util.canonicalize(devices[0]) == device_util.canonicalize( self._host_device) else: raise ValueError("Multiple devices are not supported for TPUStrategy") - if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: - return value[0] output = math_ops.add_n(value) - if aggregation == vs.VariableAggregation.MEAN: + if reduce_op == reduce_util.ReduceOp.MEAN: return output * (1. / len(value)) return output - def _update(self, var, options, fn, *args, **kwargs): + def _update(self, var, fn, args, kwargs, group): assert isinstance(var, values.TPUMirroredVariable) - should_group = options.pop("grouped") - assert not options # Validate that we are processing all of the options. - if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access - if should_group: + if group: return fn(var, *args, **kwargs) else: return [fn(var, *args, **kwargs)] @@ -431,9 +478,7 @@ class TPUStrategy(distribute_lib.DistributionStrategy): updates[d] = fn(v, *values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs)) - return values.update_regroup(self, updates, should_group) - - # TODO(josh11b): Need to implement _update_non_slot()! + return values.update_regroup(self, updates, group) def read_var(self, var): assert isinstance(var, values.TPUMirroredVariable) @@ -453,14 +498,10 @@ class TPUStrategy(distribute_lib.DistributionStrategy): def value_container(self, value): return value - def _broadcast(self, tensor, destinations): + def _broadcast_to(self, tensor, destinations): del destinations return tensor - @property - def num_replicas(self): - return self._num_cores_override or self._tpu_metadata.num_cores - @property def num_hosts(self): return self._tpu_metadata.num_hosts @@ -470,15 +511,15 @@ class TPUStrategy(distribute_lib.DistributionStrategy): return self._tpu_metadata.num_of_cores_per_host @property - def num_replicas_in_sync(self): - return self.num_replicas + def _num_replicas_in_sync(self): + return self._num_cores_override or self._tpu_metadata.num_cores @property - def between_graph(self): + def experimental_between_graph(self): return False @property - def should_init(self): + def experimental_should_init(self): return True @property @@ -500,14 +541,12 @@ class TPUStrategy(distribute_lib.DistributionStrategy): def non_slot_devices(self, var_list): return self._host_device - def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs): + def _update_non_slot(self, colocate_with, fn, args, kwargs, group): del colocate_with - should_group = options.pop("grouped") - assert not options # Validate that we are processing all of the options. with ops.device(self._host_device), distribute_lib.UpdateContext( self._host_device): result = fn(*args, **kwargs) - if should_group: + if group: return result else: return nest.map_structure(self._unwrap, result) @@ -521,17 +560,27 @@ class TPUStrategy(distribute_lib.DistributionStrategy): def get_host_cpu_device(self, host_id): return self.get_host(host_id) + "/device:CPU:0" - def configure(self, - session_config=None, - cluster_spec=None, - task_type=None, - task_id=None): + def _configure(self, + session_config=None, + cluster_spec=None, + task_type=None, + task_id=None): del cluster_spec, task_type, task_id if session_config: - session_config.isolate_session_state = True - cluster_spec = self._tpu_cluster_resolver.cluster_spec() - if cluster_spec: - session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) + session_config.CopyFrom(self._update_config_proto(session_config)) + + def _update_config_proto(self, config_proto): + updated_config = copy.deepcopy(config_proto) + updated_config.isolate_session_state = True + cluster_spec = self._tpu_cluster_resolver.cluster_spec() + if cluster_spec: + updated_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) + return updated_config + + # TODO(priyag): Delete this once all strategies use global batch size. + @property + def _global_batch_size(self): + return True class _TPUReplicaContext(distribute_lib.ReplicaContext): @@ -540,13 +589,14 @@ class _TPUReplicaContext(distribute_lib.ReplicaContext): # TODO(sourabhbajaj): Call for each tower should be updating this. def __init__(self, distribution_strategy): distribute_lib.ReplicaContext.__init__( - self, distribution_strategy, replica_id=0) - - @property - def device(self): - raise RuntimeError("Use .devices instead") + self, + distribution_strategy, + # TODO(b/118385803): properly initialize replica_id, instead of always 0 + replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)) @property def devices(self): distribute_lib.require_replica_context(self) - return [self._distribution_strategy.worker_devices[self._replica_id]] + ds = self._distribution_strategy + replica_id = tensor_util.constant_value(self._replica_id_in_sync_group) + return [ds.extended.worker_devices[replica_id]] diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py index 268393ee801..538b859f3d1 100644 --- a/tensorflow/contrib/distribute/python/values_test.py +++ b/tensorflow/contrib/distribute/python/values_test.py @@ -19,12 +19,15 @@ from __future__ import division from __future__ import print_function import os +from absl.testing import parameterized -from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import multi_worker_test_base -from tensorflow.contrib.distribute.python import values from tensorflow.core.protobuf import config_pb2 from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import values from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.estimator import model_fn as model_fn_lib @@ -34,10 +37,10 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib -from tensorflow.python.training import device_util from tensorflow.python.training import saver as saver_lib from tensorflow.python.util import nest @@ -324,20 +327,20 @@ class RegroupAndSelectDeviceTest(test.TestCase): self.assertTrue( isinstance(merged_estimator_spec, model_fn_lib.EstimatorSpec)) - self.assertEquals(model_fn_lib.ModeKeys.TRAIN, merged_estimator_spec.mode) + self.assertEqual(model_fn_lib.ModeKeys.TRAIN, merged_estimator_spec.mode) for device_id in range(3): d = _device_str(device_id) - self.assertEquals(created_estimator_specs[device_id].loss, - merged_estimator_spec.loss.get(d)) - self.assertEquals(created_estimator_specs[device_id].train_op, - merged_estimator_spec.train_op.get(d)) + self.assertEqual(created_estimator_specs[device_id].loss, + merged_estimator_spec.loss.get(d)) + self.assertEqual(created_estimator_specs[device_id].train_op, + merged_estimator_spec.train_op.get(d)) # Scaffold is populated by `EstimatorSpec.__new__`. - self.assertEquals(created_estimator_specs[device_id].scaffold, - merged_estimator_spec.scaffold.get(d)) + self.assertEqual(created_estimator_specs[device_id].scaffold, + merged_estimator_spec.scaffold.get(d)) # Also test that we can undo the merge using select_device() - self.assertEquals(created_estimator_specs[device_id], - values.select_device(_device_str(device_id), - merged_estimator_spec)) + self.assertEqual(created_estimator_specs[device_id], + values.select_device(_device_str(device_id), + merged_estimator_spec)) class PerReplicaDatasetTest(test.TestCase): @@ -568,7 +571,184 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase): multi_worker_iterator.get_next() -class MirroredVariableTest(test.TestCase): +class InputIteratorTestBase(test.TestCase): + + def _test_iterator(self, input_type, dataset_fn, worker_device_pairs, + expected_values, sess=None, split_batch_by=None): + devices = nest.flatten([ds for _, ds in worker_device_pairs]) + + if input_type == "input_fn": + input_contexts = [ + distribute_lib.InputContext() for _ in worker_device_pairs] + input_fn = lambda _: dataset_fn() + iterator = values.InputFunctionIterator(input_fn, worker_device_pairs, + input_contexts) + else: + iterator = values.DatasetIterator(dataset_fn(), worker_device_pairs, + split_batch_by) + + evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) + + evaluate(control_flow_ops.group(iterator.initialize())) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = evaluate( + [values.select_device(d, next_element) for d in devices]) + self.assertAllEqual(expected_value, computed_value) + + with self.assertRaises(errors.OutOfRangeError): + next_element = iterator.get_next() + evaluate([values.select_device(d, next_element) for d in devices]) + + # After re-initializing the iterator, should be able to iterate again. + evaluate(control_flow_ops.group(iterator.initialize())) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = evaluate( + [values.select_device(d, next_element) for d in devices]) + self.assertAllEqual(expected_value, computed_value) + + +class InputIteratorSingleWorkerTest(InputIteratorTestBase, + parameterized.TestCase): + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["input_fn", "dataset"])) + def testOneDeviceCPU(self, input_type): + worker_device_pairs = [("", ["/device:CPU:0"])] + dataset_fn = lambda: dataset_ops.Dataset.range(10) + + expected_values = [[i] for i in range(10)] + + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["input_fn", "dataset"], + required_gpus=1)) + def testTwoDevicesOneGPUOneCPU(self, input_type): + worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] + dataset_fn = lambda: dataset_ops.Dataset.range(10) + + expected_values = [[i, i+1] for i in range(0, 10, 2)] + + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["input_fn", "dataset"], + required_gpus=1)) + def testTupleDataset(self, input_type): + worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] + def dataset_fn(): + dataset1 = dataset_ops.Dataset.range(10) + dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) + return dataset_ops.Dataset.zip((dataset1, dataset2)) + + expected_values = [[(i, i**2), (i+1, (i+1)**2)] for i in range(0, 10, 2)] + + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["input_fn", "dataset"], + required_gpus=1)) + def testUnevenDatasetBatches(self, input_type): + worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] + dataset_fn = lambda: dataset_ops.Dataset.range(11) + + expected_values = [[i, i+1] for i in range(0, 10, 2)] + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["dataset"], + split_batch_by=[None, 2], + required_gpus=1)) + def testBatchSplitting(self, input_type, split_batch_by): + worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] + batch_size = 10 + dataset_fn = lambda: dataset_ops.Dataset.range(100).batch(batch_size) + + updated_batch_size = ( + batch_size // split_batch_by if split_batch_by else batch_size) + expected_values = [[range(i, i+updated_batch_size), + range(i+updated_batch_size, i+2*updated_batch_size)] + for i in range(0, 100, updated_batch_size*2)] + + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values, sess=None, + split_batch_by=split_batch_by) + + +class InputIteratorMultiWorkerTest( + multi_worker_test_base.MultiWorkerTestBase, InputIteratorTestBase, + parameterized.TestCase): + + def _cpu_devices(self): + return [ + ("/job:worker/replica:0/task:0", + ["/job:worker/replica:0/task:0/device:CPU:0"]), + ("/job:worker/replica:0/task:1", + ["/job:worker/replica:0/task:1/device:CPU:0"])] + + def _cpu_and_one_gpu_devices(self): + return [ + ("/job:worker/replica:0/task:0", [ + "/job:worker/replica:0/task:0/device:GPU:0", + "/job:worker/replica:0/task:0/device:CPU:0" + ]), + ("/job:worker/replica:0/task:1", [ + "/job:worker/replica:0/task:1/device:GPU:0", + "/job:worker/replica:0/task:1/device:CPU:0" + ]) + ] + + @combinations.generate(combinations.combine( + mode=["graph"], + input_type=["input_fn", "dataset"])) + def testOneDevicePerWorker(self, input_type): + worker_devices = self._cpu_devices() + with context.graph_mode(), self.cached_session() as sess: + dataset_fn = lambda: dataset_ops.Dataset.range(4) + self._test_iterator(input_type, dataset_fn, worker_devices, + [[0, 0], [1, 1], [2, 2], [3, 3]], sess) + + @combinations.generate(combinations.combine( + mode=["graph"], + input_type=["input_fn", "dataset"], + required_gpus=1)) + def testTwoDevicesPerWorker(self, input_type): + worker_devices = self._cpu_and_one_gpu_devices() + with context.graph_mode(), self.cached_session() as sess: + dataset_fn = lambda: dataset_ops.Dataset.range(4) + self._test_iterator(input_type, dataset_fn, worker_devices, + [[0, 1, 0, 1], [2, 3, 2, 3]], sess) + + @combinations.generate(combinations.combine( + mode=["graph"], + input_type=["input_fn", "dataset"])) + def testTupleDataset(self, input_type): + worker_devices = self._cpu_devices() + with context.graph_mode(), self.cached_session() as sess: + def dataset_fn(): + dataset1 = dataset_ops.Dataset.range(4) + dataset2 = dataset_ops.Dataset.range(4).map(lambda x: x**2) + return dataset_ops.Dataset.zip((dataset1, dataset2)) + + expected_values = [[(i, i**2), (i, i**2)] for i in range(0, 4)] + self._test_iterator(input_type, dataset_fn, worker_devices, + expected_values, sess) + + +class MirroredVariableTest(test.TestCase, parameterized.TestCase): config = config_pb2.ConfigProto() config.allow_soft_placement = True @@ -580,9 +760,9 @@ class MirroredVariableTest(test.TestCase): v, _, mirrored = _make_mirrored() - self.assertEquals(v[0].name, mirrored.name) - self.assertEquals(v[0].dtype, mirrored.dtype) - self.assertEquals(v[0].shape, mirrored.shape) + self.assertEqual(v[0].name, mirrored.name) + self.assertEqual(v[0].dtype, mirrored.dtype) + self.assertEqual(v[0].shape, mirrored.shape) @test_util.run_in_graph_and_eager_modes(config=config) def testVariableOnAnotherDevice(self): @@ -592,9 +772,9 @@ class MirroredVariableTest(test.TestCase): mirrored = values.MirroredVariable(index, v, variable_scope.VariableAggregation.MEAN) - self.assertEquals(v.name, mirrored.name) - self.assertEquals(v.dtype, mirrored.dtype) - self.assertEquals(v.shape, mirrored.shape) + self.assertEqual(v.name, mirrored.name) + self.assertEqual(v.dtype, mirrored.dtype) + self.assertEqual(v.shape, mirrored.shape) def _assign_mirrored(self, devices, v, new): for d, var, n in zip(devices, v, new): @@ -714,14 +894,13 @@ class MirroredVariableTest(test.TestCase): save_path = self._save_normal() self._restore_mirrored(save_path) - @test_util.run_in_graph_and_eager_modes(config=config) - def testFetchAMirroredVariable(self): - if context.num_gpus() < 1 or context.executing_eagerly(): - self.skipTest("A GPU is not available for this test or it's eager mode.") - - with self.session( - graph=ops.Graph()) as sess, mirrored_strategy.MirroredStrategy( - ["/device:GPU:0"]).scope(): + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_one_gpu, + combinations.core_mirrored_strategy_with_one_gpu], + mode=["graph"])) + def testFetchAMirroredVariable(self, distribution): + with self.session(graph=ops.Graph()) as sess, distribution.scope(): with ops.device("/device:GPU:0"): v = variable_scope.get_variable( name="v", initializer=1., use_resource=True) @@ -747,7 +926,7 @@ def _make_replica_local(method): return v, replica_local -class ReplicaLocalVariableTest(test.TestCase): +class ReplicaLocalVariablePropertiesTest(test.TestCase): config = config_pb2.ConfigProto() config.allow_soft_placement = True @@ -756,15 +935,14 @@ class ReplicaLocalVariableTest(test.TestCase): def testProperties(self): if context.num_gpus() < 1 and context.executing_eagerly(): self.skipTest("A GPU is not available for this test in eager mode.") - v, replica_local = _make_replica_local( variable_scope.VariableAggregation.SUM) - self.assertEquals(v[0].name, replica_local.name) - self.assertEquals(v[0].dtype, replica_local.dtype) - self.assertEquals(v[0].shape, replica_local.shape) - self.assertEquals(variable_scope.VariableAggregation.SUM, - replica_local.aggregation) + self.assertEqual(v[0].name, replica_local.name) + self.assertEqual(v[0].dtype, replica_local.dtype) + self.assertEqual(v[0].shape, replica_local.shape) + self.assertEqual(variable_scope.VariableAggregation.SUM, + replica_local.aggregation) @test_util.run_in_graph_and_eager_modes(config=config) def testVariableOnAnotherDevice(self): @@ -774,11 +952,32 @@ class ReplicaLocalVariableTest(test.TestCase): replica_local = values.ReplicaLocalVariable( index, v, variable_scope.VariableAggregation.MEAN) - self.assertEquals(v.name, replica_local.name) - self.assertEquals(v.dtype, replica_local.dtype) - self.assertEquals(v.shape, replica_local.shape) - self.assertEquals(variable_scope.VariableAggregation.MEAN, - replica_local.aggregation) + self.assertEqual(v.name, replica_local.name) + self.assertEqual(v.dtype, replica_local.dtype) + self.assertEqual(v.shape, replica_local.shape) + self.assertEqual(variable_scope.VariableAggregation.MEAN, + replica_local.aggregation) + + def testTensorConversion(self): + with context.graph_mode(): + _, replica_local = _make_replica_local( + variable_scope.VariableAggregation.SUM) + converted = ops.internal_convert_to_tensor(replica_local, as_ref=False) + self.assertIsInstance(converted, ops.Tensor) + self.assertEqual(converted.dtype, replica_local.dtype) + + converted = ops.internal_convert_to_tensor(replica_local, as_ref=True) + # Resources variable are converted to tensors as well when as_ref is True. + self.assertIsInstance(converted, ops.Tensor) + self.assertEqual(converted.dtype, replica_local.dtype) + + +@combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=["graph", "eager"])) +class ReplicaLocalVariableTest(test.TestCase, parameterized.TestCase): def _assign_replica_local(self, devices, v, new): for d, var, n in zip(devices, v, new): @@ -795,22 +994,15 @@ class ReplicaLocalVariableTest(test.TestCase): save_path, _ = self._save_return_saver(sess, var) return save_path - def _dist_scope(self): - return mirrored_strategy.MirroredStrategy(_devices).scope() - - @test_util.run_in_graph_and_eager_modes(config=config) - def testSaveAndRestoreReplicaLocalSumOneGraph(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") - - with self.cached_session(config=self.config) as sess: + def testSaveAndRestoreReplicaLocalSumOneGraph(self, distribution): + with self.cached_session() as sess: v, replica_local = _make_replica_local( variable_scope.VariableAggregation.SUM) # Overwrite the initial values. self._assign_replica_local(_devices, v, [3., 4.]) - with self._dist_scope(): + with distribution.scope(): # Saves the current value of v[0] + v[1], 7. save_path, saver = self._save_return_saver(sess, replica_local) @@ -822,19 +1014,18 @@ class ReplicaLocalVariableTest(test.TestCase): saver.restore(sess, save_path) self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]])) - @test_util.run_in_graph_and_eager_modes(config=config) - def testSaveAndRestoreReplicaLocalMeanOneGraph(self): + def testSaveAndRestoreReplicaLocalMeanOneGraph(self, distribution): if context.num_gpus() < 1 and context.executing_eagerly(): self.skipTest("A GPU is not available for this test in eager mode.") - with self.cached_session(config=self.config) as sess: + with self.cached_session() as sess: v, replica_local = _make_replica_local( variable_scope.VariableAggregation.MEAN) # Overwrite the initial values. self._assign_replica_local(_devices, v, [3., 4.]) - with self._dist_scope(): + with distribution.scope(): # Saves the current value of (v[0] + v[1])/2, 3.5. save_path, saver = self._save_return_saver(sess, replica_local) @@ -845,7 +1036,7 @@ class ReplicaLocalVariableTest(test.TestCase): saver.restore(sess, save_path) self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]])) - def _save_replica_local_mean(self): + def _save_replica_local_mean(self, distribution): """Save variables with mirroring, returns save_path.""" with self.session(graph=ops.Graph()) as sess: v, replica_local = _make_replica_local( @@ -854,7 +1045,7 @@ class ReplicaLocalVariableTest(test.TestCase): # Overwrite the initial values. self._assign_replica_local(_devices, v, [3., 4.]) - with self._dist_scope(): + with distribution.scope(): # Saves the current value of (v[0] + v[1])/2, 3.5 save_path = self._save(sess, replica_local) @@ -862,7 +1053,7 @@ class ReplicaLocalVariableTest(test.TestCase): self._assign_replica_local(_devices, v, [5., 6.]) return save_path - def _save_replica_local_sum(self): + def _save_replica_local_sum(self, distribution): """Save variables with mirroring, returns save_path.""" with self.session(graph=ops.Graph()) as sess: v, replica_local = _make_replica_local("sum") @@ -870,7 +1061,7 @@ class ReplicaLocalVariableTest(test.TestCase): # Overwrite the initial values. self._assign_replica_local(_devices, v, [1.5, 2.]) - with self._dist_scope(): + with distribution.scope(): # Saves the current value of v[0] + v[1], 3.5 save_path = self._save(sess, replica_local) @@ -908,7 +1099,7 @@ class ReplicaLocalVariableTest(test.TestCase): saver.restore(sess, save_path) self.assertEqual(3.5, self.evaluate(var)) - def _restore_replica_local_mean(self, save_path): + def _restore_replica_local_mean(self, save_path, distribution): """Restore to variables with mirroring in a fresh graph.""" with self.session(graph=ops.Graph()) as sess: v, replica_local = _make_replica_local( @@ -917,13 +1108,13 @@ class ReplicaLocalVariableTest(test.TestCase): # Overwrite the initial values. self._assign_replica_local(_devices, v, [7., 8.]) - with self._dist_scope(): + with distribution.scope(): # Restores the saved value of 3.5 to both variables. saver = saver_lib.Saver(var_list=[replica_local]) saver.restore(sess, save_path) self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]])) - def _restore_replica_local_sum(self, save_path): + def _restore_replica_local_sum(self, save_path, distribution): """Restore to variables with mirroring in a fresh graph.""" with self.session(graph=ops.Graph()) as sess: v, replica_local = _make_replica_local( @@ -932,72 +1123,35 @@ class ReplicaLocalVariableTest(test.TestCase): # Overwrite the initial values. self._assign_replica_local(_devices, v, [7., 8.]) - with self._dist_scope(): + with distribution.scope(): # Restores the saved value of 3.5 to both variables. saver = saver_lib.Saver(var_list=[replica_local]) saver.restore(sess, save_path) self.assertEqual([1.75, 1.75], self.evaluate([v[0], v[1]])) - @test_util.run_in_graph_and_eager_modes(config=config) - def testSaveReplicaLocalRestoreReplicaLocalMean(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") + def testSaveReplicaLocalRestoreReplicaLocalMean(self, distribution): + save_path = self._save_replica_local_mean(distribution) + self._restore_replica_local_mean(save_path, distribution) - save_path = self._save_replica_local_mean() - self._restore_replica_local_mean(save_path) + def testSaveReplicaLocalRestoreReplicaLocalSum(self, distribution): + save_path = self._save_replica_local_sum(distribution) + self._restore_replica_local_sum(save_path, distribution) - @test_util.run_in_graph_and_eager_modes(config=config) - def testSaveReplicaLocalRestoreReplicaLocalSum(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") - - save_path = self._save_replica_local_sum() - self._restore_replica_local_sum(save_path) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testSaveReplicaLocalMeanRestoreNormal(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") - - save_path = self._save_replica_local_mean() + def testSaveReplicaLocalMeanRestoreNormal(self, distribution): + save_path = self._save_replica_local_mean(distribution) self._restore_normal(save_path) - @test_util.run_in_graph_and_eager_modes(config=config) - def testSaveReplicaLocalSumRestoreNormal(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") - - save_path = self._save_replica_local_sum() + def testSaveReplicaLocalSumRestoreNormal(self, distribution): + save_path = self._save_replica_local_sum(distribution) self._restore_normal(save_path) - @test_util.run_in_graph_and_eager_modes(config=config) - def testSaveNormalRestoreReplicaLocalMean(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") - + def testSaveNormalRestoreReplicaLocalMean(self, distribution): save_path = self._save_normal() - self._restore_replica_local_mean(save_path) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testSaveNormalRestoreReplicaLocalSum(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") + self._restore_replica_local_mean(save_path, distribution) + def testSaveNormalRestoreReplicaLocalSum(self, distribution): save_path = self._save_normal() - self._restore_replica_local_sum(save_path) - - def testTensorConversion(self): - with context.graph_mode(): - _, replica_local = _make_replica_local( - variable_scope.VariableAggregation.SUM) - converted = ops.internal_convert_to_tensor(replica_local, as_ref=False) - self.assertIsInstance(converted, ops.Tensor) - self.assertEqual(converted.dtype, replica_local.dtype) - - converted = ops.internal_convert_to_tensor(replica_local, as_ref=True) - # Resources variable are converted to tensors as well when as_ref is True. - self.assertIsInstance(converted, ops.Tensor) - self.assertEqual(converted.dtype, replica_local.dtype) + self._restore_replica_local_sum(save_path, distribution) if __name__ == "__main__": diff --git a/tensorflow/contrib/distribute/python/warm_starting_util_test.py b/tensorflow/contrib/distribute/python/warm_starting_util_test.py index 5d57d144c1c..b0bcf9b1745 100644 --- a/tensorflow/contrib/distribute/python/warm_starting_util_test.py +++ b/tensorflow/contrib/distribute/python/warm_starting_util_test.py @@ -44,7 +44,9 @@ class WarmStartingUtilWithDistributionStrategyTest( distribution=[combinations.default_strategy, combinations.one_device_strategy, combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus], + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus], save_with_distribution=[True, False], restore_with_distribution=[True, False], mode=["graph"])) diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 60f6b90edcb..3079175015a 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -72,7 +72,6 @@ py_library( "//tensorflow/python:nn", "//tensorflow/python:nn_ops", "//tensorflow/python:random_ops", - "//tensorflow/python:spectral_ops", "//tensorflow/python:state_ops", "//tensorflow/python:tensor_util", "//tensorflow/python:util", @@ -80,6 +79,7 @@ py_library( "//tensorflow/python:variables", "//tensorflow/python/ops/distributions", "//tensorflow/python/ops/linalg", + "//tensorflow/python/ops/signal", "//third_party/py/numpy", "@six_archive//:six", ], diff --git a/tensorflow/contrib/distributions/python/kernel_tests/normal_conjugate_posteriors_test.py b/tensorflow/contrib/distributions/python/kernel_tests/normal_conjugate_posteriors_test.py index 29eeaf43c51..ab3c07172a6 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/normal_conjugate_posteriors_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/normal_conjugate_posteriors_test.py @@ -82,7 +82,7 @@ class NormalTest(test.TestCase): x = constant_op.constant( [[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], [2.5, -2.5, -4.0, 0.0, 1.0, -2.0]], dtype=dtypes.float32) - s = math_ops.reduce_sum(x, reduction_indices=[1]) + s = math_ops.reduce_sum(x, axis=[1]) x = array_ops.transpose(x) # Reshape to shape (6, 2) n = constant_op.constant([6] * 2) prior = distributions.Normal(loc=mu0, scale=sigma0) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py b/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py index a60056c444a..cdee30bbc42 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py @@ -147,14 +147,13 @@ class WishartCholeskyTest(test.TestCase): x = chol_w.sample(10000, seed=42) self.assertAllEqual((10000, 3, 3), x.get_shape()) - moment1_estimate = math_ops.reduce_mean(x, reduction_indices=[0]).eval() + moment1_estimate = math_ops.reduce_mean(x, axis=[0]).eval() self.assertAllClose(chol_w.mean().eval(), moment1_estimate, rtol=0.05) # The Variance estimate uses the squares rather than outer-products # because Wishart.Variance is the diagonal of the Wishart covariance # matrix. - variance_estimate = (math_ops.reduce_mean( - math_ops.square(x), reduction_indices=[0]) - + variance_estimate = (math_ops.reduce_mean(math_ops.square(x), axis=[0]) - math_ops.square(moment1_estimate)).eval() self.assertAllClose( chol_w.variance().eval(), variance_estimate, rtol=0.05) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py index 15c241d5d7a..74765f19e58 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py @@ -168,7 +168,7 @@ class SoftmaxCentered(bijector.Bijector): # log_normalization = 1 + reduce_sum(exp(logits)) # -log_normalization + reduce_sum(logits - log_normalization) log_normalization = nn_ops.softplus( - math_ops.reduce_logsumexp(x, axis=-1, keep_dims=True)) + math_ops.reduce_logsumexp(x, axis=-1, keepdims=True)) return array_ops.squeeze( (-log_normalization + math_ops.reduce_sum( x - log_normalization, axis=-1, keepdims=True)), axis=-1) diff --git a/tensorflow/contrib/distributions/python/ops/sample_stats.py b/tensorflow/contrib/distributions/python/ops/sample_stats.py index aa680a92be6..978e627d663 100644 --- a/tensorflow/contrib/distributions/python/ops/sample_stats.py +++ b/tensorflow/contrib/distributions/python/ops/sample_stats.py @@ -29,8 +29,8 @@ from tensorflow.python.ops import clip_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import spectral_ops from tensorflow.python.ops.distributions import util +from tensorflow.python.ops.signal import fft_ops __all__ = [ "auto_correlation", @@ -157,11 +157,11 @@ def auto_correlation( dtype.real_dtype.as_numpy_dtype(0.)) # Autocorrelation is IFFT of power-spectral density (up to some scaling). - fft_x_rotated_pad = spectral_ops.fft(x_rotated_pad) + fft_x_rotated_pad = fft_ops.fft(x_rotated_pad) spectral_density = fft_x_rotated_pad * math_ops.conj(fft_x_rotated_pad) # shifted_product is R[m] from above detailed explanation. # It is the inner product sum_n X[n] * Conj(X[n - m]). - shifted_product = spectral_ops.ifft(spectral_density) + shifted_product = fft_ops.ifft(spectral_density) # Cast back to real-valued if x was real to begin with. shifted_product = math_ops.cast(shifted_product, dtype) diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py index 3aed121233b..34614b86a75 100644 --- a/tensorflow/contrib/eager/python/datasets.py +++ b/tensorflow/contrib/eager/python/datasets.py @@ -52,12 +52,6 @@ class Iterator(iterator_ops.EagerIterator): TypeError: If `dataset` is an unsupported type. RuntimeError: When invoked without eager execution enabled. """ - if isinstance(dataset, prefetching_ops._PrefetchToDeviceDataset): # pylint: disable=protected-access - raise TypeError( - "`tf.data.experimental.prefetch_to_device()` is not compatible with " - "`tf.contrib.eager.Iterator`. Use `for ... in dataset:` to iterate " - "over the dataset instead.") - if not context.context().device_spec.device_type: is_remote_device = False else: diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py index 6a508fc6ba9..257d02057ae 100644 --- a/tensorflow/contrib/eager/python/datasets_test.py +++ b/tensorflow/contrib/eager/python/datasets_test.py @@ -26,7 +26,6 @@ import numpy as np from tensorflow.contrib import lookup from tensorflow.contrib.eager.python import datasets from tensorflow.python.data import Dataset -from tensorflow.python.data.experimental.ops import prefetching_ops from tensorflow.python.data.experimental.ops import threadpool from tensorflow.python.data.experimental.ops import unique from tensorflow.python.eager import test @@ -208,18 +207,6 @@ class IteratorTest(test.TestCase): y = math_ops.add(x, x) self.assertAllEqual([0., 2.], y.numpy()) - def testTensorsExplicitPrefetchToDevice(self): - ds = Dataset.from_tensor_slices([0., 1.]) - ds = ds.apply(prefetching_ops.prefetch_to_device(test.gpu_device_name())) - - with self.assertRaisesRegexp(TypeError, 'prefetch_to_device'): - datasets.Iterator(ds) - - for i, x in enumerate(ds): - with ops.device(test.gpu_device_name()): - x = math_ops.add(x, x) - self.assertEqual(float(i) + float(i), x.numpy()) - def testOverrideThreadPool(self): def get_thread_id(_): diff --git a/tensorflow/contrib/eager/python/evaluator.py b/tensorflow/contrib/eager/python/evaluator.py index 7949a3f6da2..51443d24829 100644 --- a/tensorflow/contrib/eager/python/evaluator.py +++ b/tensorflow/contrib/eager/python/evaluator.py @@ -22,6 +22,7 @@ import six from tensorflow.contrib.eager.python import datasets from tensorflow.contrib.eager.python import metrics +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context from tensorflow.python.eager import function from tensorflow.python.framework import errors_impl @@ -164,8 +165,8 @@ class Evaluator(object): self.__call__(example, *args, **kwargs) return self.all_metric_results(summary_logdir) # Graph construction - call_op = self.__call__(dataset.make_one_shot_iterator().get_next(), *args, - **kwargs) + call_op = self.__call__( + dataset_ops.make_one_shot_iterator(dataset).get_next(), *args, **kwargs) init_op = self.init_variables() results_op = self.all_metric_results(summary_logdir) return (init_op, call_op, results_op) diff --git a/tensorflow/contrib/eager/python/examples/densenet/BUILD b/tensorflow/contrib/eager/python/examples/densenet/BUILD index 2dc196f550a..e2154fcc5fc 100644 --- a/tensorflow/contrib/eager/python/examples/densenet/BUILD +++ b/tensorflow/contrib/eager/python/examples/densenet/BUILD @@ -3,6 +3,7 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//tensorflow:internal"]) load("//tensorflow:tensorflow.bzl", "cuda_py_test") +load("//tensorflow:tensorflow.bzl", "py_binary") py_binary( name = "densenet", diff --git a/tensorflow/contrib/eager/python/examples/densenet/densenet_graph_test.py b/tensorflow/contrib/eager/python/examples/densenet/densenet_graph_test.py index 4b3cb624bc9..24f6b007b52 100644 --- a/tensorflow/contrib/eager/python/examples/densenet/densenet_graph_test.py +++ b/tensorflow/contrib/eager/python/examples/densenet/densenet_graph_test.py @@ -119,7 +119,8 @@ class DensenetBenchmark(tf.test.Benchmark): with tf.Graph().as_default(): np_images, np_labels = random_batch(batch_size) dataset = tf.data.Dataset.from_tensors((np_images, np_labels)).repeat() - (images, labels) = dataset.make_one_shot_iterator().get_next() + (images, labels) = tf.compat.v1.data.make_one_shot_iterator( + dataset).get_next() model = densenet.DenseNet(self.depth, self.growth_rate, self.num_blocks, self.output_classes, diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist_graph_test.py b/tensorflow/contrib/eager/python/examples/gan/mnist_graph_test.py index 12b39b0cde4..e73841fbf72 100644 --- a/tensorflow/contrib/eager/python/examples/gan/mnist_graph_test.py +++ b/tensorflow/contrib/eager/python/examples/gan/mnist_graph_test.py @@ -42,7 +42,8 @@ class MnistGraphGanBenchmark(tf.test.Benchmark): # Generate some random data. images_data = np.random.randn(batch_size, 784).astype(np.float32) dataset = tf.data.Dataset.from_tensors(images_data) - images = dataset.repeat().make_one_shot_iterator().get_next() + images = tf.compat.v1.data.make_one_shot_iterator( + dataset.repeat()).get_next() # Create the models and optimizers generator = mnist.Generator(data_format()) diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb index ca27a85a229..1a08cc0fd06 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb @@ -470,7 +470,7 @@ "\n", " if epoch % 1 == 0:\n", " loss = tfe.metrics.Mean()\n", - " for test_x in test_dataset.make_one_shot_iterator():\n", + " for test_x in test_dataset:\n", " loss(compute_loss(model, test_x))\n", " elbo = -loss.result()\n", " display.clear_output(wait=False)\n", diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb index 3acecd283cd..12c5eff2b4a 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb @@ -1,1184 +1,1174 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "image_captioning_with_attention.ipynb", - "version": "0.3.2", - "views": {}, - "default_view": {}, - "provenance": [ - { - "file_id": "1HI8OK2sMjcx9CTWVn0122QAHOuXaOaMg", - "timestamp": 1530222436922 - } - ], - "private_outputs": true, - "collapsed_sections": [], - "toc_visible": true - }, - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "accelerator": "GPU" + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "K2s1A9eLRPEj" + }, + "source": [ + "##### Copyright 2018 The TensorFlow Authors.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\").\n" + ] }, - "cells": [ - { - "metadata": { - "id": "K2s1A9eLRPEj", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "##### Copyright 2018 The TensorFlow Authors.\n", - "\n", - "Licensed under the Apache License, Version 2.0 (the \"License\").\n" - ] + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Cffg2i257iMS" + }, + "source": [ + "# Image Captioning with Attention\n", + "\n", + "
\n", + "\n", + " Run in Google Colab \n", + "\n", + "View source on GitHub
" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "QASbY_HGo4Lq" + }, + "source": [ + "Image captioning is the task of generating a caption for an image. Given an image like this:\n", + "\n", + "![Man Surfing](https://tensorflow.org/images/surf.jpg) \n", + "\n", + "[Image Source](https://commons.wikimedia.org/wiki/Surfing#/media/File:Surfing_in_Hawaii.jpg), License: Public Domain\n", + "\n", + "Our goal is to generate a caption, such as \"a surfer riding on a wave\". Here, we'll use an attention-based model. This enables us to see which parts of the image the model focuses on as it generates a caption.\n", + "\n", + "![Prediction](https://tensorflow.org/images/imcap_prediction.png)\n", + "\n", + "This model architecture below is similar to [Show, Attend and Tell: Neural Image Caption Generation with Visual Attention](https://arxiv.org/abs/1502.03044). \n", + "\n", + "The code uses [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager), which you can learn more about in the linked guides.\n", + "\n", + "This notebook is an end-to-end example. If you run it, it will download the [MS-COCO](http://cocodataset.org/#home) dataset, preprocess and cache a subset of the images using Inception V3, train an encoder-decoder model, and use it to generate captions on new images.\n", + "\n", + "The code requires TensorFlow version >=1.9. If you're running this in [Colab]()\n", + "\n", + "In this example, we're training on a relatively small amount of data as an example. On a single P100 GPU, this example will take about ~2 hours to train. We train on the first 30,000 captions (corresponding to about ~20,000 images depending on shuffling, as there are multiple captions per image in the dataset)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "Cffg2i257iMS", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "# Image Captioning with Attention\n", - "\n", - "
\n", - "\n", - " Run in Google Colab \n", - "\n", - "View source on GitHub
" - ] + "colab_type": "code", + "id": "U8l4RJ0XRPEm" + }, + "outputs": [], + "source": [ + "# Import TensorFlow and enable eager execution\n", + "# This code requires TensorFlow version >=1.9\n", + "import tensorflow as tf\n", + "tf.enable_eager_execution()\n", + "\n", + "# We'll generate plots of attention in order to see which parts of an image\n", + "# our model focuses on during captioning\n", + "import matplotlib.pyplot as plt\n", + "\n", + "# Scikit-learn includes many helpful utilities\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.utils import shuffle\n", + "\n", + "import re\n", + "import numpy as np\n", + "import os\n", + "import time\n", + "import json\n", + "from glob import glob\n", + "from PIL import Image\n", + "import pickle" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "b6qbGw8MRPE5" + }, + "source": [ + "## Download and prepare the MS-COCO dataset\n", + "\n", + "We will use the [MS-COCO dataset](http://cocodataset.org/#home) to train our model. This dataset contains >82,000 images, each of which has been annotated with at least 5 different captions. The code below will download and extract the dataset automatically. \n", + "\n", + "**Caution: large download ahead**. We'll use the training set, it's a 13GB file." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "QASbY_HGo4Lq", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "Image captioning is the task of generating a caption for an image. Given an image like this:\n", - "\n", - "![Man Surfing](https://tensorflow.org/images/surf.jpg) \n", - "\n", - "[Image Source](https://commons.wikimedia.org/wiki/Surfing#/media/File:Surfing_in_Hawaii.jpg), License: Public Domain\n", - "\n", - "Our goal is to generate a caption, such as \"a surfer riding on a wave\". Here, we'll use an attention-based model. This enables us to see which parts of the image the model focuses on as it generates a caption.\n", - "\n", - "![Prediction](https://tensorflow.org/images/imcap_prediction.png)\n", - "\n", - "This model architecture below is similar to [Show, Attend and Tell: Neural Image Caption Generation with Visual Attention](https://arxiv.org/abs/1502.03044). \n", - "\n", - "The code uses [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager), which you can learn more about in the linked guides.\n", - "\n", - "This notebook is an end-to-end example. If you run it, it will download the [MS-COCO](http://cocodataset.org/#home) dataset, preprocess and cache a subset of the images using Inception V3, train an encoder-decoder model, and use it to generate captions on new images.\n", - "\n", - "The code requires TensorFlow version >=1.9. If you're running this in [Colab]()\n", - "\n", - "In this example, we're training on a relatively small amount of data as an example. On a single P100 GPU, this example will take about ~2 hours to train. We train on the first 30,000 captions (corresponding to about ~20,000 images depending on shuffling, as there are multiple captions per image in the dataset)\n" - ] + "colab_type": "code", + "id": "krQuPYTtRPE7" + }, + "outputs": [], + "source": [ + "annotation_zip = tf.keras.utils.get_file('captions.zip', \n", + " cache_subdir=os.path.abspath('.'),\n", + " origin = 'http://images.cocodataset.org/annotations/annotations_trainval2014.zip',\n", + " extract = True)\n", + "annotation_file = os.path.dirname(annotation_zip)+'/annotations/captions_train2014.json'\n", + "\n", + "name_of_zip = 'train2014.zip'\n", + "if not os.path.exists(os.path.abspath('.') + '/' + name_of_zip):\n", + " image_zip = tf.keras.utils.get_file(name_of_zip, \n", + " cache_subdir=os.path.abspath('.'),\n", + " origin = 'http://images.cocodataset.org/zips/train2014.zip',\n", + " extract = True)\n", + " PATH = os.path.dirname(image_zip)+'/train2014/'\n", + "else:\n", + " PATH = os.path.abspath('.')+'/train2014/'" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "aANEzb5WwSzg" + }, + "source": [ + "## Optionally, limit the size of the training set for faster training\n", + "For this example, we'll select a subset of 30,000 captions and use these and the corresponding images to train our model. As always, captioning quality will improve if you choose to use more data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "U8l4RJ0XRPEm", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# Import TensorFlow and enable eager execution\n", - "# This code requires TensorFlow version >=1.9\n", - "import tensorflow as tf\n", - "tf.enable_eager_execution()\n", - "\n", - "# We'll generate plots of attention in order to see which parts of an image\n", - "# our model focuses on during captioning\n", - "import matplotlib.pyplot as plt\n", - "\n", - "# Scikit-learn includes many helpful utilities\n", - "from sklearn.model_selection import train_test_split\n", - "from sklearn.utils import shuffle\n", - "\n", - "import re\n", - "import numpy as np\n", - "import os\n", - "import time\n", - "import json\n", - "from glob import glob\n", - "from PIL import Image\n", - "import pickle" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "4G3b8x8_RPFD" + }, + "outputs": [], + "source": [ + "# read the json file\n", + "with open(annotation_file, 'r') as f:\n", + " annotations = json.load(f)\n", + "\n", + "# storing the captions and the image name in vectors\n", + "all_captions = []\n", + "all_img_name_vector = []\n", + "\n", + "for annot in annotations['annotations']:\n", + " caption = ' ' + annot['caption'] + ' '\n", + " image_id = annot['image_id']\n", + " full_coco_image_path = PATH + 'COCO_train2014_' + '%012d.jpg' % (image_id)\n", + " \n", + " all_img_name_vector.append(full_coco_image_path)\n", + " all_captions.append(caption)\n", + "\n", + "# shuffling the captions and image_names together\n", + "# setting a random state\n", + "train_captions, img_name_vector = shuffle(all_captions,\n", + " all_img_name_vector,\n", + " random_state=1)\n", + "\n", + "# selecting the first 30000 captions from the shuffled set\n", + "num_examples = 30000\n", + "train_captions = train_captions[:num_examples]\n", + "img_name_vector = img_name_vector[:num_examples]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "b6qbGw8MRPE5", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Download and prepare the MS-COCO dataset\n", - "\n", - "We will use the [MS-COCO dataset](http://cocodataset.org/#home) to train our model. This dataset contains >82,000 images, each of which has been annotated with at least 5 different captions. The code below will download and extract the dataset automatically. \n", - "\n", - "**Caution: large download ahead**. We'll use the training set, it's a 13GB file." - ] + "colab_type": "code", + "id": "mPBMgK34RPFL" + }, + "outputs": [], + "source": [ + "len(train_captions), len(all_captions)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "8cSW4u-ORPFQ" + }, + "source": [ + "## Preprocess the images using InceptionV3\n", + "Next, we will use InceptionV3 (pretrained on Imagenet) to classify each image. We will extract features from the last convolutional layer. \n", + "\n", + "First, we will need to convert the images into the format inceptionV3 expects by:\n", + "* Resizing the image to (299, 299)\n", + "* Using the [preprocess_input](https://www.tensorflow.org/api_docs/python/tf/keras/applications/inception_v3/preprocess_input) method to place the pixels in the range of -1 to 1 (to match the format of the images used to train InceptionV3)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "krQuPYTtRPE7", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "annotation_zip = tf.keras.utils.get_file('captions.zip', \n", - " cache_subdir=os.path.abspath('.'),\n", - " origin = 'http://images.cocodataset.org/annotations/annotations_trainval2014.zip',\n", - " extract = True)\n", - "annotation_file = os.path.dirname(annotation_zip)+'/annotations/captions_train2014.json'\n", - "\n", - "name_of_zip = 'train2014.zip'\n", - "if not os.path.exists(os.path.abspath('.') + '/' + name_of_zip):\n", - " image_zip = tf.keras.utils.get_file(name_of_zip, \n", - " cache_subdir=os.path.abspath('.'),\n", - " origin = 'http://images.cocodataset.org/zips/train2014.zip',\n", - " extract = True)\n", - " PATH = os.path.dirname(image_zip)+'/train2014/'\n", - "else:\n", - " PATH = os.path.abspath('.')+'/train2014/'" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "zXR0217aRPFR" + }, + "outputs": [], + "source": [ + "def load_image(image_path):\n", + " img = tf.read_file(image_path)\n", + " img = tf.image.decode_jpeg(img, channels=3)\n", + " img = tf.image.resize_images(img, (299, 299))\n", + " img = tf.keras.applications.inception_v3.preprocess_input(img)\n", + " return img, image_path" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "MDvIu4sXRPFV" + }, + "source": [ + "## Initialize InceptionV3 and load the pretrained Imagenet weights\n", + "\n", + "To do so, we'll create a tf.keras model where the output layer is the last convolutional layer in the InceptionV3 architecture. \n", + "* Each image is forwarded through the network and the vector that we get at the end is stored in a dictionary (image_name --> feature_vector). \n", + "* We use the last convolutional layer because we are using attention in this example. The shape of the output of this layer is ```8x8x2048```. \n", + "* We avoid doing this during training so it does not become a bottleneck. \n", + "* After all the images are passed through the network, we pickle the dictionary and save it to disk." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "aANEzb5WwSzg", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Optionally, limit the size of the training set for faster training\n", - "For this example, we'll select a subset of 30,000 captions and use these and the corresponding images to train our model. As always, captioning quality will improve if you choose to use more data." - ] + "colab_type": "code", + "id": "RD3vW4SsRPFW" + }, + "outputs": [], + "source": [ + "image_model = tf.keras.applications.InceptionV3(include_top=False, \n", + " weights='imagenet')\n", + "new_input = image_model.input\n", + "hidden_layer = image_model.layers[-1].output\n", + "\n", + "image_features_extract_model = tf.keras.Model(new_input, hidden_layer)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "rERqlR3WRPGO" + }, + "source": [ + "## Caching the features extracted from InceptionV3\n", + "\n", + "We will pre-process each image with InceptionV3 and cache the output to disk. Caching the output in RAM would be faster but memory intensive, requiring 8 \\* 8 \\* 2048 floats per image. At the time of writing, this would exceed the memory limitations of Colab (although these may change, an instance appears to have about 12GB of memory currently). \n", + "\n", + "Performance could be improved with a more sophisticated caching strategy (e.g., by sharding the images to reduce random access disk I/O) at the cost of more code.\n", + "\n", + "This will take about 10 minutes to run in Colab with a GPU. If you'd like to see a progress bar, you could: install [tqdm](https://github.com/tqdm/tqdm) (```!pip install tqdm```), then change this line: \n", + "\n", + "```for img, path in image_dataset:``` \n", + "\n", + "to:\n", + "\n", + "```for img, path in tqdm(image_dataset):```." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "4G3b8x8_RPFD", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# read the json file\n", - "with open(annotation_file, 'r') as f:\n", - " annotations = json.load(f)\n", - "\n", - "# storing the captions and the image name in vectors\n", - "all_captions = []\n", - "all_img_name_vector = []\n", - "\n", - "for annot in annotations['annotations']:\n", - " caption = ' ' + annot['caption'] + ' '\n", - " image_id = annot['image_id']\n", - " full_coco_image_path = PATH + 'COCO_train2014_' + '%012d.jpg' % (image_id)\n", - " \n", - " all_img_name_vector.append(full_coco_image_path)\n", - " all_captions.append(caption)\n", - "\n", - "# shuffling the captions and image_names together\n", - "# setting a random state\n", - "train_captions, img_name_vector = shuffle(all_captions,\n", - " all_img_name_vector,\n", - " random_state=1)\n", - "\n", - "# selecting the first 30000 captions from the shuffled set\n", - "num_examples = 30000\n", - "train_captions = train_captions[:num_examples]\n", - "img_name_vector = img_name_vector[:num_examples]" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "Dx_fvbVgRPGQ" + }, + "outputs": [], + "source": [ + "# getting the unique images\n", + "encode_train = sorted(set(img_name_vector))\n", + "\n", + "# feel free to change the batch_size according to your system configuration\n", + "image_dataset = tf.data.Dataset.from_tensor_slices(\n", + " encode_train).map(load_image).batch(16)\n", + "\n", + "for img, path in image_dataset:\n", + " batch_features = image_features_extract_model(img)\n", + " batch_features = tf.reshape(batch_features, \n", + " (batch_features.shape[0], -1, batch_features.shape[3]))\n", + "\n", + " for bf, p in zip(batch_features, path):\n", + " path_of_feature = p.numpy().decode(\"utf-8\")\n", + " np.save(path_of_feature, bf.numpy())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "nyqH3zFwRPFi" + }, + "source": [ + "## Preprocess and tokenize the captions\n", + "\n", + "* First, we'll tokenize the captions (e.g., by splitting on spaces). This will give us a vocabulary of all the unique words in the data (e.g., \"surfing\", \"football\", etc).\n", + "* Next, we'll limit the vocabulary size to the top 5,000 words to save memory. We'll replace all other words with the token \"UNK\" (for unknown).\n", + "* Finally, we create a word --> index mapping and vice-versa.\n", + "* We will then pad all sequences to the be same length as the longest one. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "mPBMgK34RPFL", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "len(train_captions), len(all_captions)" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "HZfK8RhQRPFj" + }, + "outputs": [], + "source": [ + "# This will find the maximum length of any caption in our dataset\n", + "def calc_max_length(tensor):\n", + " return max(len(t) for t in tensor)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "8cSW4u-ORPFQ", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Preprocess the images using InceptionV3\n", - "Next, we will use InceptionV3 (pretrained on Imagenet) to classify each image. We will extract features from the last convolutional layer. \n", - "\n", - "First, we will need to convert the images into the format inceptionV3 expects by:\n", - "* Resizing the image to (299, 299)\n", - "* Using the [preprocess_input](https://www.tensorflow.org/api_docs/python/tf/keras/applications/inception_v3/preprocess_input) method to place the pixels in the range of -1 to 1 (to match the format of the images used to train InceptionV3)." - ] + "colab_type": "code", + "id": "oJGE34aiRPFo" + }, + "outputs": [], + "source": [ + "# The steps above is a general process of dealing with text processing\n", + "\n", + "# choosing the top 5000 words from the vocabulary\n", + "top_k = 5000\n", + "tokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=top_k, \n", + " oov_token=\"\", \n", + " filters='!\"#$%&()*+.,-/:;=?@[\\]^_`{|}~ ')\n", + "tokenizer.fit_on_texts(train_captions)\n", + "train_seqs = tokenizer.texts_to_sequences(train_captions)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "zXR0217aRPFR", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "def load_image(image_path):\n", - " img = tf.read_file(image_path)\n", - " img = tf.image.decode_jpeg(img, channels=3)\n", - " img = tf.image.resize_images(img, (299, 299))\n", - " img = tf.keras.applications.inception_v3.preprocess_input(img)\n", - " return img, image_path" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "8Q44tNQVRPFt" + }, + "outputs": [], + "source": [ + "tokenizer.word_index[''] = 0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "MDvIu4sXRPFV", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Initialize InceptionV3 and load the pretrained Imagenet weights\n", - "\n", - "To do so, we'll create a tf.keras model where the output layer is the last convolutional layer in the InceptionV3 architecture. \n", - "* Each image is forwarded through the network and the vector that we get at the end is stored in a dictionary (image_name --> feature_vector). \n", - "* We use the last convolutional layer because we are using attention in this example. The shape of the output of this layer is ```8x8x2048```. \n", - "* We avoid doing this during training so it does not become a bottleneck. \n", - "* After all the images are passed through the network, we pickle the dictionary and save it to disk." - ] + "colab_type": "code", + "id": "0fpJb5ojRPFv" + }, + "outputs": [], + "source": [ + "# creating the tokenized vectors\n", + "train_seqs = tokenizer.texts_to_sequences(train_captions)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "RD3vW4SsRPFW", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "image_model = tf.keras.applications.InceptionV3(include_top=False, \n", - " weights='imagenet')\n", - "new_input = image_model.input\n", - "hidden_layer = image_model.layers[-1].output\n", - "\n", - "image_features_extract_model = tf.keras.Model(new_input, hidden_layer)" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "AidglIZVRPF4" + }, + "outputs": [], + "source": [ + "# padding each vector to the max_length of the captions\n", + "# if the max_length parameter is not provided, pad_sequences calculates that automatically\n", + "cap_vector = tf.keras.preprocessing.sequence.pad_sequences(train_seqs, padding='post')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "rERqlR3WRPGO", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Caching the features extracted from InceptionV3\n", - "\n", - "We will pre-process each image with InceptionV3 and cache the output to disk. Caching the output in RAM would be faster but memory intensive, requiring 8 \\* 8 \\* 2048 floats per image. At the time of writing, this would exceed the memory limitations of Colab (although these may change, an instance appears to have about 12GB of memory currently). \n", - "\n", - "Performance could be improved with a more sophisticated caching strategy (e.g., by sharding the images to reduce random access disk I/O) at the cost of more code.\n", - "\n", - "This will take about 10 minutes to run in Colab with a GPU. If you'd like to see a progress bar, you could: install [tqdm](https://github.com/tqdm/tqdm) (```!pip install tqdm```), then change this line: \n", - "\n", - "```for img, path in image_dataset:``` \n", - "\n", - "to:\n", - "\n", - "```for img, path in tqdm(image_dataset):```." - ] + "colab_type": "code", + "id": "gL0wkttkRPGA" + }, + "outputs": [], + "source": [ + "# calculating the max_length \n", + "# used to store the attention weights\n", + "max_length = calc_max_length(train_seqs)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "M3CD75nDpvTI" + }, + "source": [ + "## Split the data into training and testing" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "Dx_fvbVgRPGQ", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# getting the unique images\n", - "encode_train = sorted(set(img_name_vector))\n", - "\n", - "# feel free to change the batch_size according to your system configuration\n", - "image_dataset = tf.data.Dataset.from_tensor_slices(\n", - " encode_train).map(load_image).batch(16)\n", - "\n", - "for img, path in image_dataset:\n", - " batch_features = image_features_extract_model(img)\n", - " batch_features = tf.reshape(batch_features, \n", - " (batch_features.shape[0], -1, batch_features.shape[3]))\n", - "\n", - " for bf, p in zip(batch_features, path):\n", - " path_of_feature = p.numpy().decode(\"utf-8\")\n", - " np.save(path_of_feature, bf.numpy())" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "iS7DDMszRPGF" + }, + "outputs": [], + "source": [ + "# Create training and validation sets using 80-20 split\n", + "img_name_train, img_name_val, cap_train, cap_val = train_test_split(img_name_vector, \n", + " cap_vector, \n", + " test_size=0.2, \n", + " random_state=0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "nyqH3zFwRPFi", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Preprocess and tokenize the captions\n", - "\n", - "* First, we'll tokenize the captions (e.g., by splitting on spaces). This will give us a vocabulary of all the unique words in the data (e.g., \"surfing\", \"football\", etc).\n", - "* Next, we'll limit the vocabulary size to the top 5,000 words to save memory. We'll replace all other words with the token \"UNK\" (for unknown).\n", - "* Finally, we create a word --> index mapping and vice-versa.\n", - "* We will then pad all sequences to the be same length as the longest one. " - ] + "colab_type": "code", + "id": "XmViPkRFRPGH" + }, + "outputs": [], + "source": [ + "len(img_name_train), len(cap_train), len(img_name_val), len(cap_val)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "uEWM9xrYcg45" + }, + "source": [ + "## Our images and captions are ready! Next, let's create a tf.data dataset to use for training our model.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "HZfK8RhQRPFj", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# This will find the maximum length of any caption in our dataset\n", - "def calc_max_length(tensor):\n", - " return max(len(t) for t in tensor)" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "Q3TnZ1ToRPGV" + }, + "outputs": [], + "source": [ + "# feel free to change these parameters according to your system's configuration\n", + "\n", + "BATCH_SIZE = 64\n", + "BUFFER_SIZE = 1000\n", + "embedding_dim = 256\n", + "units = 512\n", + "vocab_size = len(tokenizer.word_index)\n", + "# shape of the vector extracted from InceptionV3 is (64, 2048)\n", + "# these two variables represent that\n", + "features_shape = 2048\n", + "attention_features_shape = 64" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "oJGE34aiRPFo", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# The steps above is a general process of dealing with text processing\n", - "\n", - "# choosing the top 5000 words from the vocabulary\n", - "top_k = 5000\n", - "tokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=top_k, \n", - " oov_token=\"\", \n", - " filters='!\"#$%&()*+.,-/:;=?@[\\]^_`{|}~ ')\n", - "tokenizer.fit_on_texts(train_captions)\n", - "train_seqs = tokenizer.texts_to_sequences(train_captions)" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "SmZS2N0bXG3T" + }, + "outputs": [], + "source": [ + "# loading the numpy files \n", + "def map_func(img_name, cap):\n", + " img_tensor = np.load(img_name.decode('utf-8')+'.npy')\n", + " return img_tensor, cap" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "8Q44tNQVRPFt", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "tokenizer.word_index = {key:value for key, value in tokenizer.word_index.items() if value <= top_k}\n", - "# putting token in the word2idx dictionary\n", - "tokenizer.word_index[tokenizer.oov_token] = top_k + 1\n", - "tokenizer.word_index[''] = 0" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "FDF_Nm3tRPGZ" + }, + "outputs": [], + "source": [ + "dataset = tf.data.Dataset.from_tensor_slices((img_name_train, cap_train))\n", + "\n", + "# using map to load the numpy files in parallel\n", + "# NOTE: Be sure to set num_parallel_calls to the number of CPU cores you have\n", + "# https://www.tensorflow.org/api_docs/python/tf/py_func\n", + "dataset = dataset.map(lambda item1, item2: tf.py_func(\n", + " map_func, [item1, item2], [tf.float32, tf.int32]), num_parallel_calls=8)\n", + "\n", + "# shuffling and batching\n", + "dataset = dataset.shuffle(BUFFER_SIZE)\n", + "# https://www.tensorflow.org/api_docs/python/tf/contrib/data/batch_and_drop_remainder\n", + "dataset = dataset.batch(BATCH_SIZE)\n", + "dataset = dataset.prefetch(1)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "nrvoDphgRPGd" + }, + "source": [ + "## Model\n", + "\n", + "Fun fact, the decoder below is identical to the one in the example for [Neural Machine Translation with Attention]( https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb).\n", + "\n", + "The model architecture is inspired by the [Show, Attend and Tell](https://arxiv.org/pdf/1502.03044.pdf) paper.\n", + "\n", + "* In this example, we extract the features from the lower convolutional layer of InceptionV3 giving us a vector of shape (8, 8, 2048). \n", + "* We squash that to a shape of (64, 2048).\n", + "* This vector is then passed through the CNN Encoder(which consists of a single Fully connected layer).\n", + "* The RNN(here GRU) attends over the image to predict the next word." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "0fpJb5ojRPFv", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# creating the tokenized vectors\n", - "train_seqs = tokenizer.texts_to_sequences(train_captions)" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "AAppCGLKRPGd" + }, + "outputs": [], + "source": [ + "def gru(units):\n", + " # If you have a GPU, we recommend using the CuDNNGRU layer (it provides a \n", + " # significant speedup).\n", + " if tf.test.is_gpu_available():\n", + " return tf.keras.layers.CuDNNGRU(units, \n", + " return_sequences=True, \n", + " return_state=True, \n", + " recurrent_initializer='glorot_uniform')\n", + " else:\n", + " return tf.keras.layers.GRU(units, \n", + " return_sequences=True, \n", + " return_state=True, \n", + " recurrent_activation='sigmoid', \n", + " recurrent_initializer='glorot_uniform')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "olQArbgbRPF1", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# creating a reverse mapping (index -> word)\n", - "index_word = {value:key for key, value in tokenizer.word_index.items()}" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "ja2LFTMSdeV3" + }, + "outputs": [], + "source": [ + "class BahdanauAttention(tf.keras.Model):\n", + " def __init__(self, units):\n", + " super(BahdanauAttention, self).__init__()\n", + " self.W1 = tf.keras.layers.Dense(units)\n", + " self.W2 = tf.keras.layers.Dense(units)\n", + " self.V = tf.keras.layers.Dense(1)\n", + " \n", + " def call(self, features, hidden):\n", + " # features(CNN_encoder output) shape == (batch_size, 64, embedding_dim)\n", + " \n", + " # hidden shape == (batch_size, hidden_size)\n", + " # hidden_with_time_axis shape == (batch_size, 1, hidden_size)\n", + " hidden_with_time_axis = tf.expand_dims(hidden, 1)\n", + " \n", + " # score shape == (batch_size, 64, hidden_size)\n", + " score = tf.nn.tanh(self.W1(features) + self.W2(hidden_with_time_axis))\n", + " \n", + " # attention_weights shape == (batch_size, 64, 1)\n", + " # we get 1 at the last axis because we are applying score to self.V\n", + " attention_weights = tf.nn.softmax(self.V(score), axis=1)\n", + " \n", + " # context_vector shape after sum == (batch_size, hidden_size)\n", + " context_vector = attention_weights * features\n", + " context_vector = tf.reduce_sum(context_vector, axis=1)\n", + " \n", + " return context_vector, attention_weights" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "AidglIZVRPF4", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# padding each vector to the max_length of the captions\n", - "# if the max_length parameter is not provided, pad_sequences calculates that automatically\n", - "cap_vector = tf.keras.preprocessing.sequence.pad_sequences(train_seqs, padding='post')" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "AZ7R1RxHRPGf" + }, + "outputs": [], + "source": [ + "class CNN_Encoder(tf.keras.Model):\n", + " # Since we have already extracted the features and dumped it using pickle\n", + " # This encoder passes those features through a Fully connected layer\n", + " def __init__(self, embedding_dim):\n", + " super(CNN_Encoder, self).__init__()\n", + " # shape after fc == (batch_size, 64, embedding_dim)\n", + " self.fc = tf.keras.layers.Dense(embedding_dim)\n", + " \n", + " def call(self, x):\n", + " x = self.fc(x)\n", + " x = tf.nn.relu(x)\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "gL0wkttkRPGA", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# calculating the max_length \n", - "# used to store the attention weights\n", - "max_length = calc_max_length(train_seqs)" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "V9UbGQmERPGi" + }, + "outputs": [], + "source": [ + "class RNN_Decoder(tf.keras.Model):\n", + " def __init__(self, embedding_dim, units, vocab_size):\n", + " super(RNN_Decoder, self).__init__()\n", + " self.units = units\n", + "\n", + " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n", + " self.gru = gru(self.units)\n", + " self.fc1 = tf.keras.layers.Dense(self.units)\n", + " self.fc2 = tf.keras.layers.Dense(vocab_size)\n", + " \n", + " self.attention = BahdanauAttention(self.units)\n", + " \n", + " def call(self, x, features, hidden):\n", + " # defining attention as a separate model\n", + " context_vector, attention_weights = self.attention(features, hidden)\n", + " \n", + " # x shape after passing through embedding == (batch_size, 1, embedding_dim)\n", + " x = self.embedding(x)\n", + " \n", + " # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)\n", + " x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)\n", + " \n", + " # passing the concatenated vector to the GRU\n", + " output, state = self.gru(x)\n", + " \n", + " # shape == (batch_size, max_length, hidden_size)\n", + " x = self.fc1(output)\n", + " \n", + " # x shape == (batch_size * max_length, hidden_size)\n", + " x = tf.reshape(x, (-1, x.shape[2]))\n", + " \n", + " # output shape == (batch_size * max_length, vocab)\n", + " x = self.fc2(x)\n", + "\n", + " return x, state, attention_weights\n", + "\n", + " def reset_state(self, batch_size):\n", + " return tf.zeros((batch_size, self.units))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "M3CD75nDpvTI", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Split the data into training and testing" - ] + "colab_type": "code", + "id": "Qs_Sr03wRPGk" + }, + "outputs": [], + "source": [ + "encoder = CNN_Encoder(embedding_dim)\n", + "decoder = RNN_Decoder(embedding_dim, units, vocab_size)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "iS7DDMszRPGF", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# Create training and validation sets using 80-20 split\n", - "img_name_train, img_name_val, cap_train, cap_val = train_test_split(img_name_vector, \n", - " cap_vector, \n", - " test_size=0.2, \n", - " random_state=0)" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "-bYN7xA0RPGl" + }, + "outputs": [], + "source": [ + "optimizer = tf.train.AdamOptimizer()\n", + "\n", + "# We are masking the loss calculated for padding\n", + "def loss_function(real, pred):\n", + " mask = 1 - np.equal(real, 0)\n", + " loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred) * mask\n", + " return tf.reduce_mean(loss_)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "PHod7t72RPGn" + }, + "source": [ + "## Training\n", + "\n", + "* We extract the features stored in the respective `.npy` files and then pass those features through the encoder.\n", + "* The encoder output, hidden state(initialized to 0) and the decoder input (which is the start token) is passed to the decoder.\n", + "* The decoder returns the predictions and the decoder hidden state.\n", + "* The decoder hidden state is then passed back into the model and the predictions are used to calculate the loss.\n", + "* Use teacher forcing to decide the next input to the decoder.\n", + "* Teacher forcing is the technique where the target word is passed as the next input to the decoder.\n", + "* The final step is to calculate the gradients and apply it to the optimizer and backpropagate.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "XmViPkRFRPGH", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "len(img_name_train), len(cap_train), len(img_name_val), len(cap_val)" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "Vt4WZ5mhJE-E" + }, + "outputs": [], + "source": [ + "# adding this in a separate cell because if you run the training cell \n", + "# many times, the loss_plot array will be reset\n", + "loss_plot = []" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "uEWM9xrYcg45", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Our images and captions are ready! Next, let's create a tf.data dataset to use for training our model.\n", - "\n" - ] + "colab_type": "code", + "id": "UlA4VIQpRPGo" + }, + "outputs": [], + "source": [ + "EPOCHS = 20\n", + "\n", + "for epoch in range(EPOCHS):\n", + " start = time.time()\n", + " total_loss = 0\n", + " \n", + " for (batch, (img_tensor, target)) in enumerate(dataset):\n", + " loss = 0\n", + " \n", + " # initializing the hidden state for each batch\n", + " # because the captions are not related from image to image\n", + " hidden = decoder.reset_state(batch_size=target.shape[0])\n", + "\n", + " dec_input = tf.expand_dims([tokenizer.word_index['']] * BATCH_SIZE, 1)\n", + " \n", + " with tf.GradientTape() as tape:\n", + " features = encoder(img_tensor)\n", + " \n", + " for i in range(1, target.shape[1]):\n", + " # passing the features through the decoder\n", + " predictions, hidden, _ = decoder(dec_input, features, hidden)\n", + "\n", + " loss += loss_function(target[:, i], predictions)\n", + " \n", + " # using teacher forcing\n", + " dec_input = tf.expand_dims(target[:, i], 1)\n", + " \n", + " total_loss += (loss / int(target.shape[1]))\n", + " \n", + " variables = encoder.variables + decoder.variables\n", + " \n", + " gradients = tape.gradient(loss, variables) \n", + " \n", + " optimizer.apply_gradients(zip(gradients, variables), tf.train.get_or_create_global_step())\n", + " \n", + " if batch % 100 == 0:\n", + " print ('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1, \n", + " batch, \n", + " loss.numpy() / int(target.shape[1])))\n", + " # storing the epoch end loss value to plot later\n", + " loss_plot.append(total_loss / len(cap_vector))\n", + " \n", + " print ('Epoch {} Loss {:.6f}'.format(epoch + 1, \n", + " total_loss/len(cap_vector)))\n", + " print ('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "Q3TnZ1ToRPGV", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# feel free to change these parameters according to your system's configuration\n", - "\n", - "BATCH_SIZE = 64\n", - "BUFFER_SIZE = 1000\n", - "embedding_dim = 256\n", - "units = 512\n", - "vocab_size = len(tokenizer.word_index)\n", - "# shape of the vector extracted from InceptionV3 is (64, 2048)\n", - "# these two variables represent that\n", - "features_shape = 2048\n", - "attention_features_shape = 64" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "1Wm83G-ZBPcC" + }, + "outputs": [], + "source": [ + "plt.plot(loss_plot)\n", + "plt.xlabel('Epochs')\n", + "plt.ylabel('Loss')\n", + "plt.title('Loss Plot')\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "xGvOcLQKghXN" + }, + "source": [ + "## Caption!\n", + "\n", + "* The evaluate function is similar to the training loop, except we don't use teacher forcing here. The input to the decoder at each time step is its previous predictions along with the hidden state and the encoder output.\n", + "* Stop predicting when the model predicts the end token.\n", + "* And store the attention weights for every time step." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "SmZS2N0bXG3T", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# loading the numpy files \n", - "def map_func(img_name, cap):\n", - " img_tensor = np.load(img_name.decode('utf-8')+'.npy')\n", - " return img_tensor, cap" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "RCWpDtyNRPGs" + }, + "outputs": [], + "source": [ + "def evaluate(image):\n", + " attention_plot = np.zeros((max_length, attention_features_shape))\n", + "\n", + " hidden = decoder.reset_state(batch_size=1)\n", + "\n", + " temp_input = tf.expand_dims(load_image(image)[0], 0)\n", + " img_tensor_val = image_features_extract_model(temp_input)\n", + " img_tensor_val = tf.reshape(img_tensor_val, (img_tensor_val.shape[0], -1, img_tensor_val.shape[3]))\n", + "\n", + " features = encoder(img_tensor_val)\n", + "\n", + " dec_input = tf.expand_dims([tokenizer.word_index['']], 0)\n", + " result = []\n", + "\n", + " for i in range(max_length):\n", + " predictions, hidden, attention_weights = decoder(dec_input, features, hidden)\n", + "\n", + " attention_plot[i] = tf.reshape(attention_weights, (-1, )).numpy()\n", + "\n", + " predicted_id = tf.argmax(predictions[0]).numpy()\n", + " result.append(tokenizer.index_word[predicted_id])\n", + "\n", + " if tokenizer.index_word[predicted_id] == '':\n", + " return result, attention_plot\n", + "\n", + " dec_input = tf.expand_dims([predicted_id], 0)\n", + "\n", + " attention_plot = attention_plot[:len(result), :]\n", + " return result, attention_plot" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "FDF_Nm3tRPGZ", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "dataset = tf.data.Dataset.from_tensor_slices((img_name_train, cap_train))\n", - "\n", - "# using map to load the numpy files in parallel\n", - "# NOTE: Be sure to set num_parallel_calls to the number of CPU cores you have\n", - "# https://www.tensorflow.org/api_docs/python/tf/py_func\n", - "dataset = dataset.map(lambda item1, item2: tf.py_func(\n", - " map_func, [item1, item2], [tf.float32, tf.int32]), num_parallel_calls=8)\n", - "\n", - "# shuffling and batching\n", - "dataset = dataset.shuffle(BUFFER_SIZE)\n", - "# https://www.tensorflow.org/api_docs/python/tf/contrib/data/batch_and_drop_remainder\n", - "dataset = dataset.batch(BATCH_SIZE)\n", - "dataset = dataset.prefetch(1)" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "fD_y7PD6RPGt" + }, + "outputs": [], + "source": [ + "def plot_attention(image, result, attention_plot):\n", + " temp_image = np.array(Image.open(image))\n", + "\n", + " fig = plt.figure(figsize=(10, 10))\n", + " \n", + " len_result = len(result)\n", + " for l in range(len_result):\n", + " temp_att = np.resize(attention_plot[l], (8, 8))\n", + " ax = fig.add_subplot(len_result//2, len_result//2, l+1)\n", + " ax.set_title(result[l])\n", + " img = ax.imshow(temp_image)\n", + " ax.imshow(temp_att, cmap='gray', alpha=0.6, extent=img.get_extent())\n", + "\n", + " plt.tight_layout()\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "nrvoDphgRPGd", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Model\n", - "\n", - "Fun fact, the decoder below is identical to the one in the example for [Neural Machine Translation with Attention]( https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb).\n", - "\n", - "The model architecture is inspired by the [Show, Attend and Tell](https://arxiv.org/pdf/1502.03044.pdf) paper.\n", - "\n", - "* In this example, we extract the features from the lower convolutional layer of InceptionV3 giving us a vector of shape (8, 8, 2048). \n", - "* We squash that to a shape of (64, 2048).\n", - "* This vector is then passed through the CNN Encoder(which consists of a single Fully connected layer).\n", - "* The RNN(here GRU) attends over the image to predict the next word." - ] + "colab_type": "code", + "id": "io7ws3ReRPGv" + }, + "outputs": [], + "source": [ + "# captions on the validation set\n", + "rid = np.random.randint(0, len(img_name_val))\n", + "image = img_name_val[rid]\n", + "real_caption = ' '.join([tokenizer.index_word[i] for i in cap_val[rid] if i not in [0]])\n", + "result, attention_plot = evaluate(image)\n", + "\n", + "print ('Real Caption:', real_caption)\n", + "print ('Prediction Caption:', ' '.join(result))\n", + "plot_attention(image, result, attention_plot)\n", + "# opening the image\n", + "Image.open(img_name_val[rid])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Rprk3HEvZuxb" + }, + "source": [ + "## Try it on your own images\n", + "For fun, below we've provided a method you can use to caption your own images with the model we've just trained. Keep in mind, it was trained on a relatively small amount of data, and your images may be different from the training data (so be prepared for weird results!)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, + "colab_type": "code", + "id": "9Psd1quzaAWg" + }, + "outputs": [], + "source": [ + "image_url = 'https://tensorflow.org/images/surf.jpg'\n", + "image_extension = image_url[-4:]\n", + "image_path = tf.keras.utils.get_file('image'+image_extension, \n", + " origin=image_url)\n", + "\n", + "result, attention_plot = evaluate(image_path)\n", + "print ('Prediction Caption:', ' '.join(result))\n", + "plot_attention(image_path, result, attention_plot)\n", + "# opening the image\n", + "Image.open(image_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "VJZXyJco6uLO" + }, + "source": [ + "# Next steps\n", + "\n", + "Congrats! You've just trained an image captioning model with attention. Next, we recommend taking a look at this example [Neural Machine Translation with Attention]( https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb). It uses a similar architecture to translate between Spanish and English sentences. You can also experiment with training the code in this notebook on a different dataset." + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "default_view": {}, + "name": "image_captioning_with_attention.ipynb", + "private_outputs": true, + "provenance": [ { - "metadata": { - "id": "AAppCGLKRPGd", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "def gru(units):\n", - " # If you have a GPU, we recommend using the CuDNNGRU layer (it provides a \n", - " # significant speedup).\n", - " if tf.test.is_gpu_available():\n", - " return tf.keras.layers.CuDNNGRU(units, \n", - " return_sequences=True, \n", - " return_state=True, \n", - " recurrent_initializer='glorot_uniform')\n", - " else:\n", - " return tf.keras.layers.GRU(units, \n", - " return_sequences=True, \n", - " return_state=True, \n", - " recurrent_activation='sigmoid', \n", - " recurrent_initializer='glorot_uniform')" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "ja2LFTMSdeV3", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "class BahdanauAttention(tf.keras.Model):\n", - " def __init__(self, units):\n", - " super(BahdanauAttention, self).__init__()\n", - " self.W1 = tf.keras.layers.Dense(units)\n", - " self.W2 = tf.keras.layers.Dense(units)\n", - " self.V = tf.keras.layers.Dense(1)\n", - " \n", - " def call(self, features, hidden):\n", - " # features(CNN_encoder output) shape == (batch_size, 64, embedding_dim)\n", - " \n", - " # hidden shape == (batch_size, hidden_size)\n", - " # hidden_with_time_axis shape == (batch_size, 1, hidden_size)\n", - " hidden_with_time_axis = tf.expand_dims(hidden, 1)\n", - " \n", - " # score shape == (batch_size, 64, hidden_size)\n", - " score = tf.nn.tanh(self.W1(features) + self.W2(hidden_with_time_axis))\n", - " \n", - " # attention_weights shape == (batch_size, 64, 1)\n", - " # we get 1 at the last axis because we are applying score to self.V\n", - " attention_weights = tf.nn.softmax(self.V(score), axis=1)\n", - " \n", - " # context_vector shape after sum == (batch_size, hidden_size)\n", - " context_vector = attention_weights * features\n", - " context_vector = tf.reduce_sum(context_vector, axis=1)\n", - " \n", - " return context_vector, attention_weights" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "AZ7R1RxHRPGf", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "class CNN_Encoder(tf.keras.Model):\n", - " # Since we have already extracted the features and dumped it using pickle\n", - " # This encoder passes those features through a Fully connected layer\n", - " def __init__(self, embedding_dim):\n", - " super(CNN_Encoder, self).__init__()\n", - " # shape after fc == (batch_size, 64, embedding_dim)\n", - " self.fc = tf.keras.layers.Dense(embedding_dim)\n", - " \n", - " def call(self, x):\n", - " x = self.fc(x)\n", - " x = tf.nn.relu(x)\n", - " return x" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "V9UbGQmERPGi", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "class RNN_Decoder(tf.keras.Model):\n", - " def __init__(self, embedding_dim, units, vocab_size):\n", - " super(RNN_Decoder, self).__init__()\n", - " self.units = units\n", - "\n", - " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n", - " self.gru = gru(self.units)\n", - " self.fc1 = tf.keras.layers.Dense(self.units)\n", - " self.fc2 = tf.keras.layers.Dense(vocab_size)\n", - " \n", - " self.attention = BahdanauAttention(self.units)\n", - " \n", - " def call(self, x, features, hidden):\n", - " # defining attention as a separate model\n", - " context_vector, attention_weights = self.attention(features, hidden)\n", - " \n", - " # x shape after passing through embedding == (batch_size, 1, embedding_dim)\n", - " x = self.embedding(x)\n", - " \n", - " # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)\n", - " x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)\n", - " \n", - " # passing the concatenated vector to the GRU\n", - " output, state = self.gru(x)\n", - " \n", - " # shape == (batch_size, max_length, hidden_size)\n", - " x = self.fc1(output)\n", - " \n", - " # x shape == (batch_size * max_length, hidden_size)\n", - " x = tf.reshape(x, (-1, x.shape[2]))\n", - " \n", - " # output shape == (batch_size * max_length, vocab)\n", - " x = self.fc2(x)\n", - "\n", - " return x, state, attention_weights\n", - "\n", - " def reset_state(self, batch_size):\n", - " return tf.zeros((batch_size, self.units))" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "Qs_Sr03wRPGk", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "encoder = CNN_Encoder(embedding_dim)\n", - "decoder = RNN_Decoder(embedding_dim, units, vocab_size)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "-bYN7xA0RPGl", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "optimizer = tf.train.AdamOptimizer()\n", - "\n", - "# We are masking the loss calculated for padding\n", - "def loss_function(real, pred):\n", - " mask = 1 - np.equal(real, 0)\n", - " loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred) * mask\n", - " return tf.reduce_mean(loss_)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "PHod7t72RPGn", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Training\n", - "\n", - "* We extract the features stored in the respective `.npy` files and then pass those features through the encoder.\n", - "* The encoder output, hidden state(initialized to 0) and the decoder input (which is the start token) is passed to the decoder.\n", - "* The decoder returns the predictions and the decoder hidden state.\n", - "* The decoder hidden state is then passed back into the model and the predictions are used to calculate the loss.\n", - "* Use teacher forcing to decide the next input to the decoder.\n", - "* Teacher forcing is the technique where the target word is passed as the next input to the decoder.\n", - "* The final step is to calculate the gradients and apply it to the optimizer and backpropagate.\n" - ] - }, - { - "metadata": { - "id": "Vt4WZ5mhJE-E", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# adding this in a separate cell because if you run the training cell \n", - "# many times, the loss_plot array will be reset\n", - "loss_plot = []" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "UlA4VIQpRPGo", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "EPOCHS = 20\n", - "\n", - "for epoch in range(EPOCHS):\n", - " start = time.time()\n", - " total_loss = 0\n", - " \n", - " for (batch, (img_tensor, target)) in enumerate(dataset):\n", - " loss = 0\n", - " \n", - " # initializing the hidden state for each batch\n", - " # because the captions are not related from image to image\n", - " hidden = decoder.reset_state(batch_size=target.shape[0])\n", - "\n", - " dec_input = tf.expand_dims([tokenizer.word_index['']] * BATCH_SIZE, 1)\n", - " \n", - " with tf.GradientTape() as tape:\n", - " features = encoder(img_tensor)\n", - " \n", - " for i in range(1, target.shape[1]):\n", - " # passing the features through the decoder\n", - " predictions, hidden, _ = decoder(dec_input, features, hidden)\n", - "\n", - " loss += loss_function(target[:, i], predictions)\n", - " \n", - " # using teacher forcing\n", - " dec_input = tf.expand_dims(target[:, i], 1)\n", - " \n", - " total_loss += (loss / int(target.shape[1]))\n", - " \n", - " variables = encoder.variables + decoder.variables\n", - " \n", - " gradients = tape.gradient(loss, variables) \n", - " \n", - " optimizer.apply_gradients(zip(gradients, variables), tf.train.get_or_create_global_step())\n", - " \n", - " if batch % 100 == 0:\n", - " print ('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1, \n", - " batch, \n", - " loss.numpy() / int(target.shape[1])))\n", - " # storing the epoch end loss value to plot later\n", - " loss_plot.append(total_loss / len(cap_vector))\n", - " \n", - " print ('Epoch {} Loss {:.6f}'.format(epoch + 1, \n", - " total_loss/len(cap_vector)))\n", - " print ('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "1Wm83G-ZBPcC", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "plt.plot(loss_plot)\n", - "plt.xlabel('Epochs')\n", - "plt.ylabel('Loss')\n", - "plt.title('Loss Plot')\n", - "plt.show()" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "xGvOcLQKghXN", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Caption!\n", - "\n", - "* The evaluate function is similar to the training loop, except we don't use teacher forcing here. The input to the decoder at each time step is its previous predictions along with the hidden state and the encoder output.\n", - "* Stop predicting when the model predicts the end token.\n", - "* And store the attention weights for every time step." - ] - }, - { - "metadata": { - "id": "RCWpDtyNRPGs", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "def evaluate(image):\n", - " attention_plot = np.zeros((max_length, attention_features_shape))\n", - "\n", - " hidden = decoder.reset_state(batch_size=1)\n", - "\n", - " temp_input = tf.expand_dims(load_image(image)[0], 0)\n", - " img_tensor_val = image_features_extract_model(temp_input)\n", - " img_tensor_val = tf.reshape(img_tensor_val, (img_tensor_val.shape[0], -1, img_tensor_val.shape[3]))\n", - "\n", - " features = encoder(img_tensor_val)\n", - "\n", - " dec_input = tf.expand_dims([tokenizer.word_index['']], 0)\n", - " result = []\n", - "\n", - " for i in range(max_length):\n", - " predictions, hidden, attention_weights = decoder(dec_input, features, hidden)\n", - "\n", - " attention_plot[i] = tf.reshape(attention_weights, (-1, )).numpy()\n", - "\n", - " predicted_id = tf.argmax(predictions[0]).numpy()\n", - " result.append(index_word[predicted_id])\n", - "\n", - " if index_word[predicted_id] == '':\n", - " return result, attention_plot\n", - "\n", - " dec_input = tf.expand_dims([predicted_id], 0)\n", - "\n", - " attention_plot = attention_plot[:len(result), :]\n", - " return result, attention_plot" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "fD_y7PD6RPGt", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "def plot_attention(image, result, attention_plot):\n", - " temp_image = np.array(Image.open(image))\n", - "\n", - " fig = plt.figure(figsize=(10, 10))\n", - " \n", - " len_result = len(result)\n", - " for l in range(len_result):\n", - " temp_att = np.resize(attention_plot[l], (8, 8))\n", - " ax = fig.add_subplot(len_result//2, len_result//2, l+1)\n", - " ax.set_title(result[l])\n", - " img = ax.imshow(temp_image)\n", - " ax.imshow(temp_att, cmap='gray', alpha=0.6, extent=img.get_extent())\n", - "\n", - " plt.tight_layout()\n", - " plt.show()" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "io7ws3ReRPGv", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# captions on the validation set\n", - "rid = np.random.randint(0, len(img_name_val))\n", - "image = img_name_val[rid]\n", - "real_caption = ' '.join([index_word[i] for i in cap_val[rid] if i not in [0]])\n", - "result, attention_plot = evaluate(image)\n", - "\n", - "print ('Real Caption:', real_caption)\n", - "print ('Prediction Caption:', ' '.join(result))\n", - "plot_attention(image, result, attention_plot)\n", - "# opening the image\n", - "Image.open(img_name_val[rid])" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "Rprk3HEvZuxb", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Try it on your own images\n", - "For fun, below we've provided a method you can use to caption your own images with the model we've just trained. Keep in mind, it was trained on a relatively small amount of data, and your images may be different from the training data (so be prepared for weird results!)\n" - ] - }, - { - "metadata": { - "id": "9Psd1quzaAWg", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "image_url = 'https://tensorflow.org/images/surf.jpg'\n", - "image_extension = image_url[-4:]\n", - "image_path = tf.keras.utils.get_file('image'+image_extension, \n", - " origin=image_url)\n", - "\n", - "result, attention_plot = evaluate(image_path)\n", - "print ('Prediction Caption:', ' '.join(result))\n", - "plot_attention(image_path, result, attention_plot)\n", - "# opening the image\n", - "Image.open(image_path)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "VJZXyJco6uLO", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "# Next steps\n", - "\n", - "Congrats! You've just trained an image captioning model with attention. Next, we recommend taking a look at this example [Neural Machine Translation with Attention]( https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb). It uses a similar architecture to translate between Spanish and English sentences. You can also experiment with training the code in this notebook on a different dataset." - ] + "file_id": "1HI8OK2sMjcx9CTWVn0122QAHOuXaOaMg", + "timestamp": 1530222436922 } - ] + ], + "toc_visible": true, + "version": "0.3.2", + "views": {} + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 } diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_graph_test.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_graph_test.py index 557ad427521..d412b25b368 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_graph_test.py +++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_graph_test.py @@ -36,7 +36,7 @@ class GraphLinearRegressionBenchmark(tf.test.Benchmark): noise_level=0.01, batch_size=batch_size, num_batches=num_batches) - iterator = dataset.make_initializable_iterator() + iterator = tf.compat.v1.data.make_initializable_iterator(dataset) x, y = iterator.get_next() model = linear_regression.LinearModel() diff --git a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb index 480777d9487..66d52a74943 100644 --- a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb +++ b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb @@ -768,7 +768,7 @@ }, "outputs": [], "source": [ - "translate('hace mucho frio aqui.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" + "translate(u'hace mucho frio aqui.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" ] }, { @@ -781,7 +781,7 @@ }, "outputs": [], "source": [ - "translate('esta es mi vida.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" + "translate(u'esta es mi vida.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" ] }, { @@ -794,7 +794,7 @@ }, "outputs": [], "source": [ - "translate('Āætodavia estan en casa?', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" + "translate(u'todavia estan en casa?', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" ] }, { @@ -808,7 +808,7 @@ "outputs": [], "source": [ "# wrong translation\n", - "translate('trata de averiguarlo.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" + "translate(u'trata de averiguarlo.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" ] }, { diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py index f3bb978875e..fb7975d8fe8 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py @@ -142,7 +142,8 @@ class ResNet50Benchmarks(tf.test.Benchmark): with tf.Graph().as_default(): np_images, np_labels = random_batch(batch_size) dataset = tf.data.Dataset.from_tensors((np_images, np_labels)).repeat() - (images, labels) = dataset.make_one_shot_iterator().get_next() + images, labels = tf.compat.v1.data.make_one_shot_iterator( + dataset).get_next() model = resnet50.ResNet50(data_format()) logits = model(images, training=True) diff --git a/tensorflow/contrib/eager/python/examples/revnet/main.py b/tensorflow/contrib/eager/python/examples/revnet/main.py index b702e91f922..9585f3565f8 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/main.py +++ b/tensorflow/contrib/eager/python/examples/revnet/main.py @@ -72,14 +72,11 @@ def main(_): train_one_iter(model, x, y, optimizer, global_step=global_step) if global_step.numpy() % config.log_every == 0: - it_test = ds_test.make_one_shot_iterator() - acc_test, loss_test = evaluate(model, it_test) + acc_test, loss_test = evaluate(model, ds_test) if FLAGS.validate: - it_train = ds_train_one_shot.make_one_shot_iterator() - it_validation = ds_validation.make_one_shot_iterator() - acc_train, loss_train = evaluate(model, it_train) - acc_validation, loss_validation = evaluate(model, it_validation) + acc_train, loss_train = evaluate(model, ds_train_one_shot) + acc_validation, loss_validation = evaluate(model, ds_validation) print("Iter {}, " "training set accuracy {:.4f}, loss {:.4f}; " "validation set accuracy {:.4f}, loss {:.4f}; " @@ -218,11 +215,11 @@ def train_one_iter(model, inputs, labels, optimizer, global_step=None): return logits, loss -def evaluate(model, iterator): +def evaluate(model, dataset): """Compute accuracy with the given dataset iterator.""" mean_loss = tfe.metrics.Mean() accuracy = tfe.metrics.Accuracy() - for x, y in iterator: + for x, y in dataset: logits, _ = model(x, training=False) loss = model.compute_loss(logits=logits, labels=y) accuracy( diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb_graph_test.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb_graph_test.py index 63b5c4c54d1..770484abed9 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb_graph_test.py +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb_graph_test.py @@ -82,7 +82,7 @@ class PTBBenchmark(tf.test.Benchmark): tf.ones( [PTBBenchmark.SEQ_LEN, PTBBenchmark.BATCH_SIZE], dtype=tf.int64)).repeat(num_iters + num_warmup) - inputs = dataset.make_one_shot_iterator().get_next() + inputs = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next() with tf.device(tf.test.gpu_device_name()): outputs = model(inputs, training=True) @@ -124,7 +124,8 @@ class PTBBenchmark(tf.test.Benchmark): dtype=tf.int64)).repeat(num_iters + num_warmup) # inputs and labels have the same shape dataset = tf.data.Dataset.zip((dataset, dataset)) - (inputs, labels) = dataset.make_one_shot_iterator().get_next() + (inputs, labels) = tf.compat.v1.data.make_one_shot_iterator( + dataset).get_next() with tf.device(tf.test.gpu_device_name()): optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0) diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py index c88c0f52eea..566246de495 100644 --- a/tensorflow/contrib/eager/python/metrics_impl.py +++ b/tensorflow/contrib/eager/python/metrics_impl.py @@ -24,6 +24,7 @@ from tensorflow.python.eager import context from tensorflow.python.eager import function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import smart_cond from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops @@ -354,9 +355,10 @@ class Mean(Metric): def write_summary_f(): summary_ops.scalar(name=self.name, tensor=t) return t - control_flow_ops.cond(write_summary, + smart_cond.smart_cond(write_summary, write_summary_f, - lambda: t) + lambda: t, + name="") return t diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py index 9d2d172752c..39e5957f5d1 100644 --- a/tensorflow/contrib/eager/python/metrics_test.py +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -49,18 +49,6 @@ class MetricsTest(test.TestCase): self.assertEqual(dtypes.float64, m.dtype) self.assertEqual(dtypes.float64, m.result().dtype) - def testSummaryArg(self): - m = metrics.Mean() - m([1, 10, 100]) - m(1000) - m([10000.0, 100000.0]) - self.assertEqual(111111.0/6, m.result(write_summary=True).numpy()) - self.assertEqual(111111.0/6, m.result(write_summary=False).numpy()) - with self.assertRaises(ValueError): - m.result(write_summary=5) - with self.assertRaises(ValueError): - m.result(write_summary=[True]) - def testVariableCollections(self): with context.graph_mode(), ops.Graph().as_default(): m = metrics.Mean() diff --git a/tensorflow/contrib/eager/python/network.py b/tensorflow/contrib/eager/python/network.py index f801d9a47b2..5cc0c4f23d9 100644 --- a/tensorflow/contrib/eager/python/network.py +++ b/tensorflow/contrib/eager/python/network.py @@ -24,7 +24,7 @@ import weakref from tensorflow.python.eager import context from tensorflow.python.framework import ops -from tensorflow.python.keras.engine import base_layer as keras_base_layer +from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.layers import base from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging @@ -220,7 +220,7 @@ class Network(base.Layer): avoid_names = parent_network._owned_layers name_uid_map = parent_network._sub_layer_name_uids else: - name_uid_map = keras_base_layer.get_default_graph_uid_map() + name_uid_map = base_layer_utils.get_default_graph_uid_map() # Figure out which names we have to avoid based on which variable scope # we're nested in. strip_name = self._default_parent_variable_scope.name diff --git a/tensorflow/contrib/eager/python/saver.py b/tensorflow/contrib/eager/python/saver.py index f9c716360c5..1d0d6c6c14c 100644 --- a/tensorflow/contrib/eager/python/saver.py +++ b/tensorflow/contrib/eager/python/saver.py @@ -115,6 +115,11 @@ def restore_variables_on_create(save_path, map_func=None): class Saver(object): """A tf.train.Saver adapter for use when eager execution is enabled. + + `Saver`'s name-based checkpointing strategy is fragile. Please switch to + `tf.train.Checkpoint` or `tf.keras.Model.save_weights`, which perform a more + robust object-based saving. These APIs will load checkpoints written by + `Saver`. """ def __init__(self, var_list): diff --git a/tensorflow/contrib/eager/python/tfe_test.py b/tensorflow/contrib/eager/python/tfe_test.py index 4454abfb966..8c35dddb5a5 100644 --- a/tensorflow/contrib/eager/python/tfe_test.py +++ b/tensorflow/contrib/eager/python/tfe_test.py @@ -87,8 +87,8 @@ class TFETest(test_util.TensorFlowTestCase): x += 1. # Without a device context, heuristics are used to place ops. # In this case, ops.reduce_mean runs on the GPU. - reduction_indices = range(x.shape.ndims) - m = math_ops.reduce_mean(x, reduction_indices) + axis = range(x.shape.ndims) + m = math_ops.reduce_mean(x, axis) # m is on GPU, bring it back to CPU and compare. self.assertEqual(3.5, m.cpu().numpy()) diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 37f253d9c11..a888379f13e 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -16,7 +16,6 @@ py_library( srcs_version = "PY2AND3", deps = [ ":boosted_trees", - ":dnn", ":dnn_with_layer_annotations", ":early_stopping", ":expect_tensorflow_estimator_installed", @@ -25,7 +24,6 @@ py_library( ":extenders", ":head", ":hooks", - ":linear", ":logit_fns", ":multi_head", ":replicate_model_fn", @@ -47,18 +45,6 @@ py_library( ], ) -py_library( - name = "dnn", - srcs = ["python/estimator/dnn.py"], - srcs_version = "PY2AND3", - deps = [ - ":expect_tensorflow_estimator_installed", - "//tensorflow:tensorflow_py_no_contrib", - "//tensorflow/python/estimator", - "//tensorflow/python/estimator:dnn", - ], -) - py_library( name = "dnn_with_layer_annotations", srcs = ["python/estimator/dnn_with_layer_annotations.py"], @@ -144,17 +130,6 @@ py_library( ], ) -py_library( - name = "linear", - srcs = ["python/estimator/linear.py"], - srcs_version = "PY2AND3", - deps = [ - ":expect_tensorflow_estimator_installed", - "//tensorflow/python/estimator", - "//tensorflow/python/estimator:linear", - ], -) - py_library( name = "logit_fns", srcs = [ diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py index 80d59627620..7d61247e7ef 100644 --- a/tensorflow/contrib/estimator/__init__.py +++ b/tensorflow/contrib/estimator/__init__.py @@ -58,8 +58,6 @@ _allowed_symbols = [ 'multi_label_head', 'poisson_regression_head', 'regression_head', - 'DNNEstimator', - 'LinearEstimator', 'boosted_trees_classifier_train_in_memory', 'boosted_trees_regressor_train_in_memory', 'call_logit_fn', diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py deleted file mode 100644 index 7894418c4a1..00000000000 --- a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""dnn_linear_combined python module. - -Importing from tensorflow.python.estimator is unsupported -and will soon break! -""" -# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow_estimator.contrib.estimator.python.estimator import dnn_linear_combined - -# Include attrs that start with single underscore. -_HAS_DYNAMIC_ATTRIBUTES = True -dnn_linear_combined.__all__ = [ - s for s in dir(dnn_linear_combined) if not s.startswith('__') -] - -from tensorflow_estimator.contrib.estimator.python.estimator.dnn_linear_combined import * diff --git a/tensorflow/contrib/factorization/python/ops/kmeans.py b/tensorflow/contrib/factorization/python/ops/kmeans.py index f384d761a84..3eb396a29cc 100644 --- a/tensorflow/contrib/factorization/python/ops/kmeans.py +++ b/tensorflow/contrib/factorization/python/ops/kmeans.py @@ -26,7 +26,7 @@ from tensorflow.contrib.factorization.python.ops import clustering_ops from tensorflow.python.estimator import estimator from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator.export import export_output -from tensorflow.python.feature_column import feature_column as fc +from tensorflow.python.feature_column import feature_column_lib as fc from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops diff --git a/tensorflow/contrib/factorization/python/ops/kmeans_test.py b/tensorflow/contrib/factorization/python/ops/kmeans_test.py index 1ab5418fe46..2f7cd131d3e 100644 --- a/tensorflow/contrib/factorization/python/ops/kmeans_test.py +++ b/tensorflow/contrib/factorization/python/ops/kmeans_test.py @@ -27,7 +27,7 @@ from sklearn.cluster import KMeans as SklearnKMeans # pylint: disable=g-import-not-at-top from tensorflow.contrib.factorization.python.ops import kmeans as kmeans_lib from tensorflow.python.estimator import run_config -from tensorflow.python.feature_column import feature_column as fc +from tensorflow.python.feature_column import feature_column_lib as fc from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops diff --git a/tensorflow/contrib/feature_column/BUILD b/tensorflow/contrib/feature_column/BUILD index bbe335be3e1..1cd83bdb5de 100644 --- a/tensorflow/contrib/feature_column/BUILD +++ b/tensorflow/contrib/feature_column/BUILD @@ -14,6 +14,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":sequence_feature_column", + ":sequence_feature_column_v2", "//tensorflow/python:util", ], ) @@ -32,7 +33,7 @@ py_library( "//tensorflow/python:sparse_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python:variable_scope", - "//tensorflow/python/feature_column", + "//tensorflow/python/feature_column:feature_column_py", ], ) @@ -51,7 +52,7 @@ py_test( "//tensorflow/python:parsing_ops", "//tensorflow/python:sparse_tensor", "//tensorflow/python:training", - "//tensorflow/python/feature_column", + "//tensorflow/python/feature_column:feature_column_py", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], @@ -69,7 +70,7 @@ py_test( "//tensorflow/python:parsing_ops", "//tensorflow/python:training", "//tensorflow/python:util", - "//tensorflow/python/feature_column", + "//tensorflow/python/feature_column:feature_column_py", "//tensorflow/python/keras:layers", ], ) @@ -89,7 +90,7 @@ py_library( "//tensorflow/python:tensor_shape", "//tensorflow/python:variable_scope", "//tensorflow/python/feature_column", - "//tensorflow/python/feature_column:feature_column_v2", + "//tensorflow/python/feature_column:feature_column_py", ], ) @@ -110,7 +111,7 @@ py_test( "//tensorflow/python:sparse_tensor", "//tensorflow/python:training", "//tensorflow/python/feature_column", - "//tensorflow/python/feature_column:feature_column_v2", + "//tensorflow/python/feature_column:feature_column_py", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py index dd6da35ed00..9b3a5c58aaa 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py @@ -222,10 +222,8 @@ def sequence_categorical_column_with_identity( ValueError: if `default_value` is not in range `[0, num_buckets)`. """ return fc._SequenceCategoricalColumn( - fc.categorical_column_with_identity( - key=key, - num_buckets=num_buckets, - default_value=default_value)) + fc._categorical_column_with_identity( + key=key, num_buckets=num_buckets, default_value=default_value)) def sequence_categorical_column_with_hash_bucket( @@ -265,10 +263,8 @@ def sequence_categorical_column_with_hash_bucket( ValueError: `dtype` is neither string nor integer. """ return fc._SequenceCategoricalColumn( - fc.categorical_column_with_hash_bucket( - key=key, - hash_bucket_size=hash_bucket_size, - dtype=dtype)) + fc._categorical_column_with_hash_bucket( + key=key, hash_bucket_size=hash_bucket_size, dtype=dtype)) def sequence_categorical_column_with_vocabulary_file( @@ -324,7 +320,7 @@ def sequence_categorical_column_with_vocabulary_file( ValueError: `dtype` is neither string nor integer. """ return fc._SequenceCategoricalColumn( - fc.categorical_column_with_vocabulary_file( + fc._categorical_column_with_vocabulary_file( key=key, vocabulary_file=vocabulary_file, vocabulary_size=vocabulary_size, @@ -384,7 +380,7 @@ def sequence_categorical_column_with_vocabulary_list( ValueError: if `dtype` is not integer or string. """ return fc._SequenceCategoricalColumn( - fc.categorical_column_with_vocabulary_list( + fc._categorical_column_with_vocabulary_list( key=key, vocabulary_list=vocabulary_list, dtype=dtype, diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_integration_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_integration_test.py index d8ca363627e..bcc25b8de89 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_integration_test.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_integration_test.py @@ -53,19 +53,20 @@ class SequenceFeatureColumnIntegrationTest(test.TestCase): return example def _build_feature_columns(self): - col = fc.categorical_column_with_identity( - 'int_ctx', num_buckets=100) + col = fc._categorical_column_with_identity('int_ctx', num_buckets=100) ctx_cols = [ - fc.embedding_column(col, dimension=10), - fc.numeric_column('float_ctx')] + fc._embedding_column(col, dimension=10), + fc._numeric_column('float_ctx') + ] identity_col = sfc.sequence_categorical_column_with_identity( 'int_list', num_buckets=10) bucket_col = sfc.sequence_categorical_column_with_hash_bucket( 'bytes_list', hash_bucket_size=100) seq_cols = [ - fc.embedding_column(identity_col, dimension=10), - fc.embedding_column(bucket_col, dimension=20)] + fc._embedding_column(identity_col, dimension=10), + fc._embedding_column(bucket_col, dimension=20) + ] return ctx_cols, seq_cols @@ -148,8 +149,8 @@ class SequenceExampleParsingTest(test.TestCase): """ example = _make_sequence_example() columns = [ - fc.categorical_column_with_identity('int_ctx', num_buckets=100), - fc.numeric_column('float_ctx'), + fc._categorical_column_with_identity('int_ctx', num_buckets=100), + fc._numeric_column('float_ctx'), col_fn(col_name, col_arg) ] context, seq_features = parsing_ops.parse_single_sequence_example( diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py index 2163af0b438..d5f74028298 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py @@ -24,6 +24,7 @@ import numpy as np from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column as sfc from tensorflow.python.feature_column import feature_column as fc +from tensorflow.python.feature_column import feature_column_lib as fc_lib from tensorflow.python.feature_column.feature_column import _LazyBuilder from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -109,13 +110,15 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): categorical_column_a = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column_a = fc.embedding_column( - categorical_column_a, dimension=embedding_dimension_a, + embedding_column_a = fc._embedding_column( + categorical_column_a, + dimension=embedding_dimension_a, initializer=_get_initializer(embedding_dimension_a, embedding_values_a)) categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - embedding_column_b = fc.embedding_column( - categorical_column_b, dimension=embedding_dimension_b, + embedding_column_b = fc._embedding_column( + categorical_column_b, + dimension=embedding_dimension_b, initializer=_get_initializer(embedding_dimension_b, embedding_values_b)) input_layer, sequence_length = sfc.sequence_input_layer( @@ -148,10 +151,9 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): values=(2, 0, 1), dense_shape=(2, 2)) - categorical_column_a = fc.categorical_column_with_identity( + categorical_column_a = fc._categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column_a = fc.embedding_column( - categorical_column_a, dimension=2) + embedding_column_a = fc._embedding_column(categorical_column_a, dimension=2) with self.assertRaisesRegexp( ValueError, @@ -206,7 +208,7 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) # Test that columns are reordered alphabetically. - shared_embedding_columns = fc.shared_embedding_columns( + shared_embedding_columns = fc_lib.shared_embedding_columns( [categorical_column_b, categorical_column_a], dimension=embedding_dimension, initializer=_get_initializer(embedding_dimension, embedding_values)) @@ -244,11 +246,11 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): values=(2, 0, 1), dense_shape=(2, 2)) - categorical_column_a = fc.categorical_column_with_identity( + categorical_column_a = fc._categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - categorical_column_b = fc.categorical_column_with_identity( + categorical_column_b = fc._categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - shared_embedding_columns = fc.shared_embedding_columns( + shared_embedding_columns = fc_lib.shared_embedding_columns( [categorical_column_a, categorical_column_b], dimension=2) with self.assertRaisesRegexp( @@ -315,10 +317,10 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): categorical_column_a = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size_a) - indicator_column_a = fc.indicator_column(categorical_column_a) + indicator_column_a = fc._indicator_column(categorical_column_a) categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size_b) - indicator_column_b = fc.indicator_column(categorical_column_b) + indicator_column_b = fc._indicator_column(categorical_column_b) input_layer, sequence_length = sfc.sequence_input_layer( features={ 'aaa': sparse_input_a, @@ -342,9 +344,9 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): values=(2, 0, 1), dense_shape=(2, 2)) - categorical_column_a = fc.categorical_column_with_identity( + categorical_column_a = fc._categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - indicator_column_a = fc.indicator_column(categorical_column_a) + indicator_column_a = fc._indicator_column(categorical_column_a) with self.assertRaisesRegexp( ValueError, @@ -530,7 +532,7 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): sparse_input = sparse_tensor.SparseTensorValue(**sparse_input_args) categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=3) - indicator_column = fc.indicator_column(categorical_column) + indicator_column = fc._indicator_column(categorical_column) input_layer, _ = sfc.sequence_input_layer( features={'aaa': sparse_input}, feature_columns=[indicator_column]) @@ -616,8 +618,7 @@ class InputLayerTest(test.TestCase): categorical_column_a = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column_a = fc.embedding_column( - categorical_column_a, dimension=2) + embedding_column_a = fc._embedding_column(categorical_column_a, dimension=2) with self.assertRaisesRegexp( ValueError, @@ -639,7 +640,7 @@ class InputLayerTest(test.TestCase): categorical_column_a = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - indicator_column_a = fc.indicator_column(categorical_column_a) + indicator_column_a = fc._indicator_column(categorical_column_a) with self.assertRaisesRegexp( ValueError, @@ -918,8 +919,9 @@ class SequenceEmbeddingColumnTest( categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column = fc.embedding_column( - categorical_column, dimension=embedding_dimension, + embedding_column = fc._embedding_column( + categorical_column, + dimension=embedding_dimension, initializer=_initializer) embedding_lookup, _ = embedding_column._get_sequence_dense_tensor( @@ -956,8 +958,7 @@ class SequenceEmbeddingColumnTest( categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column = fc.embedding_column( - categorical_column, dimension=2) + embedding_column = fc._embedding_column(categorical_column, dimension=2) _, sequence_length = embedding_column._get_sequence_dense_tensor( _LazyBuilder({'aaa': inputs})) @@ -984,8 +985,7 @@ class SequenceEmbeddingColumnTest( categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column = fc.embedding_column( - categorical_column, dimension=2) + embedding_column = fc._embedding_column(categorical_column, dimension=2) _, sequence_length = embedding_column._get_sequence_dense_tensor( _LazyBuilder({'aaa': sparse_input})) @@ -1055,7 +1055,7 @@ class SequenceSharedEmbeddingColumnTest(test.TestCase): key='aaa', num_buckets=vocabulary_size) categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - shared_embedding_columns = fc.shared_embedding_columns( + shared_embedding_columns = fc_lib.shared_embedding_columns( [categorical_column_a, categorical_column_b], dimension=embedding_dimension, initializer=_initializer) @@ -1101,7 +1101,7 @@ class SequenceSharedEmbeddingColumnTest(test.TestCase): expected_sequence_length_b = [2, 1] categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - shared_embedding_columns = fc.shared_embedding_columns( + shared_embedding_columns = fc_lib.shared_embedding_columns( [categorical_column_a, categorical_column_b], dimension=2) sequence_length_a = shared_embedding_columns[0]._get_sequence_dense_tensor( @@ -1152,7 +1152,7 @@ class SequenceSharedEmbeddingColumnTest(test.TestCase): categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - shared_embedding_columns = fc.shared_embedding_columns( + shared_embedding_columns = fc_lib.shared_embedding_columns( [categorical_column_a, categorical_column_b], dimension=2) sequence_length_a = shared_embedding_columns[0]._get_sequence_dense_tensor( @@ -1218,7 +1218,7 @@ class SequenceIndicatorColumnTest(test.TestCase, parameterized.TestCase): categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - indicator_column = fc.indicator_column(categorical_column) + indicator_column = fc._indicator_column(categorical_column) indicator_tensor, _ = indicator_column._get_sequence_dense_tensor( _LazyBuilder({'aaa': inputs})) @@ -1250,7 +1250,7 @@ class SequenceIndicatorColumnTest(test.TestCase, parameterized.TestCase): categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - indicator_column = fc.indicator_column(categorical_column) + indicator_column = fc._indicator_column(categorical_column) _, sequence_length = indicator_column._get_sequence_dense_tensor( _LazyBuilder({'aaa': inputs})) @@ -1277,7 +1277,7 @@ class SequenceIndicatorColumnTest(test.TestCase, parameterized.TestCase): categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - indicator_column = fc.indicator_column(categorical_column) + indicator_column = fc._indicator_column(categorical_column) _, sequence_length = indicator_column._get_sequence_dense_tensor( _LazyBuilder({'aaa': sparse_input})) diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2.py index 67ffb939663..0d34ad16185 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2.py @@ -26,7 +26,7 @@ import collections from tensorflow.python.feature_column import feature_column as fc_old -from tensorflow.python.feature_column import feature_column_v2 as fc +from tensorflow.python.feature_column import feature_column_lib as fc from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -226,10 +226,8 @@ def sequence_categorical_column_with_identity( ValueError: if `default_value` is not in range `[0, num_buckets)`. """ return fc_old._SequenceCategoricalColumn( - fc_old.categorical_column_with_identity( - key=key, - num_buckets=num_buckets, - default_value=default_value)) + fc_old._categorical_column_with_identity( + key=key, num_buckets=num_buckets, default_value=default_value)) def sequence_categorical_column_with_hash_bucket( @@ -269,10 +267,8 @@ def sequence_categorical_column_with_hash_bucket( ValueError: `dtype` is neither string nor integer. """ return fc_old._SequenceCategoricalColumn( - fc_old.categorical_column_with_hash_bucket( - key=key, - hash_bucket_size=hash_bucket_size, - dtype=dtype)) + fc_old._categorical_column_with_hash_bucket( + key=key, hash_bucket_size=hash_bucket_size, dtype=dtype)) def sequence_categorical_column_with_vocabulary_file( @@ -328,7 +324,7 @@ def sequence_categorical_column_with_vocabulary_file( ValueError: `dtype` is neither string nor integer. """ return fc_old._SequenceCategoricalColumn( - fc_old.categorical_column_with_vocabulary_file( + fc_old._categorical_column_with_vocabulary_file( key=key, vocabulary_file=vocabulary_file, vocabulary_size=vocabulary_size, @@ -388,7 +384,7 @@ def sequence_categorical_column_with_vocabulary_list( ValueError: if `dtype` is not integer or string. """ return fc_old._SequenceCategoricalColumn( - fc_old.categorical_column_with_vocabulary_list( + fc_old._categorical_column_with_vocabulary_list( key=key, vocabulary_list=vocabulary_list, dtype=dtype, @@ -441,7 +437,7 @@ def sequence_numeric_column( ValueError: if any dimension in shape is not a positive integer. ValueError: if `dtype` is not convertible to `tf.float32`. """ - shape = fc._check_shape(shape=shape, key=key) + shape = fc_old._check_shape(shape=shape, key=key) if not (dtype.is_integer or dtype.is_floating): raise ValueError('dtype must be convertible to float. ' 'dtype: {}, key: {}'.format(dtype, key)) diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2_test.py index 5ecd85807c5..ca4398a1420 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2_test.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2_test.py @@ -25,7 +25,7 @@ import numpy as np from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column as sfc_old from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column_v2 as sfc from tensorflow.python.feature_column import feature_column as fc_old -from tensorflow.python.feature_column import feature_column_v2 as fc +from tensorflow.python.feature_column import feature_column_lib as fc from tensorflow.python.feature_column.feature_column import _LazyBuilder from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -111,13 +111,15 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): categorical_column_a = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column_a = fc_old.embedding_column( - categorical_column_a, dimension=embedding_dimension_a, + embedding_column_a = fc_old._embedding_column( + categorical_column_a, + dimension=embedding_dimension_a, initializer=_get_initializer(embedding_dimension_a, embedding_values_a)) categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - embedding_column_b = fc_old.embedding_column( - categorical_column_b, dimension=embedding_dimension_b, + embedding_column_b = fc_old._embedding_column( + categorical_column_b, + dimension=embedding_dimension_b, initializer=_get_initializer(embedding_dimension_b, embedding_values_b)) input_layer, sequence_length = sfc.sequence_input_layer( @@ -150,9 +152,9 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): values=(2, 0, 1), dense_shape=(2, 2)) - categorical_column_a = fc_old.categorical_column_with_identity( + categorical_column_a = fc_old._categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column_a = fc_old.embedding_column( + embedding_column_a = fc_old._embedding_column( categorical_column_a, dimension=2) with self.assertRaisesRegexp( @@ -208,7 +210,7 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) # Test that columns are reordered alphabetically. - shared_embedding_columns = fc_old.shared_embedding_columns( + shared_embedding_columns = fc.shared_embedding_columns( [categorical_column_b, categorical_column_a], dimension=embedding_dimension, initializer=_get_initializer(embedding_dimension, embedding_values)) @@ -246,11 +248,11 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): values=(2, 0, 1), dense_shape=(2, 2)) - categorical_column_a = fc_old.categorical_column_with_identity( + categorical_column_a = fc_old._categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - categorical_column_b = fc_old.categorical_column_with_identity( + categorical_column_b = fc_old._categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - shared_embedding_columns = fc_old.shared_embedding_columns( + shared_embedding_columns = fc.shared_embedding_columns( [categorical_column_a, categorical_column_b], dimension=2) with self.assertRaisesRegexp( @@ -317,10 +319,10 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): categorical_column_a = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size_a) - indicator_column_a = fc_old.indicator_column(categorical_column_a) + indicator_column_a = fc_old._indicator_column(categorical_column_a) categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size_b) - indicator_column_b = fc_old.indicator_column(categorical_column_b) + indicator_column_b = fc_old._indicator_column(categorical_column_b) input_layer, sequence_length = sfc.sequence_input_layer( features={ 'aaa': sparse_input_a, @@ -344,9 +346,9 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): values=(2, 0, 1), dense_shape=(2, 2)) - categorical_column_a = fc_old.categorical_column_with_identity( + categorical_column_a = fc_old._categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - indicator_column_a = fc_old.indicator_column(categorical_column_a) + indicator_column_a = fc_old._indicator_column(categorical_column_a) with self.assertRaisesRegexp( ValueError, @@ -532,7 +534,7 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): sparse_input = sparse_tensor.SparseTensorValue(**sparse_input_args) categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=3) - indicator_column = fc_old.indicator_column(categorical_column) + indicator_column = fc_old._indicator_column(categorical_column) input_layer, _ = sfc.sequence_input_layer( features={'aaa': sparse_input}, feature_columns=[indicator_column]) @@ -618,7 +620,7 @@ class InputLayerTest(test.TestCase): categorical_column_a = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column_a = fc_old.embedding_column( + embedding_column_a = fc_old._embedding_column( categorical_column_a, dimension=2) with self.assertRaisesRegexp( @@ -641,7 +643,7 @@ class InputLayerTest(test.TestCase): categorical_column_a = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - indicator_column_a = fc_old.indicator_column(categorical_column_a) + indicator_column_a = fc_old._indicator_column(categorical_column_a) with self.assertRaisesRegexp( ValueError, @@ -920,8 +922,9 @@ class SequenceEmbeddingColumnTest( categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column = fc_old.embedding_column( - categorical_column, dimension=embedding_dimension, + embedding_column = fc_old._embedding_column( + categorical_column, + dimension=embedding_dimension, initializer=_initializer) embedding_lookup, _ = embedding_column._get_sequence_dense_tensor( @@ -958,8 +961,7 @@ class SequenceEmbeddingColumnTest( categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column = fc_old.embedding_column( - categorical_column, dimension=2) + embedding_column = fc_old._embedding_column(categorical_column, dimension=2) _, sequence_length = embedding_column._get_sequence_dense_tensor( _LazyBuilder({'aaa': inputs})) @@ -986,8 +988,7 @@ class SequenceEmbeddingColumnTest( categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column = fc_old.embedding_column( - categorical_column, dimension=2) + embedding_column = fc_old._embedding_column(categorical_column, dimension=2) _, sequence_length = embedding_column._get_sequence_dense_tensor( _LazyBuilder({'aaa': sparse_input})) @@ -1057,7 +1058,7 @@ class SequenceSharedEmbeddingColumnTest(test.TestCase): key='aaa', num_buckets=vocabulary_size) categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - shared_embedding_columns = fc_old.shared_embedding_columns( + shared_embedding_columns = fc.shared_embedding_columns( [categorical_column_a, categorical_column_b], dimension=embedding_dimension, initializer=_initializer) @@ -1103,7 +1104,7 @@ class SequenceSharedEmbeddingColumnTest(test.TestCase): expected_sequence_length_b = [2, 1] categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - shared_embedding_columns = fc_old.shared_embedding_columns( + shared_embedding_columns = fc.shared_embedding_columns( [categorical_column_a, categorical_column_b], dimension=2) sequence_length_a = shared_embedding_columns[0]._get_sequence_dense_tensor( @@ -1154,7 +1155,7 @@ class SequenceSharedEmbeddingColumnTest(test.TestCase): categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - shared_embedding_columns = fc_old.shared_embedding_columns( + shared_embedding_columns = fc.shared_embedding_columns( [categorical_column_a, categorical_column_b], dimension=2) sequence_length_a = shared_embedding_columns[0]._get_sequence_dense_tensor( @@ -1220,7 +1221,7 @@ class SequenceIndicatorColumnTest(test.TestCase, parameterized.TestCase): categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - indicator_column = fc_old.indicator_column(categorical_column) + indicator_column = fc_old._indicator_column(categorical_column) indicator_tensor, _ = indicator_column._get_sequence_dense_tensor( _LazyBuilder({'aaa': inputs})) @@ -1252,7 +1253,7 @@ class SequenceIndicatorColumnTest(test.TestCase, parameterized.TestCase): categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - indicator_column = fc_old.indicator_column(categorical_column) + indicator_column = fc_old._indicator_column(categorical_column) _, sequence_length = indicator_column._get_sequence_dense_tensor( _LazyBuilder({'aaa': inputs})) diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD index cd747df4d69..dad50a3a730 100644 --- a/tensorflow/contrib/framework/BUILD +++ b/tensorflow/contrib/framework/BUILD @@ -47,6 +47,11 @@ tf_custom_op_py_library( ":variable_ops_op_lib", ], srcs_version = "PY2AND3", + visibility = [ + "//learning/brain:__subpackages__", + "//tensorflow:__subpackages__", + "//video/youtube/personalization:__subpackages__", + ], deps = [ ":gen_variable_ops", "//tensorflow/contrib/util:util_py", @@ -66,6 +71,7 @@ tf_custom_op_py_library( "//tensorflow/python:resource_variable_ops", "//tensorflow/python:script_ops", "//tensorflow/python:smart_cond", + "//tensorflow/python:sort_ops", "//tensorflow/python:sparse_tensor", "//tensorflow/python:state_ops", "//tensorflow/python:state_ops_gen", @@ -311,17 +317,3 @@ py_test( "//third_party/py/numpy", ], ) - -py_test( - name = "sort_ops_test", - size = "medium", - srcs = ["python/ops/sort_ops_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":framework_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:random_ops", - "//third_party/py/numpy", - ], -) diff --git a/tensorflow/contrib/framework/python/ops/sort_ops.py b/tensorflow/contrib/framework/python/ops/sort_ops.py index 1921a77c1e9..42184a4e55e 100644 --- a/tensorflow/contrib/framework/python/ops/sort_ops.py +++ b/tensorflow/contrib/framework/python/ops/sort_ops.py @@ -22,173 +22,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np +from tensorflow.python.ops import sort_ops -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import ops as framework_ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops - - -def sort(values, axis=-1, direction='ASCENDING', name=None): - """Sorts a tensor. - - Args: - values: 1-D or higher numeric `Tensor`. - axis: The axis along which to sort. The default is -1, which sorts the last - axis. - direction: The direction in which to sort the values (`'ASCENDING'` or - `'DESCENDING'`). - name: Optional name for the operation. - - Returns: - A `Tensor` with the same dtype and shape as `values`, with the elements - sorted along the given `axis`. - - Raises: - ValueError: If axis is not a constant scalar, or the direction is invalid. - """ - with framework_ops.name_scope(name, 'sort'): - return _sort_or_argsort(values, axis, direction, return_argsort=False) - - -def argsort(values, axis=-1, direction='ASCENDING', stable=False, name=None): - """Returns the indices of a tensor that give its sorted order along an axis. - - For a 1D tensor, `tf.gather(values, tf.argsort(values))` is equivalent to - `tf.sort(values)`. For higher dimensions, the output has the same shape as - `values`, but along the given axis, values represent the index of the sorted - element in that slice of the tensor at the given position. - - Args: - values: 1-D or higher numeric `Tensor`. - axis: The axis along which to sort. The default is -1, which sorts the last - axis. - direction: The direction in which to sort the values (`'ASCENDING'` or - `'DESCENDING'`). - stable: If True, equal elements in the original tensor will not be - re-ordered in the returned order. Unstable sort is not yet implemented, - but will eventually be the default for performance reasons. If you - require a stable order, pass `stable=True` for forwards compatibility. - name: Optional name for the operation. - - Returns: - An int32 `Tensor` with the same shape as `values`. The indices that would - sort each slice of the given `values` along the given `axis`. - - Raises: - ValueError: If axis is not a constant scalar, or the direction is invalid. - """ - del stable # Unused. - with framework_ops.name_scope(name, 'argsort'): - return _sort_or_argsort(values, axis, direction, return_argsort=True) - - -def _sort_or_argsort(values, axis, direction, return_argsort): - """Internal sort/argsort implementation. - - Args: - values: The input values. - axis: The axis along which to sort. - direction: 'ASCENDING' or 'DESCENDING'. - return_argsort: Whether to return the argsort result. - - Returns: - Either the sorted values, or the indices of the sorted values in the - original tensor. See the `sort` and `argsort` docstrings. - - Raises: - ValueError: If axis is not a constant scalar, or the direction is invalid. - """ - if direction not in _SORT_IMPL: - raise ValueError('%s should be one of %s' % - (direction, ', '.join(sorted(_SORT_IMPL.keys())))) - # Axis must be an integer, not a Tensor. - axis = framework_ops.convert_to_tensor(axis, name='axis') - axis_static = tensor_util.constant_value(axis) - if axis.shape.ndims != 0 or axis_static is None: - raise ValueError('axis must be a constant scalar') - axis_static = int(axis_static) # Avoids NumPy casting error - - values = framework_ops.convert_to_tensor(values, name='values') - - return _SORT_IMPL[direction](values, axis_static, return_argsort) - - -def _descending_sort(values, axis, return_argsort=False): - """Sorts values in reverse using `top_k`. - - Args: - values: Tensor of numeric values. - axis: Index of the axis which values should be sorted along. - return_argsort: If False, return the sorted values. If True, return the - indices that would sort the values. - - Returns: - The sorted values. - """ - k = array_ops.shape(values)[axis] - rank = array_ops.rank(values) - static_rank = values.shape.ndims - # Fast path: sorting the last axis. - if axis == -1 or axis + 1 == values.get_shape().ndims: - top_k_input = values - transposition = None - else: - # Otherwise, transpose the array. Swap axes `axis` and `rank - 1`. - if axis < 0: - # Calculate the actual axis index if counting from the end. Use the static - # rank if available, or else make the axis back into a tensor. - axis += static_rank or rank - if static_rank is not None: - # Prefer to calculate the transposition array in NumPy and make it a - # constant. - transposition = constant_op.constant( - np.r_[ - # Axes up to axis are unchanged. - np.arange(axis), - # Swap axis and rank - 1. - [static_rank - 1], - # Axes in [axis + 1, rank - 1) are unchanged. - np.arange(axis + 1, static_rank - 1), - # Swap axis and rank - 1. - [axis]], - name='transposition') - else: - # Generate the transposition array from the tensors. - transposition = array_ops.concat( - [ - # Axes up to axis are unchanged. - math_ops.range(axis), - # Swap axis and rank - 1. - [rank - 1], - # Axes in [axis + 1, rank - 1) are unchanged. - math_ops.range(axis + 1, rank - 1), - # Swap axis and rank - 1. - [axis] - ], - axis=0) - top_k_input = array_ops.transpose(values, transposition) - - values, indices = nn_ops.top_k(top_k_input, k) - return_value = indices if return_argsort else values - if transposition is not None: - # transposition contains a single cycle of length 2 (swapping 2 elements), - # so it is an involution (it is its own inverse). - return_value = array_ops.transpose(return_value, transposition) - return return_value - - -def _ascending_sort(values, axis, return_argsort=False): - # Negate the values to get the ascending order from descending sort. - values_or_indices = _descending_sort(-values, axis, return_argsort) - # If not argsort, negate the values again. - return values_or_indices if return_argsort else -values_or_indices - - -_SORT_IMPL = { - 'ASCENDING': _ascending_sort, - 'DESCENDING': _descending_sort, -} +sort = sort_ops.sort +argsort = sort_ops.argsort diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py index 219cc199d79..3593b501bb7 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py @@ -113,7 +113,8 @@ class GANEstimator(estimator.Estimator): add_summaries=None, use_loss_summaries=True, config=None, - warm_start_from=None): + warm_start_from=None, + is_chief=True): """Initializes a GANEstimator instance. Args: @@ -154,6 +155,8 @@ class GANEstimator(estimator.Estimator): config: `RunConfig` object to configure the runtime settings. warm_start_from: A filepath to a checkpoint or saved model, or a WarmStartSettings object to configure initialization. + is_chief: Whether or not this Estimator is running on a chief or worker. + Needs to be set appropriately if using SyncReplicasOptimizers. Raises: ValueError: If loss functions aren't callable. @@ -187,7 +190,7 @@ class GANEstimator(estimator.Estimator): return _get_estimator_spec( mode, gan_model, generator_loss_fn, discriminator_loss_fn, get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer, - get_hooks_fn, use_loss_summaries) + get_hooks_fn, use_loss_summaries, is_chief) super(GANEstimator, self).__init__( model_fn=_model_fn, model_dir=model_dir, config=config, @@ -215,7 +218,7 @@ def _get_gan_model( def _get_estimator_spec( mode, gan_model, generator_loss_fn, discriminator_loss_fn, get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer, - get_hooks_fn=None, use_loss_summaries=True): + get_hooks_fn=None, use_loss_summaries=True, is_chief=True): """Get the EstimatorSpec for the current mode.""" if mode == model_fn_lib.ModeKeys.PREDICT: estimator_spec = model_fn_lib.EstimatorSpec( @@ -236,7 +239,7 @@ def _get_estimator_spec( else discriminator_optimizer) get_hooks_fn = get_hooks_fn or tfgan_train.get_sequential_train_hooks() estimator_spec = _get_train_estimator_spec( - gan_model, gan_loss, gopt, dopt, get_hooks_fn) + gan_model, gan_loss, gopt, dopt, get_hooks_fn, is_chief=is_chief) return estimator_spec @@ -321,11 +324,11 @@ def _get_eval_estimator_spec(gan_model, gan_loss, get_eval_metric_ops_fn=None, def _get_train_estimator_spec( gan_model, gan_loss, generator_optimizer, discriminator_optimizer, - get_hooks_fn, train_op_fn=tfgan_train.gan_train_ops): + get_hooks_fn, train_op_fn=tfgan_train.gan_train_ops, is_chief=True): """Return an EstimatorSpec for the train case.""" scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss train_ops = train_op_fn(gan_model, gan_loss, generator_optimizer, - discriminator_optimizer) + discriminator_optimizer, is_chief=is_chief) training_hooks = get_hooks_fn(train_ops) return model_fn_lib.EstimatorSpec( loss=scalar_loss, diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py index 3d6bdab0ad7..bc9021050bc 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py @@ -48,6 +48,7 @@ from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache from tensorflow.python.training import input as input_lib from tensorflow.python.training import learning_rate_decay +from tensorflow.python.training import sync_replicas_optimizer from tensorflow.python.training import training from tensorflow.python.training import training_util @@ -82,7 +83,7 @@ class GetGANModelTest(test.TestCase, parameterized.TestCase): self.assertEqual(generator_inputs, gan_model.generator_inputs) self.assertIsNotNone(gan_model.generated_data) - self.assertEqual(2, len(gan_model.generator_variables)) # 1 FC layer + self.assertLen(gan_model.generator_variables, 2) # 1 FC layer self.assertIsNotNone(gan_model.generator_fn) if mode == model_fn_lib.ModeKeys.PREDICT: self.assertIsNone(gan_model.real_data) @@ -95,7 +96,7 @@ class GetGANModelTest(test.TestCase, parameterized.TestCase): self.assertIsNotNone(gan_model.real_data) self.assertIsNotNone(gan_model.discriminator_real_outputs) self.assertIsNotNone(gan_model.discriminator_gen_outputs) - self.assertEqual(2, len(gan_model.discriminator_variables)) # 1 FC layer + self.assertLen(gan_model.discriminator_variables, 2) # 1 FC layer self.assertIsNotNone(gan_model.discriminator_scope) self.assertIsNotNone(gan_model.discriminator_fn) @@ -121,6 +122,7 @@ def get_dummy_gan_model(): def dummy_loss_fn(gan_model, add_summaries=True): + del add_summaries return math_ops.reduce_sum(gan_model.discriminator_real_outputs - gan_model.discriminator_gen_outputs) @@ -168,6 +170,35 @@ class GetEstimatorSpecTest(test.TestCase, parameterized.TestCase): self.assertShapeEqual(np.array(0), spec.loss) # must be a scalar self.assertIsNotNone(spec.eval_metric_ops) + def test_get_sync_estimator_spec(self): + """Make sure spec is loaded with sync hooks for sync opts.""" + + def get_sync_optimizer(): + return sync_replicas_optimizer.SyncReplicasOptimizer( + training.GradientDescentOptimizer(learning_rate=1.0), + replicas_to_aggregate=1) + + with ops.Graph().as_default(): + self._gan_model = get_dummy_gan_model() + g_opt = get_sync_optimizer() + d_opt = get_sync_optimizer() + + spec = estimator._get_estimator_spec( + model_fn_lib.ModeKeys.TRAIN, + self._gan_model, + generator_loss_fn=dummy_loss_fn, + discriminator_loss_fn=dummy_loss_fn, + get_eval_metric_ops_fn=get_metrics, + generator_optimizer=g_opt, + discriminator_optimizer=d_opt) + + self.assertLen(spec.training_hooks, 4) + sync_opts = [ + hook._sync_optimizer for hook in spec.training_hooks if + isinstance(hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)] + self.assertLen(sync_opts, 2) + self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt))) + # TODO(joelshor): Add pandas test. class GANEstimatorIntegrationTest(test.TestCase): diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl.py b/tensorflow/contrib/gan/python/losses/python/losses_impl.py index df0342c80c5..a0a86c6337e 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl.py @@ -36,7 +36,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np from tensorflow.contrib.framework.python.ops import variables as contrib_variables_lib from tensorflow.python.framework import ops @@ -47,7 +46,6 @@ from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope -from tensorflow.python.ops.distributions import distribution as ds from tensorflow.python.ops.losses import losses from tensorflow.python.ops.losses import util from tensorflow.python.summary import summary @@ -740,11 +738,16 @@ def least_squares_discriminator_loss( def _validate_distributions(distributions): if not isinstance(distributions, (list, tuple)): raise ValueError('`distributions` must be a list or tuple. Instead, ' - 'found %s.', type(distributions)) + 'found %s.' % type(distributions)) for x in distributions: - if not isinstance(x, ds.Distribution): + # We used to check with `isinstance(x, tf.distributions.Distribution)`. + # However, distributions have migrated to `tfp.distributions.Distribution`, + # which is a new code repo, so we can't check this way anymore until + # TF-GAN is migrated to a new repo as well. + # This new check is not sufficient, but is a useful heuristic for now. + if not callable(getattr(x, 'log_prob', None)): raise ValueError('`distributions` must be a list of `Distributions`. ' - 'Instead, found %s.', type(x)) + 'Instead, found %s.' % type(x)) def _validate_information_penalty_inputs( @@ -817,7 +820,7 @@ def _numerically_stable_global_norm(tensor_list): Returns: A scalar tensor with the global norm. """ - if np.all([x is None for x in tensor_list]): + if all(x is None for x in tensor_list): return 0.0 list_max = math_ops.reduce_max([math_ops.reduce_max(math_ops.abs(x)) for x in diff --git a/tensorflow/contrib/gan/python/namedtuples.py b/tensorflow/contrib/gan/python/namedtuples.py index b9ac1bf1513..969b68449d9 100644 --- a/tensorflow/contrib/gan/python/namedtuples.py +++ b/tensorflow/contrib/gan/python/namedtuples.py @@ -213,7 +213,8 @@ class GANTrainOps( collections.namedtuple('GANTrainOps', ( 'generator_train_op', 'discriminator_train_op', - 'global_step_inc_op' + 'global_step_inc_op', + 'train_hooks' ))): """GANTrainOps contains the training ops. @@ -221,8 +222,17 @@ class GANTrainOps( generator_train_op: Op that performs a generator update step. discriminator_train_op: Op that performs a discriminator update step. global_step_inc_op: Op that increments the shared global step. + train_hooks: a list or tuple containing hooks related to training that need + to be populated when training ops are instantiated. Used primarily for + sync hooks. """ + def __new__(cls, generator_train_op, discriminator_train_op, + global_step_inc_op, train_hooks=()): + return super(GANTrainOps, cls).__new__(cls, generator_train_op, + discriminator_train_op, + global_step_inc_op, train_hooks) + class GANTrainSteps( collections.namedtuple('GANTrainSteps', ( diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py index 7ee39f304ab..4c7bee41b33 100644 --- a/tensorflow/contrib/gan/python/train.py +++ b/tensorflow/contrib/gan/python/train.py @@ -114,7 +114,7 @@ def gan_model( discriminator_gen_outputs = discriminator_fn(generated_data, generator_inputs) with variable_scope.variable_scope(dis_scope, reuse=True): - real_data = ops.convert_to_tensor(real_data) + real_data = _convert_tensor_or_l_or_d(real_data) discriminator_real_outputs = discriminator_fn(real_data, generator_inputs) if check_shapes: @@ -924,6 +924,7 @@ def gan_train_ops( generator_optimizer, discriminator_optimizer, check_for_unused_update_ops=True, + is_chief=True, # Optional args to pass directly to the `create_train_op`. **kwargs): """Returns GAN train ops. @@ -939,6 +940,8 @@ def gan_train_ops( discriminator_optimizer: The optimizer for the discriminator updates. check_for_unused_update_ops: If `True`, throws an exception if there are update ops outside of the generator or discriminator scopes. + is_chief: Specifies whether or not the training is being run by the primary + replica during replica training. **kwargs: Keyword args to pass directly to `training.create_train_op` for both the generator and discriminator train op. @@ -980,6 +983,9 @@ def gan_train_ops( kwargs, model.generator_scope.name, model.discriminator_scope.name, check_for_unused_update_ops) + # Get the sync hooks if these are needed. + sync_hooks = [] + generator_global_step = None if isinstance(generator_optimizer, sync_replicas_optimizer.SyncReplicasOptimizer): @@ -995,6 +1001,7 @@ def gan_train_ops( trainable=False, collections=[ops.GraphKeys.GLOBAL_VARIABLES]) gen_update_ops += [generator_global_step.assign(global_step)] + sync_hooks.append(generator_optimizer.make_session_run_hook(is_chief)) with ops.name_scope('generator_train'): gen_train_op = training.create_train_op( total_loss=loss.generator_loss, @@ -1016,6 +1023,7 @@ def gan_train_ops( trainable=False, collections=[ops.GraphKeys.GLOBAL_VARIABLES]) dis_update_ops += [discriminator_global_step.assign(global_step)] + sync_hooks.append(discriminator_optimizer.make_session_run_hook(is_chief)) with ops.name_scope('discriminator_train'): disc_train_op = training.create_train_op( total_loss=loss.discriminator_loss, @@ -1025,7 +1033,8 @@ def gan_train_ops( update_ops=dis_update_ops, **kwargs) - return namedtuples.GANTrainOps(gen_train_op, disc_train_op, global_step_inc) + return namedtuples.GANTrainOps(gen_train_op, disc_train_op, global_step_inc, + sync_hooks) # TODO(joelshor): Implement a dynamic GAN train loop, as in `Real-Time Adaptive @@ -1066,13 +1075,24 @@ def get_sequential_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)): train_steps.generator_train_steps) discriminator_hook = RunTrainOpsHook(train_ops.discriminator_train_op, train_steps.discriminator_train_steps) - return [generator_hook, discriminator_hook] + return [generator_hook, discriminator_hook] + list(train_ops.train_hooks) return get_hooks +def _num_joint_steps(train_steps): + g_steps = train_steps.generator_train_steps + d_steps = train_steps.discriminator_train_steps + # Get the number of each type of step that should be run. + num_d_and_g_steps = min(g_steps, d_steps) + num_g_steps = g_steps - num_d_and_g_steps + num_d_steps = d_steps - num_d_and_g_steps + + return num_d_and_g_steps, num_g_steps, num_d_steps + + def get_joint_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)): - """Returns a hooks function for sequential GAN training. + """Returns a hooks function for joint GAN training. When using these train hooks, IT IS RECOMMENDED TO USE `use_locking=True` ON ALL OPTIMIZERS TO AVOID RACE CONDITIONS. @@ -1105,12 +1125,7 @@ def get_joint_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)): Returns: A function that takes a GANTrainOps tuple and returns a list of hooks. """ - g_steps = train_steps.generator_train_steps - d_steps = train_steps.discriminator_train_steps - # Get the number of each type of step that should be run. - num_d_and_g_steps = min(g_steps, d_steps) - num_g_steps = g_steps - num_d_and_g_steps - num_d_steps = d_steps - num_d_and_g_steps + num_d_and_g_steps, num_g_steps, num_d_steps = _num_joint_steps(train_steps) def get_hooks(train_ops): g_op = train_ops.generator_train_op @@ -1120,7 +1135,7 @@ def get_joint_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)): g_hook = RunTrainOpsHook(g_op, num_g_steps) d_hook = RunTrainOpsHook(d_op, num_d_steps) - return [joint_hook, g_hook, d_hook] + return [joint_hook, g_hook, d_hook] + list(train_ops.train_hooks) return get_hooks diff --git a/tensorflow/contrib/gan/python/train_test.py b/tensorflow/contrib/gan/python/train_test.py index 64d67061990..841f25cd7f1 100644 --- a/tensorflow/contrib/gan/python/train_test.py +++ b/tensorflow/contrib/gan/python/train_test.py @@ -519,7 +519,7 @@ class GANLossTest(test.TestCase, parameterized.TestCase): """Test output type.""" loss = train.gan_loss(get_gan_model_fn(), add_summaries=True) self.assertIsInstance(loss, namedtuples.GANLoss) - self.assertGreater(len(ops.get_collection(ops.GraphKeys.SUMMARIES)), 0) + self.assertNotEmpty(ops.get_collection(ops.GraphKeys.SUMMARIES)) @parameterized.named_parameters( ('cyclegan', create_cyclegan_model), @@ -528,7 +528,7 @@ class GANLossTest(test.TestCase, parameterized.TestCase): def test_cyclegan_output_type(self, get_gan_model_fn): loss = train.cyclegan_loss(get_gan_model_fn(), add_summaries=True) self.assertIsInstance(loss, namedtuples.CycleGANLoss) - self.assertGreater(len(ops.get_collection(ops.GraphKeys.SUMMARIES)), 0) + self.assertNotEmpty(ops.get_collection(ops.GraphKeys.SUMMARIES)) @parameterized.named_parameters( ('gan', create_gan_model, False), @@ -759,7 +759,7 @@ class TensorPoolAdjusteModelTest(test.TestCase): # For [pool_size, ?), the pool is full, tensor2 must be equal to some # historical values of tensor1 (which is previously stored in the # pool). - self.assertTrue(any([(v == t2).all() for v in history_values])) + self.assertTrue(any((v == t2).all() for v in history_values)) def _make_new_model_and_check(self, model, pool_size): pool_fn = lambda x: random_tensor_pool.tensor_pool(x, pool_size=pool_size) @@ -836,6 +836,9 @@ class GANTrainOpsTest(test.TestCase, parameterized.TestCase): self.assertIsInstance(train_ops, namedtuples.GANTrainOps) + # Make sure there are no training hooks populated accidentally. + self.assertEmpty(train_ops.train_hooks) + # TODO(joelshor): Add a test to check that custom update op is run. @parameterized.named_parameters( ('gan', create_gan_model, False), @@ -923,8 +926,15 @@ class GANTrainOpsTest(test.TestCase, parameterized.TestCase): model, loss, generator_optimizer=g_opt, discriminator_optimizer=d_opt) self.assertIsInstance(train_ops, namedtuples.GANTrainOps) # No new trainable variables should have been added. - self.assertEqual(num_trainable_vars, - len(variables_lib.get_trainable_variables())) + self.assertLen(variables_lib.get_trainable_variables(), num_trainable_vars) + + # Sync hooks should be populated in the GANTrainOps. + self.assertLen(train_ops.train_hooks, 2) + for hook in train_ops.train_hooks: + self.assertIsInstance( + hook, sync_replicas_optimizer._SyncReplicasOptimizerHook) + sync_opts = [hook._sync_optimizer for hook in train_ops.train_hooks] + self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt))) g_sync_init_op = g_opt.get_init_tokens_op(num_tokens=1) d_sync_init_op = d_opt.get_init_tokens_op(num_tokens=1) @@ -959,6 +969,32 @@ class GANTrainOpsTest(test.TestCase, parameterized.TestCase): coord.request_stop() coord.join(g_threads + d_threads) + @parameterized.named_parameters( + ('is_chief', True), + ('is_not_chief', False), + ) + def test_is_chief_in_train_hooks(self, is_chief): + """Make sure is_chief is propagated correctly to sync hooks.""" + model = create_gan_model() + loss = train.gan_loss(model) + g_opt = get_sync_optimizer() + d_opt = get_sync_optimizer() + train_ops = train.gan_train_ops( + model, + loss, + g_opt, + d_opt, + is_chief=is_chief, + summarize_gradients=True, + colocate_gradients_with_ops=True) + + self.assertLen(train_ops.train_hooks, 2) + for hook in train_ops.train_hooks: + self.assertIsInstance( + hook, sync_replicas_optimizer._SyncReplicasOptimizerHook) + is_chief_list = [hook._is_chief for hook in train_ops.train_hooks] + self.assertListEqual(is_chief_list, [is_chief, is_chief]) + class GANTrainTest(test.TestCase, parameterized.TestCase): """Tests for `gan_train`.""" @@ -1036,6 +1072,44 @@ class GANTrainTest(test.TestCase, parameterized.TestCase): self.assertTrue(np.isscalar(final_loss)) self.assertEqual(17.0, final_loss) + @parameterized.named_parameters( + ('gan', create_gan_model), + ('callable_gan', create_callable_gan_model), + ('infogan', create_infogan_model), + ('callable_infogan', create_callable_infogan_model), + ('acgan', create_acgan_model), + ('callable_acgan', create_callable_acgan_model), + ) + def test_train_hooks_exist_in_get_hooks_fn(self, create_gan_model_fn): + model = create_gan_model_fn() + loss = train.gan_loss(model) + + g_opt = get_sync_optimizer() + d_opt = get_sync_optimizer() + train_ops = train.gan_train_ops( + model, + loss, + g_opt, + d_opt, + summarize_gradients=True, + colocate_gradients_with_ops=True) + + sequential_train_hooks = train.get_sequential_train_hooks()(train_ops) + self.assertLen(sequential_train_hooks, 4) + sync_opts = [ + hook._sync_optimizer for hook in sequential_train_hooks if + isinstance(hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)] + self.assertLen(sync_opts, 2) + self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt))) + + joint_train_hooks = train.get_joint_train_hooks()(train_ops) + self.assertLen(joint_train_hooks, 5) + sync_opts = [ + hook._sync_optimizer for hook in joint_train_hooks if + isinstance(hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)] + self.assertLen(sync_opts, 2) + self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt))) + class PatchGANTest(test.TestCase, parameterized.TestCase): """Tests that functions work on PatchGAN style output.""" diff --git a/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc b/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc index 94f522c04e5..fbccbead03f 100644 --- a/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc +++ b/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc @@ -170,6 +170,14 @@ class GdrRemoteRendezvous : public BaseRemoteRendezvous { // Record "call" in active_ so that it can be aborted cleanly. RegisterCall(call); + // RendezvousMgr already aborted, shouldn't send RPC call any more + if (!call->status().ok()) { + done(call->status(), Args(), Args(), Tensor(), false); + session()->worker_cache->ReleaseWorker(src_worker, rwi); + delete call; + return; + } + // Start "call". Ref(); call->Start([this, call, src_worker, rwi, done]() { diff --git a/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py b/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py index f7f1189bb93..bc941ae9f23 100644 --- a/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py +++ b/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import os from tensorflow.contrib.hadoop.python.ops import hadoop_dataset_ops +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -47,7 +48,7 @@ class SequenceFileDatasetTest(test.TestCase): dataset = hadoop_dataset_ops.SequenceFileDataset(filenames).repeat( num_repeats) - iterator = dataset.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(dataset) init_op = iterator.initializer get_next = iterator.get_next() diff --git a/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py b/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py index bf398b838df..d3fcc8cb2a9 100644 --- a/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py +++ b/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py @@ -40,15 +40,12 @@ class SequenceFileDataset(dataset_ops.DatasetSource): For example: ```python + tf.enable_eager_execution() + dataset = tf.contrib.hadoop.SequenceFileDataset("/foo/bar.seq") - iterator = dataset.make_one_shot_iterator() - next_element = iterator.get_next() # Prints the (key, value) pairs inside a hadoop sequence file. - while True: - try: - print(sess.run(next_element)) - except tf.errors.OutOfRangeError: - break + for key, value in dataset: + print(key, value) ``` Args: diff --git a/tensorflow/contrib/ignite/README.md b/tensorflow/contrib/ignite/README.md index c7db0b77e25..5a8c650fb92 100644 --- a/tensorflow/contrib/ignite/README.md +++ b/tensorflow/contrib/ignite/README.md @@ -54,14 +54,12 @@ jdbc:ignite:thin://localhost/> INSERT INTO KITTEN_CACHE VALUES (3, 'LITTLE BALL ```python >>> import tensorflow as tf >>> from tensorflow.contrib.ignite import IgniteDataset ->>> ->>> dataset = IgniteDataset(cache_name="SQL_PUBLIC_KITTEN_CACHE") ->>> iterator = dataset.make_one_shot_iterator() ->>> next_obj = iterator.get_next() +>>> tf.enable_eager_execution() >>> ->>> with tf.Session() as sess: ->>> for _ in range(3): ->>> print(sess.run(next_obj)) +>>> dataset = IgniteDataset(cache_name="SQL_PUBLIC_KITTEN_CACHE") +>>> +>>> for element in dataset: +>>> print(element) {'key': 1, 'val': {'NAME': b'WARM KITTY'}} {'key': 2, 'val': {'NAME': b'SOFT KITTY'}} @@ -74,23 +72,22 @@ jdbc:ignite:thin://localhost/> INSERT INTO KITTEN_CACHE VALUES (3, 'LITTLE BALL ```python >>> import tensorflow as tf >>> from tensorflow.contrib.ignite import IgniteDataset ->>> ->>> dataset = IgniteDataset(cache_name="IMAGES") ->>> iterator = dataset.make_one_shot_iterator() ->>> next_obj = iterator.get_next() +>>> tf.enable_eager_execution() >>> ->>> with tf.Session() as sess: ->>> print(sess.run(next_obj)) +>>> dataset = IgniteDataset(cache_name="IMAGES") +>>> +>>> for element in dataset.take(1): +>>> print(element) { - 'key': 'kitten.png', + 'key': 'kitten.png', 'val': { 'metadata': { 'file_name': b'kitten.png', 'label': b'little ball of fur', - width: 800, + width: 800, height: 600 - }, + }, 'pixels': [0, 0, 0, 0, ..., 0] } } @@ -100,13 +97,11 @@ jdbc:ignite:thin://localhost/> INSERT INTO KITTEN_CACHE VALUES (3, 'LITTLE BALL ```python >>> import tensorflow as tf >>> from tensorflow.contrib.ignite import IgniteDataset ->>> ->>> dataset = IgniteDataset(cache_name="IMAGES").map(lambda obj: obj['val']['pixels']) ->>> iterator = dataset.make_one_shot_iterator() ->>> next_obj = iterator.get_next() >>> ->>> with tf.Session() as sess: ->>> print(sess.run(next_obj)) +>>> dataset = IgniteDataset(cache_name="IMAGES").map(lambda obj: obj['val']['pixels']) +>>> +>>> for element in dataset: +>>> print(element) [0, 0, 0, 0, ..., 0] ``` @@ -126,18 +121,18 @@ Ignite Dataset allows using these two aspects of distributed neural network trai ```python >>> import tensorflow as tf >>> from tensorflow.contrib.ignite import IgniteDataset ->>> +>>> >>> dataset = IgniteDataset("IMAGES") >>> >>> # Compute gradients locally on every worker node. ->>> gradients = [] +>>> gradients = [] >>> for i in range(5): >>> with tf.device("/job:WORKER/task:%d" % i): ->>> device_iterator = dataset.make_one_shot_iterator() +>>> device_iterator = tf.compat.v1.data.make_one_shot_iterator(dataset) >>> device_next_obj = device_iterator.get_next() >>> gradient = compute_gradient(device_next_obj) ->>> gradients.append(gradient) ->>> +>>> gradients.append(gradient) +>>> >>> # Aggregate them on master node. >>> result_gradient = tf.reduce_sum(gradients) >>> @@ -145,7 +140,7 @@ Ignite Dataset allows using these two aspects of distributed neural network trai >>> print(sess.run(result_gradient)) ``` -High-level TensorFlow API for [distributed training](https://www.tensorflow.org/api_docs/python/tf/contrib/distribute/DistributionStrategy) is supported as well. +High-level TensorFlow API for [distributed training](https://www.tensorflow.org/api_docs/python/tf/contrib/distribute/DistributionStrategy) is supported as well. ### Distributed File System diff --git a/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py b/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py index ef29b5f14a4..ff5d4c458c8 100644 --- a/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py +++ b/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py @@ -21,6 +21,7 @@ import os from tensorflow.contrib.ignite import IgniteDataset from tensorflow.python.client import session +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.platform import test @@ -65,7 +66,7 @@ class IgniteDatasetTest(test.TestCase): self.assertEqual(dtypes.string, dataset.output_types["val"]["NAME"]) self.assertEqual(dtypes.int64, dataset.output_types["val"]["VAL"]) - it = dataset.make_one_shot_iterator() + it = dataset_ops.make_one_shot_iterator(dataset) ne = it.get_next() with session.Session() as sess: diff --git a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc index 478b716d883..108da044946 100644 --- a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc +++ b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc @@ -115,7 +115,7 @@ class AdjustHsvInYiqOp : public AdjustHsvInYiqOpBase { *context->device()->tensorflow_cpu_worker_threads(); Shard(worker_threads.num_threads, worker_threads.workers, channel_count, kCostPerChannel, - [channel_count, &input_data, &output_data, &tranformation_matrix]( + [&input_data, &output_data, &tranformation_matrix]( int64 start_channel, int64 end_channel) { // Applying projection matrix to input RGB vectors. const float* p = input_data.data() + start_channel * kChannelSize; diff --git a/tensorflow/contrib/image/python/ops/dense_image_warp.py b/tensorflow/contrib/image/python/ops/dense_image_warp.py index 9c7ada7afb7..7930b8317b6 100644 --- a/tensorflow/contrib/image/python/ops/dense_image_warp.py +++ b/tensorflow/contrib/image/python/ops/dense_image_warp.py @@ -25,7 +25,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops - +from tensorflow.python.ops import check_ops def _interpolate_bilinear(grid, query_points, @@ -60,28 +60,40 @@ def _interpolate_bilinear(grid, msg = 'Grid must be 4 dimensional. Received size: ' raise ValueError(msg + str(grid.get_shape())) - batch_size, height, width, channels = shape + batch_size, height, width, channels = (array_ops.shape(grid)[0], + array_ops.shape(grid)[1], + array_ops.shape(grid)[2], + array_ops.shape(grid)[3]) + + shape = [batch_size, height, width, channels] query_type = query_points.dtype grid_type = grid.dtype - if (query_points.shape.rank != 3 or - query_points.shape.dims[2].value != 2): - msg = ('Query points must be 3 dimensional and size 2 in dim 2. Received ' - 'size: ') - raise ValueError(msg + str(query_points.get_shape())) + with ops.control_dependencies([ + check_ops.assert_equal( + len(query_points.get_shape()), + 3, + message='Query points must be 3 dimensional.'), + check_ops.assert_equal( + array_ops.shape(query_points)[2], + 2, + message='Query points must be size 2 in dim 2.')]): + num_queries = array_ops.shape(query_points)[1] - _, num_queries, _ = query_points.get_shape().as_list() - - if height < 2 or width < 2: - msg = 'Grid must be at least batch_size x 2 x 2 in size. Received size: ' - raise ValueError(msg + str(grid.get_shape())) - - alphas = [] - floors = [] - ceils = [] - - index_order = [0, 1] if indexing == 'ij' else [1, 0] - unstacked_query_points = array_ops.unstack(query_points, axis=2) + with ops.control_dependencies([ + check_ops.assert_greater_equal( + height, + 2, + message='Grid height must be at least 2.'), + check_ops.assert_greater_equal( + width, + 2, + message='Grid width must be at least 2.')]): + alphas = [] + floors = [] + ceils = [] + index_order = [0, 1] if indexing == 'ij' else [1, 0] + unstacked_query_points = array_ops.unstack(query_points, axis=2) for dim in index_order: with ops.name_scope('dim-' + str(dim)): @@ -112,16 +124,17 @@ def _interpolate_bilinear(grid, alpha = array_ops.expand_dims(alpha, 2) alphas.append(alpha) - if batch_size * height * width > np.iinfo(np.int32).max / 8: - error_msg = """The image size or batch size is sufficiently large - that the linearized addresses used by array_ops.gather - may exceed the int32 limit.""" - raise ValueError(error_msg) - - flattened_grid = array_ops.reshape(grid, - [batch_size * height * width, channels]) - batch_offsets = array_ops.reshape( - math_ops.range(batch_size) * height * width, [batch_size, 1]) + with ops.control_dependencies([ + check_ops.assert_less_equal( + math_ops.cast(batch_size * height * width, dtype=dtypes.float32), + np.iinfo(np.int32).max / 8, + message="""The image size or batch size is sufficiently large + that the linearized addresses used by array_ops.gather + may exceed the int32 limit.""")]): + flattened_grid = array_ops.reshape( + grid, [batch_size * height * width, channels]) + batch_offsets = array_ops.reshape( + math_ops.range(batch_size) * height * width, [batch_size, 1]) # This wraps array_ops.gather. We reshape the image data such that the # batch, y, and x coordinates are pulled into the first dimension. @@ -182,7 +195,11 @@ def dense_image_warp(image, flow, name='dense_image_warp'): of dimensions. """ with ops.name_scope(name): - batch_size, height, width, channels = image.get_shape().as_list() + batch_size, height, width, channels = (array_ops.shape(image)[0], + array_ops.shape(image)[1], + array_ops.shape(image)[2], + array_ops.shape(image)[3]) + # The flow is defined on the image grid. Turn the flow into a list of query # points in the grid space. grid_x, grid_y = array_ops.meshgrid( diff --git a/tensorflow/contrib/keras/api/keras/layers/__init__.py b/tensorflow/contrib/keras/api/keras/layers/__init__.py index 3327a9f9a61..9e19884df85 100644 --- a/tensorflow/contrib/keras/api/keras/layers/__init__.py +++ b/tensorflow/contrib/keras/api/keras/layers/__init__.py @@ -20,7 +20,7 @@ from __future__ import print_function # Generic layers. # pylint: disable=g-bad-import-order -from tensorflow.python.keras.engine.base_layer import InputSpec +from tensorflow.python.keras.engine.input_spec import InputSpec from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.keras.engine.input_layer import Input from tensorflow.python.keras.engine.input_layer import InputLayer diff --git a/tensorflow/contrib/keras/api/keras/utils/__init__.py b/tensorflow/contrib/keras/api/keras/utils/__init__.py index 47cd01b924f..3b9fa1b230b 100644 --- a/tensorflow/contrib/keras/api/keras/utils/__init__.py +++ b/tensorflow/contrib/keras/api/keras/utils/__init__.py @@ -30,6 +30,7 @@ from tensorflow.python.keras.utils.generic_utils import Progbar from tensorflow.python.keras.utils.generic_utils import serialize_keras_object from tensorflow.python.keras.utils.io_utils import HDF5Matrix from tensorflow.python.keras.utils.layer_utils import convert_all_kernels_in_model +from tensorflow.python.keras.utils.losses_utils import squeeze_or_expand_dimensions from tensorflow.python.keras.utils.np_utils import normalize from tensorflow.python.keras.utils.np_utils import to_categorical from tensorflow.python.keras.utils.vis_utils import plot_model diff --git a/tensorflow/contrib/kernel_methods/python/kernel_estimators.py b/tensorflow/contrib/kernel_methods/python/kernel_estimators.py index de7530231db..1626e55b9b3 100644 --- a/tensorflow/contrib/kernel_methods/python/kernel_estimators.py +++ b/tensorflow/contrib/kernel_methods/python/kernel_estimators.py @@ -90,7 +90,7 @@ def _update_features_and_columns(features, feature_columns, mapped_column_name = column_name + "_MAPPED" # Construct new feature columns based on provided kernel_mappers. column_kernel_mappers = kernel_mappers_dict[feature_column] - new_dim = sum([mapper.output_dim for mapper in column_kernel_mappers]) + new_dim = sum(mapper.output_dim for mapper in column_kernel_mappers) mapped_columns.add( layers.feature_column.real_valued_column(mapped_column_name, new_dim)) diff --git a/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py b/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py index 75806dbbeb1..c392adbb1d9 100644 --- a/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py +++ b/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py @@ -34,15 +34,12 @@ class KinesisDataset(dataset_ops.DatasetSource): For example, we can construct and use the KinesisDataset as follows: ```python + tf.enable_eager_execution() + dataset = tf.contrib.kinesis.KinesisDataset( "kinesis_stream_name", read_indefinitely=False) - next = dataset.make_one_shot_iterator().get_next() - with tf.Session() as sess: - while True: - try: - print(sess.run(nxt)) - except tf.errors.OutOfRangeError: - break + for element in dataset: + print(element) ``` Since Kinesis is a data streaming service, data may not be available diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD index e6596bfdfb9..9ca6f8df5db 100644 --- a/tensorflow/contrib/layers/BUILD +++ b/tensorflow/contrib/layers/BUILD @@ -78,6 +78,11 @@ tf_custom_op_py_library( ":sparse_feature_cross_op_op_lib", ], srcs_version = "PY2AND3", + visibility = [ + "//learning/brain:__subpackages__", + "//tensorflow:__subpackages__", + "//video/youtube/personalization:__subpackages__", + ], deps = [ ":sparse_feature_cross_op", "//tensorflow/contrib/framework:framework_py", @@ -253,7 +258,7 @@ py_test( "//tensorflow/python:training", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", - "//tensorflow/python/feature_column", + "//tensorflow/python/feature_column:feature_column_py", "//third_party/py/numpy", ], ) @@ -277,7 +282,7 @@ py_test( "//tensorflow/python:sparse_tensor", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", - "//tensorflow/python/feature_column", + "//tensorflow/python/feature_column:feature_column_py", "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py index 124515e5a64..295c721fced 100644 --- a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py +++ b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import itertools import math +import sys import numpy as np @@ -36,6 +37,7 @@ from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test from tensorflow.python.util import compat @@ -48,11 +50,13 @@ class SafeEmbeddingLookupSparseTest(test.TestCase): assert num_shards > 0 assert num_shards <= vocab_size - embedding_weights = partitioned_variables.create_partitioned_variables( + initializer = init_ops.truncated_normal_initializer( + mean=0.0, stddev=1.0 / math.sqrt(vocab_size), dtype=dtypes.float32) + embedding_weights = list(variable_scope.get_variable( + "embedding_weights", shape=[vocab_size, embed_dim], - slicing=[num_shards, 1], - initializer=init_ops.truncated_normal_initializer( - mean=0.0, stddev=1.0 / math.sqrt(vocab_size), dtype=dtypes.float32)) + partitioner=partitioned_variables.fixed_size_partitioner(num_shards), + initializer=initializer)) for w in embedding_weights: w.initializer.run() embedding_weights = [w.eval() for w in embedding_weights] @@ -256,6 +260,13 @@ class SafeEmbeddingLookupSparseTest(test.TestCase): embedding_weights, sparse_ids, sparse_weights) +# pylint: disable=invalid-name +def local_variable_scope(): + """Create a variable scope named like the caller function.""" + return variable_scope.variable_scope(sys._getframe(1).f_code.co_name) +# pylint: enable=invalid-name + + class ScatteredEmbeddingLookupTest(test.TestCase): def setUp(self): @@ -266,17 +277,18 @@ class ScatteredEmbeddingLookupTest(test.TestCase): assert num_shards > 0 assert num_shards <= size - embedding_weights = partitioned_variables.create_partitioned_variables( + embedding_weights = list(variable_scope.get_variable( + "embedding_weights", shape=[size], - slicing=[num_shards], + partitioner=partitioned_variables.fixed_size_partitioner(num_shards), initializer=init_ops.truncated_normal_initializer( - mean=0.0, stddev=1.0, dtype=dtypes.float32)) + mean=0.0, stddev=1.0, dtype=dtypes.float32))) for w in embedding_weights: w.initializer.run() return embedding_weights def test_scattered_embedding_consistency(self): - with self.cached_session(): + with self.cached_session(), local_variable_scope(): embedding_weights = self._random_weights() values = constant_op.constant(["foo", "foo"]) @@ -288,7 +300,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase): embedding_lookup_result[1]) def test_scattered_embedding_multiple_partition(self): - with self.cached_session(): + with self.cached_session(), local_variable_scope(): embedding_weights = self._random_weights(num_shards=7) values = constant_op.constant([4, 4, 5]) @@ -304,7 +316,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase): self.assertGreater(embedding_diff, 0) def test_scattered_embedding_coverage(self): - with self.cached_session(): + with self.cached_session(), local_variable_scope(): size = 8 embedding_weights = self._random_weights(size=size, num_shards=3) values = constant_op.constant(["foo"]) @@ -316,7 +328,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase): self.assertEqual(len(np.unique(embedding_lookup_result[0])), size) def test_scattered_embedding_multi_dimension(self): - with self.cached_session(): + with self.cached_session(), local_variable_scope(): embedding_weights = self._random_weights() values = constant_op.constant([["foo", "bar", "bar"], ["bar", "bar", "foo"]]) @@ -329,7 +341,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase): embedding_lookup_result[1][2]) def test_scattered_embedding_lookup_sparse(self): - with self.cached_session(): + with self.cached_session(), local_variable_scope(): embedding_weights = self._random_weights(num_shards=3) sparse_tensor = sparse_tensor_lib.SparseTensor( values=["foo", "bar", "foo", "bar"], @@ -358,7 +370,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase): embeds = np.random.randn(n_embed, d_embed) idx = np.random.randint(0, n_embed, idx_shape) - with self.cached_session(): + with self.cached_session(), local_variable_scope(): embedded_np = embeds[idx] embedded_tf = embedding_ops.embedding_lookup_unique(embeds, idx).eval() @@ -370,7 +382,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase): idx = np.random.randint(0, 5, 10) idx2d = np.random.randint(0, 5, (10, 2)) - with self.cached_session(): + with self.cached_session(), local_variable_scope(): embedded_np = embeds[idx] embedded_np2d = embeds[idx2d] embedded_tf = embedding_ops.embedding_lookup_unique(embeds, idx).eval() @@ -398,17 +410,18 @@ class SampledScatteredEmbeddingLookupTest(test.TestCase): assert num_shards > 0 assert num_shards <= size - embedding_weights = partitioned_variables.create_partitioned_variables( + embedding_weights = list(variable_scope.get_variable( + "embedding_weights", shape=[size], - slicing=[num_shards], + partitioner=partitioned_variables.fixed_size_partitioner(num_shards), initializer=init_ops.truncated_normal_initializer( - mean=0.0, stddev=1.0, dtype=dtypes.float32)) + mean=0.0, stddev=1.0, dtype=dtypes.float32))) for w in embedding_weights: w.initializer.run() return embedding_weights def test_hashed_embedding_consistency(self): - with self.cached_session(): + with self.cached_session(), local_variable_scope(): embedding_weights = self._random_weights() values = constant_op.constant(["foo", "foo"]) # The first three sampled_candidates are equal, so the first three @@ -429,7 +442,7 @@ class SampledScatteredEmbeddingLookupTest(test.TestCase): embedding_lookup_result[1][3]) def test_hashed_embedding_multi_dimension(self): - with self.cached_session(): + with self.cached_session(), local_variable_scope(): embedding_weights = self._random_weights() values = constant_op.constant([["foo", "bar", "bar"], ["bar", "bar", "foo"]]) @@ -691,7 +704,6 @@ class EmbeddingLookupSparseWithDistributedAggregationTest(test.TestCase): index += num_val return grouped_vals - @test_util.enable_c_shapes def testEmbeddingLookupSparse(self): vocab_size = 13 batch_size = 10 diff --git a/tensorflow/contrib/layers/python/layers/encoders.py b/tensorflow/contrib/layers/python/layers/encoders.py index f42112206d0..3671633c8d7 100644 --- a/tensorflow/contrib/layers/python/layers/encoders.py +++ b/tensorflow/contrib/layers/python/layers/encoders.py @@ -84,8 +84,7 @@ def bow_encoder(ids, if isinstance(ids, sparse_tensor.SparseTensor): raise TypeError('ids are expected to be dense Tensor, got: %s', ids) return math_ops.reduce_mean( - embedding_ops.embedding_lookup(embeddings, ids), - reduction_indices=1) + embedding_ops.embedding_lookup(embeddings, ids), axis=1) def embed_sequence(ids, diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py index 222404b19db..00d819ed0e9 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column.py +++ b/tensorflow/contrib/layers/python/layers/feature_column.py @@ -1015,8 +1015,7 @@ class _OneHotColumn( dense_id_tensor, depth=self.length, on_value=1.0, off_value=0.0) # Reduce to get a multi-hot per example. - return math_ops.reduce_sum( - one_hot_id_tensor, reduction_indices=[output_rank - 1]) + return math_ops.reduce_sum(one_hot_id_tensor, axis=[output_rank - 1]) @property def _variable_shape(self): diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py index 6fb4b9ff353..7e6eafaa0d6 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py @@ -27,7 +27,7 @@ from tensorflow.contrib.layers.python.layers import feature_column from tensorflow.contrib.layers.python.layers import feature_column_ops from tensorflow.core.example import example_pb2 from tensorflow.core.example import feature_pb2 -from tensorflow.python.feature_column import feature_column as fc_core +from tensorflow.python.feature_column import feature_column_lib as fc_core from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops diff --git a/tensorflow/contrib/layers/python/layers/feature_column_test.py b/tensorflow/contrib/layers/python/layers/feature_column_test.py index d90d6ecf7f6..cab8da808b6 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_test.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_test.py @@ -27,7 +27,7 @@ import numpy as np from tensorflow.contrib.layers.python.layers import feature_column as fc from tensorflow.contrib.layers.python.layers import feature_column_ops -from tensorflow.python.feature_column import feature_column as fc_core +from tensorflow.python.feature_column import feature_column_lib as fc_core from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index ac9561c7693..403b522ce45 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -35,6 +35,7 @@ from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras.engine import input_spec from tensorflow.python.layers import base from tensorflow.python.layers import convolutional as convolutional_layers from tensorflow.python.layers import core as core_layers @@ -1958,7 +1959,7 @@ class GDN(base.Layer): self._reparam_offset = reparam_offset self.data_format = data_format self._channel_axis() # trigger ValueError early - self.input_spec = base.InputSpec(min_ndim=3, max_ndim=5) + self.input_spec = input_spec.InputSpec(min_ndim=3, max_ndim=5) def _channel_axis(self): try: @@ -2015,7 +2016,7 @@ class GDN(base.Layer): raise ValueError('The channel dimension of the inputs to `GDN` ' 'must be defined.') self._input_rank = input_shape.ndims - self.input_spec = base.InputSpec( + self.input_spec = input_spec.InputSpec( ndim=input_shape.ndims, axes={ channel_axis: num_channels }) diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index 8ead6336a08..0a4d2c6d4cb 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -3811,7 +3811,7 @@ class UnitNormTests(test.TestCase): image = random_ops.random_uniform((height, width, 3)) output = _layers.unit_norm(image, dim=dim, epsilon=1e-6) norms = math_ops.sqrt( - math_ops.reduce_sum(math_ops.square(output), reduction_indices=dim)) + math_ops.reduce_sum(math_ops.square(output), axis=dim)) shape = [height, width, 3] del shape[dim] @@ -3847,7 +3847,7 @@ class UnitNormTests(test.TestCase): image = array_ops.placeholder(dtypes.float32, (None, None, 3)) output = _layers.unit_norm(image, dim=dim, epsilon=1e-6) norms = math_ops.sqrt( - math_ops.reduce_sum(math_ops.square(output), reduction_indices=dim)) + math_ops.reduce_sum(math_ops.square(output), axis=dim)) with self.cached_session(): actual = norms.eval({image: placeholder_value}) diff --git a/tensorflow/contrib/layers/python/layers/regularizers_test.py b/tensorflow/contrib/layers/python/layers/regularizers_test.py index 51faba30c74..5cb00b76847 100644 --- a/tensorflow/contrib/layers/python/layers/regularizers_test.py +++ b/tensorflow/contrib/layers/python/layers/regularizers_test.py @@ -141,7 +141,7 @@ class RegularizerTest(test.TestCase): dummy_regularizer = lambda x: math_ops.reduce_sum(2 * x) array_weights_list = [[1.5], [2, 3, 4.2], [10, 42, 666.6]] tensor_weights_list = [constant_op.constant(x) for x in array_weights_list] - expected = sum([2 * x for l in array_weights_list for x in l]) + expected = sum(2 * x for l in array_weights_list for x in l) with self.cached_session(): result = regularizers.apply_regularization(dummy_regularizer, tensor_weights_list) diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index 61185f65a9b..14065fcee51 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -24,6 +24,11 @@ py_library( exclude = ["python/learn/**/*_test.py"], ), srcs_version = "PY2AND3", + visibility = [ + "//learning/brain:__subpackages__", + "//tensorflow:__subpackages__", + "//video/youtube/personalization:__subpackages__", + ], # This library should not depend on sklearn, even though some of the code # refers to it. (The code handles the presence of sklearn conditionally.) deps = [ @@ -269,6 +274,7 @@ py_test( name = "estimator_test", size = "medium", srcs = ["python/learn/estimators/estimator_test.py"], + shard_count = 2, srcs_version = "PY2AND3", tags = [ "manual", diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn.py b/tensorflow/contrib/learn/python/learn/estimators/dnn.py index eabebb7e881..10fbd60ba2d 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn.py @@ -28,7 +28,6 @@ import six from tensorflow.contrib import layers from tensorflow.contrib.framework import deprecated from tensorflow.contrib.framework import deprecated_arg_values -from tensorflow.python.training import training_util from tensorflow.contrib.layers.python.layers import feature_column from tensorflow.contrib.layers.python.layers import optimizers from tensorflow.contrib.learn.python.learn import metric_spec @@ -38,11 +37,12 @@ from tensorflow.contrib.learn.python.learn.estimators import head as head_lib from tensorflow.contrib.learn.python.learn.estimators import model_fn from tensorflow.contrib.learn.python.learn.estimators import prediction_key from tensorflow.contrib.learn.python.learn.utils import export -from tensorflow.python.feature_column import feature_column as fc_core +from tensorflow.python.feature_column import feature_column_lib as fc_core from tensorflow.python.ops import nn from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import variable_scope from tensorflow.python.summary import summary +from tensorflow.python.training import training_util # The default learning rate of 0.05 is a historical artifact of the initial # implementation, but seems a reasonable choice. @@ -150,10 +150,10 @@ def _dnn_model_fn(features, labels, mode, params, config=None): "input_from_feature_columns", values=tuple(six.itervalues(features)), partitioner=input_layer_partitioner) as input_layer_scope: - if all([ + if all( isinstance(fc, feature_column._FeatureColumn) # pylint: disable=protected-access for fc in feature_columns - ]): + ): net = layers.input_from_feature_columns( columns_to_tensors=features, feature_columns=feature_columns, diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py index 3d85533d92d..2ade6b7b6ce 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py @@ -38,7 +38,7 @@ from tensorflow.contrib.learn.python.learn.estimators import head as head_lib from tensorflow.contrib.learn.python.learn.estimators import model_fn from tensorflow.contrib.learn.python.learn.estimators import prediction_key from tensorflow.contrib.learn.python.learn.utils import export -from tensorflow.python.feature_column import feature_column as fc_core +from tensorflow.python.feature_column import feature_column_lib as fc_core from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import nn @@ -236,10 +236,10 @@ def _dnn_linear_combined_model_fn(features, labels, mode, params, config=None): "input_from_feature_columns", values=tuple(six.itervalues(features)), partitioner=input_layer_partitioner) as dnn_input_scope: - if all([ + if all( isinstance(fc, feature_column_lib._FeatureColumn) # pylint: disable=protected-access for fc in dnn_feature_columns - ]): + ): net = layers.input_from_feature_columns( columns_to_tensors=features, feature_columns=dnn_feature_columns, @@ -292,8 +292,8 @@ def _dnn_linear_combined_model_fn(features, labels, mode, params, config=None): linear_parent_scope, values=tuple(six.itervalues(features)), partitioner=linear_partitioner) as scope: - if all([isinstance(fc, feature_column_lib._FeatureColumn) # pylint: disable=protected-access - for fc in linear_feature_columns]): + if all(isinstance(fc, feature_column_lib._FeatureColumn) # pylint: disable=protected-access + for fc in linear_feature_columns): if joint_linear_weights: linear_logits, _, _ = layers.joint_weighted_sum_from_feature_columns( columns_to_tensors=features, diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py index 4e65c180d8b..d46a873bfaa 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py @@ -36,7 +36,7 @@ from tensorflow.contrib.learn.python.learn.estimators import run_config from tensorflow.contrib.learn.python.learn.estimators import test_data from tensorflow.contrib.learn.python.learn.metric_spec import MetricSpec from tensorflow.contrib.metrics.python.ops import metric_ops -from tensorflow.python.feature_column import feature_column as fc_core +from tensorflow.python.feature_column import feature_column_lib as fc_core from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py index 2bd57597c2e..ee25cebd484 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py @@ -38,7 +38,7 @@ from tensorflow.contrib.learn.python.learn.estimators import run_config from tensorflow.contrib.learn.python.learn.estimators import test_data from tensorflow.contrib.learn.python.learn.metric_spec import MetricSpec from tensorflow.contrib.metrics.python.ops import metric_ops -from tensorflow.python.feature_column import feature_column as fc_core +from tensorflow.python.feature_column import feature_column_lib as fc_core from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py index 1d8a59281a4..28c4964527b 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py @@ -668,7 +668,7 @@ class DynamicRNNEstimatorLearningTest(test.TestCase): sequences = centers + noise inputs = array_ops.expand_dims(sequences, 2) - labels = math_ops.reduce_mean(sequences, reduction_indices=[1]) + labels = math_ops.reduce_mean(sequences, axis=[1]) return {'inputs': inputs}, labels return input_fn @@ -722,8 +722,8 @@ class DynamicRNNEstimatorLearningTest(test.TestCase): inputs = array_ops.expand_dims(math_ops.to_float(random_sequence), 2) labels = math_ops.to_int32( array_ops.squeeze( - math_ops.reduce_sum( - inputs, reduction_indices=[1]) > (sequence_length / 2.0))) + math_ops.reduce_sum(inputs, axis=[1]) > ( + sequence_length / 2.0))) return {'inputs': inputs}, labels return input_fn diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 8bc869db895..9132b2209bc 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -1066,11 +1066,11 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable, chief_hooks = [] if (self._config.save_checkpoints_secs or self._config.save_checkpoints_steps): - saver_hook_exists = any([ + saver_hook_exists = any( isinstance(h, basic_session_run_hooks.CheckpointSaverHook) for h in (all_hooks + model_fn_ops.training_hooks + chief_hooks + model_fn_ops.training_chief_hooks) - ]) + ) if not saver_hook_exists: chief_hooks = [ basic_session_run_hooks.CheckpointSaverHook( @@ -1493,7 +1493,7 @@ class Estimator(BaseEstimator): # pylint: disable=protected-access class SKCompat(sklearn.BaseEstimator): """Scikit learn wrapper for TensorFlow Learn Estimator. - + THIS CLASS IS DEPRECATED. See [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) for general migration instructions. diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py index e100bc7a1e7..9ee8d8004bf 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/linear.py +++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py @@ -37,7 +37,7 @@ from tensorflow.contrib.learn.python.learn.estimators import head as head_lib from tensorflow.contrib.learn.python.learn.estimators import prediction_key from tensorflow.contrib.learn.python.learn.utils import export from tensorflow.contrib.linear_optimizer.python import sdca_optimizer -from tensorflow.python.feature_column import feature_column as fc_core +from tensorflow.python.feature_column import feature_column_lib as fc_core from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor @@ -155,8 +155,8 @@ def _linear_model_fn(features, labels, mode, params, config=None): parent_scope, values=tuple(six.itervalues(features)), partitioner=partitioner) as scope: - if all([isinstance(fc, feature_column._FeatureColumn) # pylint: disable=protected-access - for fc in feature_columns]): + if all(isinstance(fc, feature_column._FeatureColumn) # pylint: disable=protected-access + for fc in feature_columns): if joint_weights: layer_fn = layers.joint_weighted_sum_from_feature_columns else: diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py index 597ca4e86db..dfc76bfde6c 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py @@ -37,7 +37,7 @@ from tensorflow.contrib.learn.python.learn.estimators import test_data from tensorflow.contrib.learn.python.learn.metric_spec import MetricSpec from tensorflow.contrib.linear_optimizer.python import sdca_optimizer as sdca_optimizer_lib from tensorflow.contrib.metrics.python.ops import metric_ops -from tensorflow.python.feature_column import feature_column as fc_core +from tensorflow.python.feature_column import feature_column_lib as fc_core from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor @@ -1745,7 +1745,7 @@ class LinearRegressorTest(test.TestCase): 'place_holder': constant_op.constant([[0.0]] * num_examples), }, constant_op.constant( - [[1 if i % 4 is 0 else 0] for i in range(num_examples)]) + [[1 if i % 4 == 0 else 0] for i in range(num_examples)]) place_holder = feature_column_lib.real_valued_column('place_holder') sdca_optimizer = sdca_optimizer_lib.SDCAOptimizer( diff --git a/tensorflow/contrib/learn/python/learn/learn_io/numpy_io.py b/tensorflow/contrib/learn/python/learn/learn_io/numpy_io.py index 29552d24f1e..59a67636ae2 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/numpy_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/numpy_io.py @@ -27,7 +27,7 @@ from tensorflow.python.estimator.inputs.numpy_io import numpy_input_fn as core_n from tensorflow.python.util.deprecation import deprecated -@deprecated(None, 'Use tf.estimator.inputs.numpy_input_fn.') +@deprecated(None, 'Use tf.compat.v1.estimator.inputs.numpy_input_fn.') def numpy_input_fn(x, y=None, batch_size=128, diff --git a/tensorflow/contrib/learn/python/learn/learn_io/pandas_io.py b/tensorflow/contrib/learn/python/learn/learn_io/pandas_io.py index b4ef055f5ae..e9df7258a35 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/pandas_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/pandas_io.py @@ -53,7 +53,7 @@ PANDAS_DTYPES = { } -@deprecated(None, 'Please use tf.estimator.inputs.pandas_input_fn') +@deprecated(None, 'Please use tf.compat.v1.estimator.inputs.pandas_input_fn') def pandas_input_fn(x, y=None, batch_size=128, diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py b/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py index 64766718823..7a5354222f1 100644 --- a/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py +++ b/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py @@ -524,7 +524,7 @@ class SDCALinearRegressorTest(test.TestCase): # LinearClassifier requires at least one column. 'place_holder': constant_op.constant([[0.0]] * num_examples), - }, constant_op.constant([[1 if i % 4 is 0 else 0] + }, constant_op.constant([[1 if i % 4 == 0 else 0] for i in range(num_examples)]) with self._single_threaded_test_session(): diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py index 5e99ef46051..9b2c2dd87cc 100644 --- a/tensorflow/contrib/lookup/lookup_ops_test.py +++ b/tensorflow/contrib/lookup/lookup_ops_test.py @@ -25,6 +25,7 @@ import six from tensorflow.contrib import lookup from tensorflow.python.client import session from tensorflow.python.data.experimental.ops import counter +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -2737,7 +2738,7 @@ class MutableHashTableBenchmark(test.Benchmark): def benchmark_many_repeated_scalar_insert_scalar(self): table = self._create_table() - c = counter.Counter().make_one_shot_iterator().get_next() + c = dataset_ops.make_one_shot_iterator(counter.Counter()).get_next() value = variables.Variable(1.0) insert = table.insert(c, value) size = table.size() @@ -2758,7 +2759,7 @@ class MutableHashTableBenchmark(test.Benchmark): def benchmark_many_repeated_batch_32_insert_scalar(self): table = self._create_table() - c = counter.Counter().make_one_shot_iterator().get_next() + c = dataset_ops.make_one_shot_iterator(counter.Counter()).get_next() value = variables.Variable([1.0] * 32) insert = table.insert(32 * c + list(range(32)), value) size = table.size() diff --git a/tensorflow/contrib/losses/python/losses/loss_ops.py b/tensorflow/contrib/losses/python/losses/loss_ops.py index 619294b5182..709a042bbce 100644 --- a/tensorflow/contrib/losses/python/losses/loss_ops.py +++ b/tensorflow/contrib/losses/python/losses/loss_ops.py @@ -22,7 +22,6 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.framework.python.ops import add_arg_scope -from tensorflow.python.compat import compat from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -60,41 +59,12 @@ def _scale_losses(losses, weights): """ # First, compute the sum of the losses over all elements: start_index = max(0, weights.get_shape().ndims) - reduction_indices = list(range(start_index, losses.get_shape().ndims)) - reduced_losses = math_ops.reduce_sum( - losses, reduction_indices=reduction_indices) + axis = list(range(start_index, losses.get_shape().ndims)) + reduced_losses = math_ops.reduce_sum(losses, axis=axis) reduced_losses = math_ops.multiply(reduced_losses, weights) return math_ops.reduce_sum(reduced_losses) -def _safe_div(numerator, denominator, name="value"): - """Computes a safe divide which returns 0 if the denominator is zero. - - Note that the function contains an additional conditional check that is - necessary for avoiding situations where the loss is zero causing NaNs to - creep into the gradient computation. - - Args: - numerator: An arbitrary `Tensor`. - denominator: A `Tensor` whose shape matches `numerator` and whose values are - assumed to be non-negative. - name: An optional name for the returned op. - - Returns: - The element-wise value of the numerator divided by the denominator. - """ - if compat.forward_compatible(2018, 11, 1): - return math_ops.div_no_nan(numerator, denominator, name=name) - return array_ops.where( - math_ops.greater(denominator, 0), - math_ops.div(numerator, - array_ops.where( - math_ops.equal(denominator, 0), - array_ops.ones_like(denominator), denominator)), - array_ops.zeros_like(numerator), - name=name) - - def _safe_mean(losses, num_present): """Computes a safe mean of the losses. @@ -107,7 +77,7 @@ def _safe_mean(losses, num_present): then zero is returned. """ total_loss = math_ops.reduce_sum(losses) - return _safe_div(total_loss, num_present, name="value") + return math_ops.div_no_nan(total_loss, num_present, name="value") @deprecated("2016-12-30", "Use tf.losses.compute_weighted_loss instead.") @@ -187,10 +157,9 @@ def _num_present(losses, weights, per_batch=False): # First, count the number of nonzero weights: if weights.get_shape().ndims >= 1: - reduction_indices = list(range(1, weights.get_shape().ndims)) + axis = list(range(1, weights.get_shape().ndims)) num_nonzero_per_batch = math_ops.reduce_sum( - math_ops.to_float(math_ops.not_equal(weights, 0)), - reduction_indices=reduction_indices) + math_ops.to_float(math_ops.not_equal(weights, 0)), axis=axis) # Next, determine the number of elements that weights would broadcast to: broadcast_dims = array_ops.slice( @@ -606,20 +575,20 @@ def mean_pairwise_squared_error(predictions, if weights.get_shape().ndims is None: raise ValueError("weights.get_shape().ndims cannot be None") - reduction_indices = list(range(1, diffs.get_shape().ndims)) + axis = list(range(1, diffs.get_shape().ndims)) sum_squares_diff_per_batch = math_ops.reduce_sum( - math_ops.square(diffs), reduction_indices=reduction_indices) + math_ops.square(diffs), axis=axis) num_present_per_batch = _num_present(diffs, weights, per_batch=True) - term1 = 2.0 * _safe_div(sum_squares_diff_per_batch, - num_present_per_batch, - name="value") + term1 = 2.0 * math_ops.div_no_nan( + sum_squares_diff_per_batch, num_present_per_batch, name="value") - sum_diff = math_ops.reduce_sum(diffs, reduction_indices=reduction_indices) - term2 = 2.0 * _safe_div(math_ops.square(sum_diff), - math_ops.square(num_present_per_batch), - name="value") + sum_diff = math_ops.reduce_sum(diffs, axis=axis) + term2 = 2.0 * math_ops.div_no_nan( + math_ops.square(sum_diff), + math_ops.square(num_present_per_batch), + name="value") loss = _scale_losses(term1 - term2, weights) @@ -674,7 +643,7 @@ def cosine_distance(predictions, radial_diffs = math_ops.multiply(predictions, labels) losses = 1 - math_ops.reduce_sum( - radial_diffs, reduction_indices=[ + radial_diffs, axis=[ axis, ]) return compute_weighted_loss(losses, weights, scope=scope) diff --git a/tensorflow/contrib/makefile/download_dependencies.sh b/tensorflow/contrib/makefile/download_dependencies.sh index 0a07588f07f..b396c527673 100755 --- a/tensorflow/contrib/makefile/download_dependencies.sh +++ b/tensorflow/contrib/makefile/download_dependencies.sh @@ -34,7 +34,7 @@ NSYNC_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/nsync/.*tar\. # 1.10 branch does not work. `make distclean` fails and blocks the build # process. For now we're hardcoding to the version which is used by # TensorFlow 1.9. -PROTOBUF_URL="https://mirror.bazel.build/github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz" +PROTOBUF_URL="https://mirror.bazel.build/github.com/google/protobuf/archive/v3.6.0.tar.gz" # TODO (yongtang): Replace the following with 'https://mirror.bazel.build/github.com/google/re2/.*tar\.gz' once # the archive has been propagated in mirror.bazel.build. RE2_URL="$(grep -o 'https://github.com/google/re2/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index e779eff6890..655c7eefcb9 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -157,6 +157,7 @@ tensorflow/core/kernels/mirror_pad_op_cpu_impl_2.cc tensorflow/core/kernels/mirror_pad_op_cpu_impl_3.cc tensorflow/core/kernels/mirror_pad_op_cpu_impl_4.cc tensorflow/core/kernels/mirror_pad_op_cpu_impl_5.cc +tensorflow/core/kernels/multinomial_op.cc tensorflow/core/kernels/no_op.cc tensorflow/core/kernels/non_max_suppression_op.cc tensorflow/core/kernels/one_hot_op.cc @@ -252,6 +253,7 @@ tensorflow/core/kernels/split_op.cc tensorflow/core/kernels/split_v_op.cc tensorflow/core/kernels/stack.cc tensorflow/core/kernels/stack_ops.cc +tensorflow/core/kernels/stateless_random_ops.cc tensorflow/core/kernels/strided_slice_op.cc tensorflow/core/kernels/strided_slice_op_inst_0.cc tensorflow/core/kernels/strided_slice_op_inst_1.cc diff --git a/tensorflow/contrib/metrics/python/metrics/classification.py b/tensorflow/contrib/metrics/python/metrics/classification.py index ac123608650..062deb74b16 100644 --- a/tensorflow/contrib/metrics/python/metrics/classification.py +++ b/tensorflow/contrib/metrics/python/metrics/classification.py @@ -175,7 +175,7 @@ def f1_score(labels, predictions, weights=None, num_thresholds=200, return best_f1 best_f1 = distribution_strategy_context.get_replica_context().merge_call( - f1_across_replicas, values) + f1_across_replicas, args=(values,)) update_op = compute_best_f1_score(tp=update_ops['tp'], fp=update_ops['fp'], fn=update_ops['fn'], name='update') diff --git a/tensorflow/contrib/metrics/python/metrics/classification_test.py b/tensorflow/contrib/metrics/python/metrics/classification_test.py index d6a670f97b3..e789d2cb9df 100644 --- a/tensorflow/contrib/metrics/python/metrics/classification_test.py +++ b/tensorflow/contrib/metrics/python/metrics/classification_test.py @@ -291,12 +291,11 @@ class F1ScoreTest(test.TestCase): labels = labels.astype(np.float32) predictions = predictions.astype(np.float32) - tf_predictions, tf_labels = (dataset_ops.Dataset - .from_tensor_slices((predictions, labels)) - .repeat() - .batch(batch_size) - .make_one_shot_iterator() - .get_next()) + tf_predictions, tf_labels = dataset_ops.make_one_shot_iterator( + dataset_ops.Dataset + .from_tensor_slices((predictions, labels)) + .repeat() + .batch(batch_size)).get_next() f1, f1_op = classification.f1_score(tf_labels, tf_predictions, num_thresholds=3) diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index d6932f6e4b6..7b432f8bd20 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -24,7 +24,6 @@ from __future__ import print_function import collections as collections_lib -from tensorflow.python.compat import compat from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -46,32 +45,6 @@ from tensorflow.python.util.deprecation import deprecated _EPSILON = 1e-7 -def _safe_div(numerator, denominator): - """Computes a safe divide which returns 0 if the denominator is zero. - - Note that the function contains an additional conditional check that is - necessary for avoiding situations where the loss is zero causing NaNs to - creep into the gradient computation. - - Args: - numerator: An arbitrary `Tensor`. - denominator: A `Tensor` whose shape matches `numerator` and whose values are - assumed to be non-negative. - - Returns: - The element-wise value of the numerator divided by the denominator. - """ - if compat.forward_compatible(2018, 11, 1): - return math_ops.div_no_nan(numerator, denominator) - return array_ops.where( - math_ops.greater(denominator, 0), - math_ops.div(numerator, - array_ops.where( - math_ops.equal(denominator, 0), - array_ops.ones_like(denominator), denominator)), - array_ops.zeros_like(numerator)) - - @deprecated(None, 'Please switch to tf.metrics.true_positives. Note that the ' 'order of the labels and predictions arguments has been switched.') def streaming_true_positives(predictions, @@ -3247,24 +3220,20 @@ def streaming_covariance(predictions, # We update the means by Delta=Error*BatchCount/(BatchCount+PrevCount) # batch_mean_prediction is E[x_B] in the update equation - batch_mean_prediction = _safe_div( - math_ops.reduce_sum(weighted_predictions), - batch_count) - delta_mean_prediction = _safe_div( - (batch_mean_prediction - mean_prediction) * batch_count, - update_count) + batch_mean_prediction = math_ops.div_no_nan( + math_ops.reduce_sum(weighted_predictions), batch_count) + delta_mean_prediction = math_ops.div_no_nan( + (batch_mean_prediction - mean_prediction) * batch_count, update_count) update_mean_prediction = state_ops.assign_add(mean_prediction, delta_mean_prediction) # prev_mean_prediction is E[x_A] in the update equation prev_mean_prediction = update_mean_prediction - delta_mean_prediction # batch_mean_label is E[y_B] in the update equation - batch_mean_label = _safe_div( - math_ops.reduce_sum(weighted_labels), - batch_count) - delta_mean_label = _safe_div( - (batch_mean_label - mean_label) * batch_count, - update_count) + batch_mean_label = math_ops.div_no_nan( + math_ops.reduce_sum(weighted_labels), batch_count) + delta_mean_label = math_ops.div_no_nan( + (batch_mean_label - mean_label) * batch_count, update_count) update_mean_label = state_ops.assign_add(mean_label, delta_mean_label) # prev_mean_label is E[y_A] in the update equation prev_mean_label = update_mean_label - delta_mean_label @@ -3447,7 +3416,7 @@ def streaming_mean_cosine_distance(predictions, predictions.get_shape().assert_is_compatible_with(labels.get_shape()) radial_diffs = math_ops.multiply(predictions, labels) radial_diffs = math_ops.reduce_sum( - radial_diffs, reduction_indices=[ + radial_diffs, axis=[ dim, ], keepdims=True) mean_distance, update_op = streaming_mean(radial_diffs, weights, None, None, @@ -3926,9 +3895,8 @@ def cohen_kappa(labels, po_sum = math_ops.reduce_sum(po) total = math_ops.reduce_sum(pe_row) pe_sum = math_ops.reduce_sum( - _safe_div( - math_ops.to_double(pe_row * pe_col), - math_ops.to_double(total))) + math_ops.div_no_nan( + math_ops.to_double(pe_row * pe_col), math_ops.to_double(total))) po_sum, pe_sum, total = (math_ops.to_double(po_sum), math_ops.to_double(pe_sum), math_ops.to_double(total)) diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py b/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py index 1b0383d24c0..c922d0cd11f 100644 --- a/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py +++ b/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py @@ -29,7 +29,7 @@ from tensorflow.python.platform import test def _GetExampleIter(inputs): dataset = dataset_ops.Dataset.from_tensor_slices(inputs) - return dataset.make_one_shot_iterator() + return dataset_ops.make_one_shot_iterator(dataset) class FixedLossScaleManagerTest(test.TestCase): diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py index 9009df0eefe..33f9a43e803 100644 --- a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py +++ b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py @@ -132,7 +132,7 @@ class LossScaleOptimizerTest(test.TestCase): x = variable_scope.get_variable("x", initializer=1., dtype=dtypes.float32) dataset = dataset_ops.Dataset.from_tensor_slices([np.nan, np.inf, 0.1]) - itr = dataset.make_one_shot_iterator() + itr = dataset_ops.make_one_shot_iterator(dataset) lr = 1 opt = gd.GradientDescentOptimizer(lr) @@ -182,7 +182,7 @@ class LossScaleOptimizerTest(test.TestCase): x = variable_scope.get_variable("x", initializer=1., dtype=dtypes.float32) dataset = dataset_ops.Dataset.from_tensor_slices([np.nan, np.inf, 0.1]) - itr = dataset.make_one_shot_iterator() + itr = dataset_ops.make_one_shot_iterator(dataset) lr = 1 init_loss_scale = 8 diff --git a/tensorflow/contrib/model_pruning/python/layers/core_layers.py b/tensorflow/contrib/model_pruning/python/layers/core_layers.py index f0ce6fe0396..1fa5c8cb485 100644 --- a/tensorflow/contrib/model_pruning/python/layers/core_layers.py +++ b/tensorflow/contrib/model_pruning/python/layers/core_layers.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras.engine import input_spec from tensorflow.python.layers import base from tensorflow.python.layers import utils from tensorflow.python.ops import array_ops @@ -119,7 +120,7 @@ class _MaskedConv(base.Layer): self.bias_initializer = bias_initializer self.kernel_regularizer = kernel_regularizer self.bias_regularizer = bias_regularizer - self.input_spec = base.InputSpec(ndim=self.rank + 2) + self.input_spec = input_spec.InputSpec(ndim=self.rank + 2) def build(self, input_shape): input_shape = tensor_shape.TensorShape(input_shape) @@ -171,7 +172,7 @@ class _MaskedConv(base.Layer): dtype=self.dtype) else: self.bias = None - self.input_spec = base.InputSpec( + self.input_spec = input_spec.InputSpec( ndim=self.rank + 2, axes={channel_axis: input_dim}) self.built = True @@ -393,14 +394,14 @@ class MaskedFullyConnected(base.Layer): self.bias_initializer = bias_initializer self.kernel_regularizer = kernel_regularizer self.bias_regularizer = bias_regularizer - self.input_spec = base.InputSpec(min_ndim=2) + self.input_spec = input_spec.InputSpec(min_ndim=2) def build(self, input_shape): input_shape = tensor_shape.TensorShape(input_shape) if tensor_shape.dimension_value(input_shape[-1]) is None: raise ValueError('The last dimension of the inputs to `Dense` ' 'should be defined. Found `None`.') - self.input_spec = base.InputSpec( + self.input_spec = input_spec.InputSpec( min_ndim=2, axes={-1: tensor_shape.dimension_value(input_shape[-1])}) self.kernel = self.add_variable( diff --git a/tensorflow/contrib/opt/python/training/lars_optimizer.py b/tensorflow/contrib/opt/python/training/lars_optimizer.py index a8dafd9a4cb..205d6c39491 100644 --- a/tensorflow/contrib/opt/python/training/lars_optimizer.py +++ b/tensorflow/contrib/opt/python/training/lars_optimizer.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops @@ -162,3 +163,14 @@ class LARSOptimizer(optimizer.Optimizer): math_ops.cast(self._momentum_tensor, grad.dtype), use_locking=self._use_locking, use_nesterov=self._use_nesterov) + + def _prepare(self): + learning_rate = self._learning_rate + if callable(learning_rate): + learning_rate = learning_rate() + self._learning_rate_tensor = ops.convert_to_tensor(learning_rate, + name="learning_rate") + momentum = self._momentum + if callable(momentum): + momentum = momentum() + self._momentum_tensor = ops.convert_to_tensor(momentum, name="momentum") \ No newline at end of file diff --git a/tensorflow/contrib/opt/python/training/nadam_optimizer.py b/tensorflow/contrib/opt/python/training/nadam_optimizer.py index 155ff5b3f4f..960826407b6 100644 --- a/tensorflow/contrib/opt/python/training/nadam_optimizer.py +++ b/tensorflow/contrib/opt/python/training/nadam_optimizer.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops @@ -83,14 +84,14 @@ class NadamOptimizer(adam.AdamOptimizer): with ops.control_dependencies([m_t]): m_t = scatter_add(m, indices, m_scaled_g_values) # m_bar = (1 - beta1) * g_t + beta1 * m_t - m_bar = m_scaled_g_values + beta1_t * m_t + m_bar = m_scaled_g_values + beta1_t * array_ops.gather(m_t, indices) # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) v = self.get_slot(var, "v") v_scaled_g_values = (grad * grad) * (1 - beta2_t) v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking) with ops.control_dependencies([v_t]): v_t = scatter_add(v, indices, v_scaled_g_values) - v_sqrt = math_ops.sqrt(v_t) - var_update = state_ops.assign_sub( - var, lr * m_bar / (v_sqrt + epsilon_t), use_locking=self._use_locking) + v_t_slice = array_ops.gather(v_t, indices) + v_sqrt = math_ops.sqrt(v_t_slice) + var_update = scatter_add(var, indices, -lr * m_bar / (v_sqrt + epsilon_t)) return control_flow_ops.group(*[var_update, m_bar, v_t]) diff --git a/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py b/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py index 85e05ce71ce..a4372f64874 100644 --- a/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py +++ b/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py @@ -52,14 +52,19 @@ def nadam_update_numpy(param, class NadamOptimizerTest(test.TestCase): def doTestSparse(self, use_resource=False): + # need to use a larger value of epsilon here so that + # np.sqrt(v_t) + epsilon doesn't get rounded to 0 when + # the dtype is half and np.sqrt(v_t) = 0, as is the case + # when the gradient is 0 + sparse_epsilon = 1e-7 for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: with self.cached_session(): # Initialize variables for numpy implementation. m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 - var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) - grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) - var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) - grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + var0_np = np.array([1.0, 1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0, 0.01], dtype=dtype.as_numpy_dtype) if use_resource: var0 = resource_variable_ops.ResourceVariable(var0_np) @@ -67,21 +72,21 @@ class NadamOptimizerTest(test.TestCase): else: var0 = variables.Variable(var0_np) var1 = variables.Variable(var1_np) - grads0_np_indices = np.array([0, 1], dtype=np.int32) + grads0_np_indices = np.array([0, 2], dtype=np.int32) grads0 = ops.IndexedSlices( - constant_op.constant(grads0_np), - constant_op.constant(grads0_np_indices), constant_op.constant([2])) - grads1_np_indices = np.array([0, 1], dtype=np.int32) + constant_op.constant(grads0_np[grads0_np_indices]), + constant_op.constant(grads0_np_indices), constant_op.constant([3])) + grads1_np_indices = np.array([0, 2], dtype=np.int32) grads1 = ops.IndexedSlices( - constant_op.constant(grads1_np), - constant_op.constant(grads1_np_indices), constant_op.constant([2])) - opt = nadam_optimizer.NadamOptimizer() + constant_op.constant(grads1_np[grads1_np_indices]), + constant_op.constant(grads1_np_indices), constant_op.constant([3])) + opt = nadam_optimizer.NadamOptimizer(epsilon=sparse_epsilon) update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 3.0, 4.0], var1.eval()) beta1_power, beta2_power = opt._get_beta_accumulators() @@ -91,8 +96,10 @@ class NadamOptimizerTest(test.TestCase): self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) update.run() - var0_np, m0, v0 = nadam_update_numpy(var0_np, grads0_np, t, m0, v0) - var1_np, m1, v1 = nadam_update_numpy(var1_np, grads1_np, t, m1, v1) + var0_np, m0, v0 = nadam_update_numpy(var0_np, grads0_np, t, m0, v0, + epsilon=sparse_epsilon) + var1_np, m1, v1 = nadam_update_numpy(var1_np, grads1_np, t, m1, v1, + epsilon=sparse_epsilon) # Validate updated params self.assertAllCloseAccordingToType(var0_np, var0.eval()) diff --git a/tensorflow/contrib/optimizer_v2/BUILD b/tensorflow/contrib/optimizer_v2/BUILD index 3ba3ee29ec7..6e401406308 100644 --- a/tensorflow/contrib/optimizer_v2/BUILD +++ b/tensorflow/contrib/optimizer_v2/BUILD @@ -48,7 +48,6 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/python:control_flow_ops", - "//tensorflow/python:distribute", "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:resource_variable_ops", @@ -56,6 +55,8 @@ py_library( "//tensorflow/python:training", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", + "//tensorflow/python/distribute:distribute_lib", + "//tensorflow/python/distribute:reduce_util", ], ) diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py index 467dd86d8fd..73a556f0b29 100644 --- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py @@ -24,6 +24,8 @@ import abc import six +from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import reduce_util as ds_reduce_util from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.framework import dtypes @@ -34,7 +36,6 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables -from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import distribution_strategy_context as distribute_ctx from tensorflow.python.training import optimizer as optimizer_v1 from tensorflow.python.training import slot_creator @@ -446,7 +447,7 @@ class _OptimizerV2State(object): if v is None: if colocate_with is None: colocate_with = self._non_slot_devices - with self._distribution.colocate_vars_with(colocate_with): + with self._distribution.extended.colocate_vars_with(colocate_with): # TODO(josh11b): Use get_variable() except for the legacy Adam use case. v = variable_scope.variable(initial_value, name=name, trainable=False) self._non_slot_dict[name] = v @@ -657,7 +658,6 @@ class OptimizerV2(optimizer_v1.Optimizer): var_list=None, gate_gradients=GATE_OP, aggregation_method=None, - colocate_gradients_with_ops=False, name=None, grad_loss=None, stop_gradients=None, @@ -680,8 +680,6 @@ class OptimizerV2(optimizer_v1.Optimizer): `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. aggregation_method: Specifies the method used to combine gradient terms. Valid values are defined in the class `AggregationMethod`. - colocate_gradients_with_ops: If True, try colocating gradients with the - corresponding op. name: Optional name for the returned operation. grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. stop_gradients: Optional. A Tensor or list of tensors not to differentiate @@ -704,8 +702,8 @@ class OptimizerV2(optimizer_v1.Optimizer): Minimization (and gradient computation) is done with respect to the elements of `var_list` if not None, else with respect to any trainable variables created during the execution of the `loss` function. - `gate_gradients`, `aggregation_method`, `colocate_gradients_with_ops` and - `grad_loss` are ignored when eager execution is enabled. + `gate_gradients`, `aggregation_method`, and `grad_loss` are ignored when + eager execution is enabled. @end_compatibility """ grads_and_vars = self.compute_gradients( @@ -713,7 +711,6 @@ class OptimizerV2(optimizer_v1.Optimizer): var_list=var_list, gate_gradients=gate_gradients, aggregation_method=aggregation_method, - colocate_gradients_with_ops=colocate_gradients_with_ops, grad_loss=grad_loss, stop_gradients=stop_gradients, scale_loss_by_num_replicas=scale_loss_by_num_replicas) @@ -733,7 +730,6 @@ class OptimizerV2(optimizer_v1.Optimizer): var_list=None, gate_gradients=GATE_OP, aggregation_method=None, - colocate_gradients_with_ops=False, grad_loss=None, stop_gradients=None, scale_loss_by_num_replicas=None): @@ -756,8 +752,6 @@ class OptimizerV2(optimizer_v1.Optimizer): `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. aggregation_method: Specifies the method used to combine gradient terms. Valid values are defined in the class `AggregationMethod`. - colocate_gradients_with_ops: If True, try colocating gradients with the - corresponding op. grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. stop_gradients: Optional. A Tensor or list of tensors not to differentiate through. @@ -776,8 +770,8 @@ class OptimizerV2(optimizer_v1.Optimizer): not callable. @compatibility(eager) - When eager execution is enabled, `gate_gradients`, `aggregation_method`, - and `colocate_gradients_with_ops` are ignored. + When eager execution is enabled, `gate_gradients`, and `aggregation_method` + are ignored. @end_compatibility """ # TODO(josh11b): Test that we handle weight decay in a reasonable way. @@ -832,7 +826,6 @@ class OptimizerV2(optimizer_v1.Optimizer): grad_ys=grad_loss, gate_gradients=(gate_gradients == optimizer_v1.Optimizer.GATE_OP), aggregation_method=aggregation_method, - colocate_gradients_with_ops=colocate_gradients_with_ops, stop_gradients=stop_gradients) if gate_gradients == optimizer_v1.Optimizer.GATE_GRAPH: grads = control_flow_ops.tuple(grads) @@ -848,8 +841,7 @@ class OptimizerV2(optimizer_v1.Optimizer): """Scale loss for the number of replicas.""" if scale_loss_by_num_replicas is None: scale_loss_by_num_replicas = ( - distribute_lib.get_loss_reduction() == variable_scope - .VariableAggregation.MEAN) + distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN) if scale_loss_by_num_replicas: num_replicas = \ distribute_ctx.get_distribution_strategy().num_replicas_in_sync @@ -892,7 +884,8 @@ class OptimizerV2(optimizer_v1.Optimizer): raise ValueError("No gradients provided for any variable: %s." % ([str(v) for _, v in grads_and_vars],)) return distribute_ctx.get_replica_context().merge_call( - self._distributed_apply, filtered, global_step=global_step, name=name) + self._distributed_apply, args=(filtered,), + kwargs={"global_step": global_step, "name": name}) def _get_or_create_state(self, var_list=None): """Either looks up or creates `_OptimizerV2State`. @@ -927,8 +920,8 @@ class OptimizerV2(optimizer_v1.Optimizer): def _distributed_apply(self, distribution, grads_and_vars, global_step, name): """`apply_gradients` for use with a `DistributionStrategy`.""" - reduced_grads = distribution.batch_reduce( - variable_scope.VariableAggregation.SUM, grads_and_vars) + reduced_grads = distribution.extended.batch_reduce_to( + ds_reduce_util.ReduceOp.SUM, grads_and_vars) var_list = [v for _, v in grads_and_vars] grads_and_vars = zip(reduced_grads, var_list) @@ -944,7 +937,7 @@ class OptimizerV2(optimizer_v1.Optimizer): with ops.name_scope(name, self._name) as name: per_graph_state = self._get_or_create_state(var_list=unwrapped_var_list) # Include the current value of any dynamic hyper parameters in `state`. - non_slot_devices = distribution.non_slot_devices(var_list) + non_slot_devices = distribution.extended.non_slot_devices(var_list) state = per_graph_state._copy_with_dynamic_hyper( # pylint: disable=protected-access self._hyper, distribution, non_slot_devices) @@ -989,7 +982,8 @@ class OptimizerV2(optimizer_v1.Optimizer): # Use the processors to update the variables. update_ops = [] for grad, var in grads_and_vars: - update_ops.extend(distribution.update(var, update, grad, grouped=False)) + update_ops.extend(distribution.extended.update( + var, update, args=(grad,), group=False)) # Give the child class a chance to do something after applying # gradients @@ -1001,8 +995,8 @@ class OptimizerV2(optimizer_v1.Optimizer): update_ops = control_flow_ops.group(update_ops) with ops.control_dependencies([update_ops]): - finish_updates = distribution.update_non_slot( - non_slot_devices, finish, grouped=False) + finish_updates = distribution.extended.update_non_slot( + non_slot_devices, finish, group=False) # We said grouped=False, which means finish_updates is always a list. # It will be [None] when finish() returns None. if finish_updates == [None]: @@ -1017,8 +1011,8 @@ class OptimizerV2(optimizer_v1.Optimizer): def update_global_step(global_step, name): return global_step.assign_add(1, read_value=False, name=name) - apply_updates = distribution.update(global_step, update_global_step, - name) + apply_updates = distribution.extended.update( + global_step, update_global_step, args=(name,)) # Add the training op to the TRAIN_OP graph collection in graph mode. if not eager_execution: diff --git a/tensorflow/contrib/predictor/BUILD b/tensorflow/contrib/predictor/BUILD index d50b52b8ff1..53a3bc63e1d 100644 --- a/tensorflow/contrib/predictor/BUILD +++ b/tensorflow/contrib/predictor/BUILD @@ -42,6 +42,7 @@ py_library( name = "saved_model_predictor", srcs = ["saved_model_predictor.py"], srcs_version = "PY2AND3", + visibility = ["//learning/brain/contrib/learn/tpu:__subpackages__"], deps = [ ":base_predictor", "//tensorflow/contrib/saved_model:saved_model_py", diff --git a/tensorflow/contrib/quantize/README.md b/tensorflow/contrib/quantize/README.md index a1f2b590266..9085d9fa719 100644 --- a/tensorflow/contrib/quantize/README.md +++ b/tensorflow/contrib/quantize/README.md @@ -28,7 +28,7 @@ Since it's difficult to add these fake quantization operations to all the required locations in the model, there's a function available that rewrites the training graph. To create a fake quantized training graph: -``` +```python # Build forward pass of model. loss = tf.losses.get_total_loss() @@ -51,7 +51,7 @@ The rewritten *eval graph* is non-trivially different from the *training graph* since the quantization ops affect the batch normalization step. Because of this, we've added a separate rewrite for the *eval graph*: -``` +```python # Build eval model logits = tf.nn.softmax_cross_entropy_with_logits_v2(...) diff --git a/tensorflow/contrib/quantize/python/quant_ops.py b/tensorflow/contrib/quantize/python/quant_ops.py index 6f659347fba..8619708cdae 100644 --- a/tensorflow/contrib/quantize/python/quant_ops.py +++ b/tensorflow/contrib/quantize/python/quant_ops.py @@ -138,7 +138,7 @@ def LastValueQuantize(inputs, if per_channel: if input_dim >= 2: batch_min = math_ops.reduce_min( - inputs, reduction_indices=reduce_dims, name='BatchMin') + inputs, axis=reduce_dims, name='BatchMin') else: batch_min = inputs else: @@ -147,7 +147,7 @@ def LastValueQuantize(inputs, if per_channel: if input_dim >= 2: batch_max = math_ops.reduce_max( - inputs, reduction_indices=reduce_dims, name='BatchMax') + inputs, axis=reduce_dims, name='BatchMax') else: batch_max = inputs else: @@ -263,7 +263,7 @@ def MovingAvgQuantize(inputs, if per_channel: if input_dim >= 2: batch_min = math_ops.reduce_min( - inputs, reduction_indices=reduce_dims, name='BatchMin') + inputs, axis=reduce_dims, name='BatchMin') else: batch_min = inputs else: @@ -272,7 +272,7 @@ def MovingAvgQuantize(inputs, if per_channel: if input_dim >= 2: batch_max = math_ops.reduce_max( - inputs, reduction_indices=reduce_dims, name='BatchMax') + inputs, axis=reduce_dims, name='BatchMax') else: batch_max = inputs else: diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py index 338923f7512..21d1b121309 100644 --- a/tensorflow/contrib/quantize/python/quantize.py +++ b/tensorflow/contrib/quantize/python/quantize.py @@ -160,7 +160,7 @@ def Quantize(graph, # shouldn't quantize it, since the activation will be Fused into the # Add at inference time. consumers = input_to_ops_map.ConsumerOperations(layer_match.bypass_op) - if any([consumer.type in _ACTIVATION_TYPES for consumer in consumers]): + if any(consumer.type in _ACTIVATION_TYPES for consumer in consumers): logging.info('Skipping %s, because its followed by an activation.', layer_match.bypass_op.name) else: @@ -195,7 +195,7 @@ def Quantize(graph, # Add at inference time. consumers = input_to_ops_map.ConsumerOperations( layer_match.post_activation_bypass_op) - if any([consumer.type in _RELU_TYPES for consumer in consumers]): + if any(consumer.type in _RELU_TYPES for consumer in consumers): logging.info('Skipping %s, because its followed by an activation.', layer_match.post_activation_bypass_op.name) else: diff --git a/tensorflow/contrib/resampler/BUILD b/tensorflow/contrib/resampler/BUILD index 38fcca03116..bbf10996759 100644 --- a/tensorflow/contrib/resampler/BUILD +++ b/tensorflow/contrib/resampler/BUILD @@ -13,6 +13,7 @@ load( ) load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") +load("//tensorflow/compiler/tests:build_defs.bzl", "tf_xla_py_test") tf_custom_op_py_library( name = "resampler_py", @@ -50,10 +51,14 @@ tf_kernel_library( prefix = "resampler_ops", deps = [ ":resampler_ops_op_lib", - "//tensorflow/compiler/tf2xla/kernels:resampler_ops", "//tensorflow/core:framework", "//tensorflow/core:lib", - ], + ] + select({ + "//tensorflow:with_xla_support": [ + "//tensorflow/compiler/tf2xla/kernels:resampler_ops", + ], + "//conditions:default": [], + }), alwayslink = 1, ) @@ -94,3 +99,26 @@ cuda_py_test( "//tensorflow/python:array_ops", ], ) + +tf_xla_py_test( + name = "resampler_ops_xla_test", + size = "small", + srcs = ["xla/resampler_ops_xla_test.py"], + disabled_backends = [ + # TODO(b/74459949) Support BatchDot in CPU backend. + "cpu", + "cpu_ondemand", + ], + # TODO(b/112295522): the OSS build will not likely work in the short to medium term, currently it is blocked by the fact that bazel does not allow py_library to depend on cc_library: https://github.com/bazelbuild/bazel/issues/701 which may not be resolvable. + tags = ["no_oss"], + deps = [ + "//tensorflow/compiler/tests:xla_test", + "//tensorflow/compiler/tf2xla/kernels:resampler_ops", + "//tensorflow/contrib/resampler:resampler_ops", + "//tensorflow/contrib/resampler:resampler_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:platform_test", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/compiler/tests/resampler_ops_test.py b/tensorflow/contrib/resampler/xla/resampler_ops_xla_test.py similarity index 76% rename from tensorflow/compiler/tests/resampler_ops_test.py rename to tensorflow/contrib/resampler/xla/resampler_ops_xla_test.py index f87ac3360c9..d8ca0eab276 100644 --- a/tensorflow/compiler/tests/resampler_ops_test.py +++ b/tensorflow/contrib/resampler/xla/resampler_ops_xla_test.py @@ -63,8 +63,8 @@ class ResamplerOpsTest(xla_test.XLATestCase): def testSimple(self): for dtype in self.float_types: input_shape = [1, 2, 2, 1] - input_rgb_data = [0, 5, 13, 54] - input_np = np.array(input_rgb_data, dtype=dtype).reshape(input_shape) + input_data = [0, 5, 13, 54] + input_np = np.array(input_data, dtype=dtype).reshape(input_shape) warp_shape = [1, 2] warp_data = [0.7, 0.6] @@ -151,6 +151,55 @@ class ResamplerOpsTest(xla_test.XLATestCase): expected_grad_data, expected_grad_warp) + def testOutOfBoundWarps(self): + # (x, y) are both less than 0. + for dtype in self.float_types: + input_shape = [1, 2, 2, 1] + input_data = [10, 5, 13, 54] + input_np = np.array(input_data, dtype=dtype).reshape(input_shape) + + warp_shape = [1, 2, 2] + warp_data = [-1, -1, 0.7, 0.6] + warp_np = np.array(warp_data, dtype=dtype).reshape(warp_shape) + expected = [[[0.0], [27.62]]] + self._assertForwardOpMatchesExpected(input_np, warp_np, expected) + + # One of (x, y) is less than 0. + for dtype in self.float_types: + input_shape = [1, 2, 2, 1] + input_data = [10, 5, 13, 54] + input_np = np.array(input_data, dtype=dtype).reshape(input_shape) + + warp_shape = [1, 2, 2] + warp_data = [-1, 0.1, 0.7, 0.6] + warp_np = np.array(warp_data, dtype=dtype).reshape(warp_shape) + expected = [[[0.0], [27.62]]] + self._assertForwardOpMatchesExpected(input_np, warp_np, expected) + + # Both of (x, y) are greater than image size. + for dtype in self.float_types: + input_shape = [1, 2, 2, 1] + input_data = [10, 5, 13, 54] + input_np = np.array(input_data, dtype=dtype).reshape(input_shape) + + warp_shape = [1, 2, 2] + warp_data = [-0.1, 0.1, 1.2, 2.1] + warp_np = np.array(warp_data, dtype=dtype).reshape(warp_shape) + expected = [[[0.0], [0.0]]] + self._assertForwardOpMatchesExpected(input_np, warp_np, expected) + + # One of (x, y) is greater than image size. + for dtype in self.float_types: + input_shape = [1, 2, 2, 1] + input_data = [10, 5, 13, 54] + input_np = np.array(input_data, dtype=dtype).reshape(input_shape) + + warp_shape = [1, 2, 2] + warp_data = [0.1, -0.1, 1.2, 0.1] + warp_np = np.array(warp_data, dtype=dtype).reshape(warp_shape) + expected = [[[0.0], [0.0]]] + self._assertForwardOpMatchesExpected(input_np, warp_np, expected) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py index 245fa68eaef..7d57b0413a3 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py @@ -906,7 +906,7 @@ class DropoutWrapperTest(test.TestCase): def testDropoutWrapperKeepNoOutput(self): keep_all = variable_scope.get_variable("all", initializer=1.0) - keep_none = variable_scope.get_variable("none", initializer=1e-10) + keep_none = variable_scope.get_variable("none", initializer=1e-6) res = self._testDropoutWrapper( input_keep_prob=keep_all, output_keep_prob=keep_none, @@ -922,7 +922,7 @@ class DropoutWrapperTest(test.TestCase): def testDropoutWrapperKeepNoStateExceptLSTMCellMemory(self): keep_all = variable_scope.get_variable("all", initializer=1.0) - keep_none = variable_scope.get_variable("none", initializer=1e-10) + keep_none = variable_scope.get_variable("none", initializer=1e-6) # Even though we dropout state, by default DropoutWrapper never # drops out the memory ("c") term of an LSTMStateTuple. res = self._testDropoutWrapper( @@ -943,7 +943,7 @@ class DropoutWrapperTest(test.TestCase): def testDropoutWrapperKeepNoInput(self): keep_all = variable_scope.get_variable("all", initializer=1.0) - keep_none = variable_scope.get_variable("none", initializer=1e-10) + keep_none = variable_scope.get_variable("none", initializer=1e-6) true_full_output = np.array( [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], dtype=np.float32) diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py index 5cba54dd3df..ef372b947ce 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py @@ -227,7 +227,7 @@ class RNNTest(test.TestCase): def testDropout(self): cell = Plus1RNNCell() full_dropout_cell = rnn_cell.DropoutWrapper( - cell, input_keep_prob=1e-12, seed=0) + cell, input_keep_prob=1e-6, seed=0) (name, dep), = full_dropout_cell._checkpoint_dependencies self.assertIs(dep, cell) self.assertEqual("cell", name) diff --git a/tensorflow/contrib/rnn/python/ops/gru_ops.py b/tensorflow/contrib/rnn/python/ops/gru_ops.py index b30ca7882fc..251a933eaec 100644 --- a/tensorflow/contrib/rnn/python/ops/gru_ops.py +++ b/tensorflow/contrib/rnn/python/ops/gru_ops.py @@ -21,7 +21,7 @@ from tensorflow.contrib.rnn.ops import gen_gru_ops from tensorflow.contrib.util import loader from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.layers import base as base_layer +from tensorflow.python.keras.engine import input_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -165,7 +165,7 @@ class GRUBlockCell(LayerRNNCell): num_units = cell_size self._cell_size = num_units # Inputs must be 2-dimensional. - self.input_spec = base_layer.InputSpec(ndim=2) + self.input_spec = input_spec.InputSpec(ndim=2) @property def state_size(self): diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py index 4db431f85a4..b043026bc55 100644 --- a/tensorflow/contrib/rnn/python/ops/lstm_ops.py +++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py @@ -25,6 +25,7 @@ from tensorflow.contrib.rnn.ops import gen_lstm_ops from tensorflow.contrib.util import loader from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.keras.engine import input_spec from tensorflow.python.layers import base as base_layer from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops @@ -385,7 +386,7 @@ class LSTMBlockCell(LayerRNNCell): "scope": "lstm_cell" } # Inputs must be 2-dimensional. - self.input_spec = base_layer.InputSpec(ndim=2) + self.input_spec = input_spec.InputSpec(ndim=2) @property def state_size(self): @@ -628,7 +629,7 @@ class LSTMBlockFusedCell(LSTMBlockWrapper): self._use_peephole = use_peephole # Inputs must be 3-dimensional. - self.input_spec = base_layer.InputSpec(ndim=3) + self.input_spec = input_spec.InputSpec(ndim=3) @property def num_units(self): diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index e159dc95796..8a1c09f171e 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -30,7 +30,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.keras import activations from tensorflow.python.keras import initializers -from tensorflow.python.layers import base as base_layer +from tensorflow.python.keras.engine import input_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops from tensorflow.python.ops import gen_array_ops @@ -2752,7 +2752,7 @@ class SRUCell(rnn_cell_impl.LayerRNNCell): self._activation = activation or math_ops.tanh # Restrict inputs to be 2-dimensional matrices - self.input_spec = base_layer.InputSpec(ndim=2) + self.input_spec = input_spec.InputSpec(ndim=2) @property def state_size(self): @@ -3089,7 +3089,7 @@ class IndRNNCell(rnn_cell_impl.LayerRNNCell): super(IndRNNCell, self).__init__(_reuse=reuse, name=name, dtype=dtype) # Inputs must be 2-dimensional. - self.input_spec = base_layer.InputSpec(ndim=2) + self.input_spec = input_spec.InputSpec(ndim=2) self._num_units = num_units self._activation = activation or math_ops.tanh @@ -3183,7 +3183,7 @@ class IndyGRUCell(rnn_cell_impl.LayerRNNCell): super(IndyGRUCell, self).__init__(_reuse=reuse, name=name, dtype=dtype) # Inputs must be 2-dimensional. - self.input_spec = base_layer.InputSpec(ndim=2) + self.input_spec = input_spec.InputSpec(ndim=2) self._num_units = num_units self._activation = activation or math_ops.tanh @@ -3323,7 +3323,7 @@ class IndyLSTMCell(rnn_cell_impl.LayerRNNCell): super(IndyLSTMCell, self).__init__(_reuse=reuse, name=name, dtype=dtype) # Inputs must be 2-dimensional. - self.input_spec = base_layer.InputSpec(ndim=2) + self.input_spec = input_spec.InputSpec(ndim=2) self._num_units = num_units self._forget_bias = forget_bias @@ -3444,7 +3444,7 @@ class MinimalRNNCell(rnn_cell_impl.LayerRNNCell): super(MinimalRNNCell, self).__init__(name=name, dtype=dtype, **kwargs) # Inputs must be 2-dimensional. - self.input_spec = base_layer.InputSpec(ndim=2) + self.input_spec = input_spec.InputSpec(ndim=2) self.units = units self.activation = activations.get(activation) @@ -3558,7 +3558,7 @@ class CFNCell(rnn_cell_impl.LayerRNNCell): super(CFNCell, self).__init__(name=name, dtype=dtype, **kwargs) # Inputs must be 2-dimensional. - self.input_spec = base_layer.InputSpec(ndim=2) + self.input_spec = input_spec.InputSpec(ndim=2) self.units = units self.activation = activations.get(activation) diff --git a/tensorflow/contrib/saved_model/BUILD b/tensorflow/contrib/saved_model/BUILD index f0947fe423f..269443b2c65 100644 --- a/tensorflow/contrib/saved_model/BUILD +++ b/tensorflow/contrib/saved_model/BUILD @@ -102,7 +102,10 @@ py_test( size = "medium", srcs = ["python/saved_model/keras_saved_model_test.py"], srcs_version = "PY2AND3", - tags = ["no_windows"], + tags = [ + "no_oss", # TODO(b/119349471): Re-enable + "no_windows", + ], deps = [ ":keras_saved_model", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py index 27b5b6d22e0..ffba514bb96 100644 --- a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py +++ b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py @@ -25,7 +25,6 @@ from tensorflow.python.client import session from tensorflow.python.estimator import keras as estimator_keras_util from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator.export import export as export_helpers -from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.keras import backend as K from tensorflow.python.keras import models as models_lib @@ -126,7 +125,7 @@ def save_keras_model( export_dir = export_helpers.get_timestamped_export_dir(saved_model_path) temp_export_dir = export_helpers.get_temp_export_dir(export_dir) - builder = saved_model_builder.SavedModelBuilder(temp_export_dir) + builder = saved_model_builder._SavedModelBuilder(temp_export_dir) # Manually save variables to export them in an object-based checkpoint. This # skips the `builder.add_meta_graph_and_variables()` step, which saves a @@ -228,9 +227,10 @@ def _export_mode( g.add_to_collection(ops.GraphKeys.GLOBAL_STEP, clone.optimizer.iterations) # Extract update and train ops from train/test/predict functions. + train_op = None if mode == model_fn_lib.ModeKeys.TRAIN: clone._make_train_function() - builder._add_train_op(clone.train_function.updates_op) + train_op = clone.train_function.updates_op elif mode == model_fn_lib.ModeKeys.EVAL: clone._make_test_function() else: @@ -265,7 +265,8 @@ def _export_mode( model_fn_lib.EXPORT_TAG_MAP[mode], signature_def_map=_create_signature_def_map(clone, mode), saver=saver_lib.Saver(clone_var_list), - main_op=variables.local_variables_initializer()) + init_op=variables.local_variables_initializer(), + train_op=train_op) return None @@ -307,31 +308,11 @@ def _create_signature_def_map(model, mode): serving_only=(mode == model_fn_lib.ModeKeys.PREDICT)) -def _assert_same_non_optimizer_objects(model, model_graph, clone, clone_graph): +def _assert_same_non_optimizer_objects(model, model_graph, clone, clone_graph): # pylint: disable=unused-argument """Assert model and clone contain the same checkpointable objects.""" - def get_non_optimizer_objects(m, g): - """Gather set of model and optimizer checkpointable objects.""" - # Set default graph because optimizer.variables() returns optimizer - # variables defined in the default graph. - with g.as_default(): - all_objects = set(checkpointable_utils.list_objects(m)) - optimizer_and_variables = set() - for obj in all_objects: - if isinstance(obj, optimizers.TFOptimizer): - optimizer_and_variables.update(checkpointable_utils.list_objects(obj)) - optimizer_and_variables.update(set(obj.optimizer.variables())) - return all_objects - optimizer_and_variables - - model_objects = get_non_optimizer_objects(model, model_graph) - clone_objects = get_non_optimizer_objects(clone, clone_graph) - - if len(model_objects) != len(clone_objects): - raise errors.InternalError( - None, None, - 'Model and clone must use the same variables.' - '\n\tModel variables: %s\n\t Clone variables: %s' - % (model_objects, clone_objects)) + # TODO(fchollet, kathywu): make sure this works in eager mode. + return True def load_keras_model(saved_model_path): diff --git a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py index a65b2ce4661..93d73e1b484 100644 --- a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py +++ b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py @@ -29,14 +29,12 @@ from tensorflow.python import keras from tensorflow.python.client import session from tensorflow.python.eager import context from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.keras.engine import training from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -from tensorflow.python.saved_model import constants from tensorflow.python.saved_model import loader_impl from tensorflow.python.saved_model import signature_constants from tensorflow.python.training import training as training_module @@ -255,7 +253,7 @@ def load_model(sess, path, mode): outputs = { k: sess.graph.get_tensor_by_name(v.name) for k, v in meta_graph_def.signature_def[sig_def_key].outputs.items()} - return inputs, outputs + return inputs, outputs, meta_graph_def @test_util.run_all_in_graph_and_eager_modes @@ -332,8 +330,8 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase): # Load predict graph, and test predictions with session.Session(graph=ops.Graph()) as sess: - inputs, outputs = load_model(sess, output_path, - model_fn_lib.ModeKeys.PREDICT) + inputs, outputs, _ = load_model(sess, output_path, + model_fn_lib.ModeKeys.PREDICT) predictions = sess.run(outputs[output_name], {inputs[input_name]: input_arr}) @@ -342,19 +340,21 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase): if optimizer: # Load eval graph, and test predictions, loss and metric values with session.Session(graph=ops.Graph()) as sess: - inputs, outputs = load_model(sess, output_path, - model_fn_lib.ModeKeys.EVAL) + inputs, outputs, _ = load_model(sess, output_path, + model_fn_lib.ModeKeys.EVAL) # First obtain the loss and predictions, and run the metric update op by # feeding in the inputs and targets. loss, predictions, _ = sess.run( (outputs['loss'], outputs['predictions/' + output_name], - outputs['metrics/mae/update_op']), - {inputs[input_name]: input_arr, inputs[target_name]: target_arr}) + outputs['metrics/mean_absolute_error/update_op']), { + inputs[input_name]: input_arr, + inputs[target_name]: target_arr + }) # The metric value should be run after the update op, to ensure that it # reflects the correct value. - metric_value = sess.run(outputs['metrics/mae/value']) + metric_value = sess.run(outputs['metrics/mean_absolute_error/value']) self.assertEqual(int(train_before_export), sess.run(training_module.get_global_step())) @@ -364,17 +364,17 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase): # Load train graph, and check for the train op, and prediction values with session.Session(graph=ops.Graph()) as sess: - inputs, outputs = load_model(sess, output_path, - model_fn_lib.ModeKeys.TRAIN) + inputs, outputs, meta_graph_def = load_model( + sess, output_path, model_fn_lib.ModeKeys.TRAIN) self.assertEqual(int(train_before_export), sess.run(training_module.get_global_step())) self.assertIn('loss', outputs) - self.assertIn('metrics/mae/update_op', outputs) - self.assertIn('metrics/mae/value', outputs) + self.assertIn('metrics/mean_absolute_error/update_op', outputs) + self.assertIn('metrics/mean_absolute_error/value', outputs) self.assertIn('predictions/' + output_name, outputs) # Train for a step - train_op = ops.get_collection(constants.TRAIN_OP_KEY) + train_op = loader_impl.get_train_op(meta_graph_def) train_outputs, _ = sess.run( [outputs, train_op], {inputs[input_name]: input_arr, inputs[target_name]: target_arr}) @@ -401,8 +401,8 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase): output_path = keras_saved_model.save_keras_model( model, saved_model_path, custom_objects={'relu6': relu6}) with session.Session(graph=ops.Graph()) as sess: - inputs, outputs = load_model(sess, output_path, - model_fn_lib.ModeKeys.PREDICT) + inputs, outputs, _ = load_model(sess, output_path, + model_fn_lib.ModeKeys.PREDICT) input_name = model.input_names[0] output_name = model.output_names[0] predictions = sess.run( @@ -463,11 +463,6 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase): clone.compile(loss='mse', optimizer=keras.optimizers.RMSprop(lr=0.0001)) clone.train_on_batch(input_arr, target_arr) - with self.assertRaisesRegexp( - errors.InternalError, 'Model and clone must use the same variables.'): - keras_saved_model._assert_same_non_optimizer_objects( - model, model_graph, clone, clone_graph) - def testSaveSeqModelWithoutInputShapesRaisesError(self): """A Sequential model that hasn't been built should raise an error.""" model = sequential_model_without_input_shape(True) diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py index 8668c67cf95..922f21b98b3 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py @@ -154,8 +154,8 @@ class AttentionWrapperTest(test.TestCase): if attention_layer_sizes is not None: # Compute sum of attention_layer_sizes. Use encoder_output_depth if None. - attention_depth = sum([attention_layer_size or encoder_output_depth - for attention_layer_size in attention_layer_sizes]) + attention_depth = sum(attention_layer_size or encoder_output_depth + for attention_layer_size in attention_layer_sizes) elif attention_layers is not None: # Compute sum of attention_layers output depth. attention_depth = sum( diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py index 4d1807130c5..10e4556dacb 100644 --- a/tensorflow/contrib/summary/summary_ops_test.py +++ b/tensorflow/contrib/summary/summary_ops_test.py @@ -152,6 +152,27 @@ class EagerFileTest(test_util.TensorFlowTestCase): self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].tag, 'scalar') + def testRecordEveryNGlobalSteps(self): + step = training_util.get_or_create_global_step() + logdir = tempfile.mkdtemp() + + def run_step(): + summary_ops.scalar('scalar', i, step=step) + step.assign_add(1) + + with summary_ops.create_file_writer( + logdir).as_default(), summary_ops.record_summaries_every_n_global_steps( + 2, step): + for i in range(10): + run_step() + # And another 10 steps as a graph function. + run_step_fn = function.defun(run_step) + for i in range(10): + run_step_fn() + + events = summary_test_util.events_from_logdir(logdir) + self.assertEqual(len(events), 11) + def testMaxQueue(self): logs = tempfile.mkdtemp() with summary_ops.create_file_writer( @@ -279,12 +300,9 @@ class EagerDbTest(summary_test_util.SummaryDbTest): def testDbURIOpen(self): tmpdb_path = os.path.join(self.get_temp_dir(), 'tmpDbURITest.sqlite') - tmpdb_uri = six.moves.urllib_parse.urljoin("file:", tmpdb_path) - tmpdb_writer = summary_ops.create_db_writer( - tmpdb_uri, - "experimentA", - "run1", - "user1") + tmpdb_uri = six.moves.urllib_parse.urljoin('file:', tmpdb_path) + tmpdb_writer = summary_ops.create_db_writer(tmpdb_uri, 'experimentA', + 'run1', 'user1') with summary_ops.always_record_summaries(): with tmpdb_writer.as_default(): summary_ops.scalar('t1', 2.0) diff --git a/tensorflow/contrib/tensorboard/db/summary_file_writer.cc b/tensorflow/contrib/tensorboard/db/summary_file_writer.cc index 3f24f58f03a..22b6f09d0cd 100644 --- a/tensorflow/contrib/tensorboard/db/summary_file_writer.cc +++ b/tensorflow/contrib/tensorboard/db/summary_file_writer.cc @@ -73,7 +73,16 @@ class SummaryFileWriter : public SummaryWriterInterface { e->set_step(global_step); e->set_wall_time(GetWallTime()); Summary::Value* v = e->mutable_summary()->add_value(); - t.AsProtoTensorContent(v->mutable_tensor()); + + if (t.dtype() == DT_STRING) { + // Treat DT_STRING specially, so that tensor_util.MakeNdarray in Python + // can convert the TensorProto to string-type numpy array. MakeNdarray + // does not work with strings encoded by AsProtoTensorContent() in + // tensor_content. + t.AsProtoField(v->mutable_tensor()); + } else { + t.AsProtoTensorContent(v->mutable_tensor()); + } v->set_tag(tag); if (!serialized_metadata.empty()) { v->mutable_metadata()->ParseFromString(serialized_metadata); diff --git a/tensorflow/contrib/tensorboard/db/summary_file_writer_test.cc b/tensorflow/contrib/tensorboard/db/summary_file_writer_test.cc index cd3f712256f..ffbfb9533e8 100644 --- a/tensorflow/contrib/tensorboard/db/summary_file_writer_test.cc +++ b/tensorflow/contrib/tensorboard/db/summary_file_writer_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/contrib/tensorboard/db/summary_file_writer.h" #include "tensorflow/core/framework/summary.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/io/path.h" @@ -104,6 +105,23 @@ TEST_F(SummaryFileWriterTest, WriteTensor) { CHECK_EQ(e.summary().value_size(), 1); EXPECT_EQ(e.summary().value(0).tag(), "name"); })); + TF_CHECK_OK(SummaryTestHelper( + "string_tensor_test", + [](SummaryWriterInterface* writer) { + Tensor hello(DT_STRING, TensorShape({})); + hello.scalar()() = "hello"; + TF_RETURN_IF_ERROR(writer->WriteTensor( + 2, hello, "name", SummaryMetadata().SerializeAsString())); + TF_RETURN_IF_ERROR(writer->Flush()); + return Status::OK(); + }, + [](const Event& e) { + EXPECT_EQ(e.step(), 2); + CHECK_EQ(e.summary().value_size(), 1); + EXPECT_EQ(e.summary().value(0).tag(), "name"); + EXPECT_EQ(e.summary().value(0).tensor().dtype(), DT_STRING); + EXPECT_EQ(e.summary().value(0).tensor().string_val()[0], "hello"); + })); } TEST_F(SummaryFileWriterTest, WriteScalar) { diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index 20bcd2447e6..784acce444a 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -29,6 +29,10 @@ load( "if_tensorrt", ) +exports_files(glob([ + "test/testdata/*", +])) + tf_cuda_cc_test( name = "tensorrt_test_cc", size = "small", @@ -491,6 +495,7 @@ cuda_py_tests( "test/memory_alignment_test.py", "test/multi_connection_neighbor_engine_test.py", "test/neighboring_engine_test.py", + "test/quantization_test.py", "test/rank_two_test.py", "test/reshape_transpose_test.py", "test/vgg_block_nchw_test.py", @@ -527,6 +532,30 @@ cuda_py_tests( ], ) +cuda_py_test( + name = "quantization_mnist_test", + srcs = ["test/quantization_mnist_test.py"], + additional_deps = [ + ":tf_trt_integration_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python/keras:keras", + "//tensorflow/python/estimator:estimator", + ], + data = [ + "test/testdata/checkpoint", + "test/testdata/model.ckpt-46900.data-00000-of-00001", + "test/testdata/model.ckpt-46900.index", + ], + tags = [ + "no_cuda_on_cpu_tap", + "no_pip", + "no_tap", # It is not able to download the mnist data. + "no_windows", + "nomac", + ], +) + cc_library( name = "utils", srcs = ["convert/utils.cc"], diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index 26d54eb156c..812948bb303 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -82,60 +82,76 @@ std::vector GetLoadedTensorRTVersion() { } TrtCandidateSelector::TrtCandidateSelector( - const grappler::GraphProperties& graph_properties) - : graph_properties_(graph_properties) {} + const grappler::GraphProperties& graph_properties, int precision_mode) + : graph_properties_(graph_properties), precision_mode_(precision_mode) {} Status TrtCandidateSelector::IsTensorRTCandidate(const tensorflow::Node* node) { // TODO(laigd): move this set to TrtNodeValidator where it should belong. // LINT.IfChange static const std::set candidate_ops = { - "Identity", - "Snapshot", - "Const", - "Conv2D", - "MaxPool", - "BiasAdd", - "Relu", - "Add", - "Mul", - "Sub", - "Rsqrt", - "Pad", - "Mean", - "AvgPool", - "ConcatV2", - "DepthwiseConv2dNative", - "FusedBatchNorm", - "FusedBatchNormV2", - "Div", - "RealDiv", - "Rsqrt", - "Reciprocal", - "Exp", - "Log", - "Sqrt", - "Abs", - "Neg", - "Transpose", - "Reshape", - "MatMul", - "BatchMatMul", - "Softmax", - "Minimum", - "Maximum", - "TopKV2", - "Sum", - "Prod", - "Max", - "Min", + "Identity", + "Snapshot", + "Const", + "Conv2D", + "MaxPool", + "BiasAdd", + "Relu", + "Sigmoid", + "Tanh", + "Add", + "Mul", + "Sub", + "Rsqrt", + "Pad", + "Mean", + "AvgPool", + "ConcatV2", + "DepthwiseConv2dNative", + "FusedBatchNorm", + "FusedBatchNormV2", + "Div", + "RealDiv", + "Rsqrt", + "Reciprocal", + "Exp", + "Log", + "Sqrt", + "Abs", + "Neg", + "Transpose", + "Reshape", + "MatMul", + "BatchMatMul", + "Softmax", + "Minimum", + "Maximum", + "TopKV2", + "Sum", + "Prod", + "Max", + "Min", + "Relu6", + "Square", }; - // LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.cc) - const bool is_supported_op_type = + bool is_supported_op_type = (candidate_ops.count(node->type_string()) || PluginFactoryTensorRT::GetInstance()->IsPlugin(node->type_string())); + static const std::set quantize_ops = { + "QuantizeAndDequantizeV2", + "QuantizeAndDequantizeV3", + "FakeQuantWithMinMaxVars", + "FakeQuantWithMinMaxArgs", + }; + // In INT8 mode, we will always apply the quantization ranges provided by + // these ops to the relevant tensors. This happens regardless of the value of + // use_calibration. + if (precision_mode_ == INT8MODE && quantize_ops.count(node->type_string())) { + is_supported_op_type = true; + } + // LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.cc) if (!is_supported_op_type) { return errors::Unimplemented("Op type ", node->type_string(), - " is not supported."); + " is not supported"); } std::vector input_edges; @@ -170,7 +186,7 @@ tensorflow::Status BuildNodeMap( tensorflow::Status ConvertCalibGraphToInferGraph( const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* infer_graph, bool is_dyn_op) { - VLOG(0) << "Starting Calib Conversion"; + LOG(INFO) << "Starting Calib Conversion"; infer_graph->CopyFrom(graph_def); auto trt_rm = TRTResourceManager::instance(); auto calib_rm = trt_rm->getManager("TRTCalibration"); @@ -220,18 +236,19 @@ tensorflow::Status ConvertGraphDefToTensorRT( const std::vector& output_names, size_t max_batch_size, size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def, int precision_mode, int minimum_segment_size, bool is_dyn_op, - int max_cached_engines, std::vector cached_engine_batches) { + int max_cached_engines, std::vector cached_engine_batches, + bool use_calibration) { // Create GrapplerItem. tensorflow::grappler::GrapplerItem item; item.fetch = output_names; item.graph = graph_def; - // TODO(aaroey): we should have used single machine cluster like the - // following, but the problem is then wrap_conversion will depend on - // direct_session and cause double linking problems. To fix this we need to - // fix or get rid of the swig dependency. Here we use VirtualCluster - // as a work around, and we need to create a session to initialize the - // underlying device before calling this method. +// TODO(aaroey): we should have used single machine cluster like the +// following, but the problem is then wrap_conversion will depend on +// direct_session and cause double linking problems. To fix this we need to +// fix or get rid of the swig dependency. Here we use VirtualCluster +// as a work around, and we need to create a session to initialize the +// underlying device before calling this method. #if 0 // Create single machine cluster. Note that this will create a session and // initialize the gpu devices. @@ -264,7 +281,9 @@ tensorflow::Status ConvertGraphDefToTensorRT( #endif // Create RewriterConfig. - tensorflow::RewriterConfig rw_cfg; + tensorflow::ConfigProto config_proto; + auto& rw_cfg = + *config_proto.mutable_graph_options()->mutable_rewrite_options(); // TODO(aaroey): use only const folding and layout for the time being since // new optimizers break the graph for trt. rw_cfg.add_optimizers("constfold"); @@ -285,9 +304,10 @@ tensorflow::Status ConvertGraphDefToTensorRT( list->add_i(batch); } } + parameters["use_calibration"].set_b(use_calibration); // Run optimizer. - tensorflow::grappler::MetaOptimizer meta_opt(nullptr, rw_cfg); + tensorflow::grappler::MetaOptimizer meta_opt(nullptr, config_proto); TF_RETURN_IF_ERROR(meta_opt.Optimize(cluster.get(), item, new_graph_def)); if (VLOG_IS_ON(5)) { @@ -433,7 +453,8 @@ tensorflow::Status GetEngineInfo( << "but this shouldn't have happened"; info->device = *segment_devices.begin(); } else { - LOG(ERROR) << "Can't find a device placement for the op!"; + VLOG(1) << "No device is assigned to the segment. " + << "A device will be assigned during graph execution (inference)."; } return Status::OK(); } @@ -564,27 +585,30 @@ tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, } } } + + const bool calibrate_int8 = + (info.precision_mode == INT8MODE && info.use_calibration); + // Build the engine and get its serialized representation. string segment_string; - if (info.engine_type == EngineInfo::EngineType::TRTStatic || - info.precision_mode == INT8MODE) { + if (info.engine_type == EngineInfo::EngineType::TRTStatic || calibrate_int8) { // Create static engine for fp32/fp16 mode, and test validity of the engine - // for int8 mode. We don't want engine to fail at the calibration time. - // So we are constructing a FP32 engine here to check its validity, and if - // it is a valid engine then we put the serialized graphdef to the op. - // Otherwise we skip node creation for this engine. + // for int8 calibration mode. We don't want engine to fail at the + // calibration time. So we are constructing a FP32 engine here to check its + // validity, and if it is a valid engine then we put the serialized graphdef + // to the op. Otherwise we skip node creation for this engine. Logger trt_logger; TrtUniquePtrType engine; // TODO(sami): What happens if 1st dim is not batch? TF_RETURN_IF_ERROR(ConvertGraphDefToEngine( - info.segment_graph_def, - info.precision_mode == INT8MODE ? FP32MODE : info.precision_mode, + info.segment_graph_def, calibrate_int8 ? FP32MODE : info.precision_mode, max_batch_size, info.max_workspace_size_bytes, input_shapes, &trt_logger, alloc, /*calibrator=*/nullptr, &engine, + info.use_calibration, /*convert_successfully=*/nullptr)); TrtUniquePtrType engine_data(engine->serialize()); segment_string = string((const char*)engine_data->data(), engine_data->size()); - if (info.precision_mode == INT8MODE) { + if (calibrate_int8) { // See above comment about why not putting this inside the 'else' branch. segment_string = info.segment_graph_def.SerializeAsString(); } @@ -596,7 +620,7 @@ tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, // conversion. string prec_string; TF_RETURN_IF_ERROR(GetPrecisionModeName(info.precision_mode, &prec_string)); - if (info.precision_mode == INT8MODE && + if (info.precision_mode == INT8MODE && calibrate_int8 && !TRTResourceManager::instance()->getManager("TRTCalibration")) { LOG(ERROR) << "Failed to construct calibration storage"; } @@ -632,6 +656,7 @@ tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, .Attr("cached_engine_batches", {max_batch_size}) .Attr("workspace_size_bytes", info.max_workspace_size_bytes) .Attr("precision_mode", prec_string) + .Attr("use_calibration", info.use_calibration) .Attr("OutT", out_types) .Finalize(&trt_node); if (!status.ok()) { @@ -864,19 +889,17 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) { } segment_options.minimum_segment_size = params.minimum_segment_size; tensorflow::tensorrt::segment::SegmentNodesVector initial_segments; - TrtCandidateSelector candidate_selector(*params.graph_properties); + TrtCandidateSelector candidate_selector(*params.graph_properties, + params.precision_mode); TF_RETURN_IF_ERROR(tensorrt::segment::SegmentGraph( - &graph, - std::bind(&TrtCandidateSelector::IsTensorRTCandidate, &candidate_selector, - std::placeholders::_1), + &graph, std::bind(&TrtCandidateSelector::IsTensorRTCandidate, + &candidate_selector, std::placeholders::_1), // Input validation is already done by TrtCandidateSelector, so we don't // need to check the input edges. [](const Edge* edge) { return true; }, OutputEdgeValidator(), segment_options, &initial_segments)); - if (initial_segments.size() > 1) { - VLOG(0) << "MULTIPLE tensorrt candidate conversion: " + LOG(INFO) << "Number of TensorRT candidate segments: " << initial_segments.size(); - } // Get the EngineInfo for each segment. std::unordered_map node_map; @@ -902,13 +925,17 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) { continue; } curr_engine.precision_mode = params.precision_mode; - curr_engine.engine_type = - (params.is_dyn_op || params.precision_mode == INT8MODE - ? EngineInfo::EngineType::TRTDynamic - : EngineInfo::EngineType::TRTStatic); + if (params.use_calibration && params.precision_mode != INT8MODE) { + return errors::InvalidArgument( + "Calibration with FP32 or FP16 is not supported."); + } + curr_engine.engine_type = ((params.is_dyn_op || params.use_calibration) + ? EngineInfo::EngineType::TRTDynamic + : EngineInfo::EngineType::TRTStatic); + curr_engine.use_calibration = params.use_calibration; curr_engine.cached_engine_batches = params.cached_engine_batches; curr_engine.maximum_cached_engines = params.max_cached_engines; - StrAppend(&curr_engine.engine_name, "my_trt_op_", t); + StrAppend(&curr_engine.engine_name, "TRTEngineOp_", t); status = RegisterSegmentFunctionToFunctionLibrary( &graph, curr_engine.segment_graph_def, curr_engine.engine_name); if (!status.ok()) { @@ -969,16 +996,9 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) { &graph, alloc.get(), &engine_nodes); // If status is ok, we successfully added the node to the graph and can // remove segment ops. Otherwise graph is not modified. - string msg = StrCat("Engine ", engine.engine_name, " creation for segment ", - i, ", composed of ", + string msg = StrCat("TensorRT node ", engine.engine_name, + " added for segment ", i, " consisting of ", converted_segments.at(i).first.size(), " nodes"); - if (VLOG_IS_ON(1)) { - StrAppend(&msg, " ("); - for (const string& node_name : converted_segments.at(i).first) { - StrAppend(&msg, node_name, ", "); - } - StrAppend(&msg, ")"); - } if (status.ok()) { LOG(INFO) << msg << " succeeded."; for (auto node_name : converted_segments.at(i).first) { @@ -986,7 +1006,14 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) { } } else { // Graph is not modified. - LOG(WARNING) << msg << " failed: " << status << ". Skipping..."; + LOG(WARNING) << msg << " failed: " << status << ". Fallback to TF..."; + } + if (VLOG_IS_ON(1)) { + msg = "Segment consists of nodes: "; + for (const string& node_name : converted_segments.at(i).first) { + StrAppend(&msg, node_name, ", "); + } + VLOG(1) << msg; } } cudaSetDevice(old_cuda_device); diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.h b/tensorflow/contrib/tensorrt/convert/convert_graph.h index 1c9d82105a7..1f39f56f639 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.h +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.h @@ -35,7 +35,8 @@ namespace convert { // supported by TRT. class TrtCandidateSelector { public: - TrtCandidateSelector(const grappler::GraphProperties& graph_properties); + TrtCandidateSelector(const grappler::GraphProperties& graph_properties, + int precision_mode); // Returns OK iff 'node' is a TF-TRT conversion candidate, which will be added // to TRT subgraph and later converted into TRT engine. @@ -49,6 +50,9 @@ class TrtCandidateSelector { // GraphProperties of the graph whose nodes are to be validated by // IsTensorRTCandidate(). const grappler::GraphProperties& graph_properties_; + + // Quantization ops are only converted when using quantized precisions. + const int precision_mode_; }; struct ConversionParams { @@ -63,6 +67,7 @@ struct ConversionParams { cluster(nullptr), is_dyn_op(false), fixed_input_size(true), + use_calibration(true), max_cached_engines(1) {} const tensorflow::GraphDef* input_graph_def; const std::vector* output_names; @@ -76,6 +81,7 @@ struct ConversionParams { bool is_dyn_op; // Whether to create engine on conversion or execution time bool fixed_input_size; // Assume non-batch ranks of input tensors are fixed int max_cached_engines; // maximum number of cached engines + bool use_calibration; std::vector cached_engine_batches; // list of cached engines }; @@ -95,7 +101,7 @@ tensorflow::Status ConvertGraphDefToTensorRT( size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def, int precision_mode = 1, int minimum_segment_size = 3, bool is_dyn_op = false, int max_cached_engines = 1, - std::vector cached_engine_batches = {}); + std::vector cached_engine_batches = {}, bool use_calibration = true); // Method to call from optimization pass tensorflow::Status ConvertAfterShapes(ConversionParams& params); diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc b/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc index f10729987fd..2d2bfeb192c 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc @@ -85,27 +85,42 @@ TEST(TrtCandidateSelector, Basics) { ops::MatMul(s.WithOpName("matmul_with_incompatible_input"), incompatible_feed, const_2); + // Quantize ops. + auto quantize_attrs = ops::FakeQuantWithMinMaxArgs::Min(-6.0f).Max(6.0f); + auto quantize = ops::FakeQuantWithMinMaxArgs(s.WithOpName("quantize"), feed, + quantize_attrs); + + // Get GrapplerItem and GraphProperties. grappler::GrapplerItem item; TF_EXPECT_OK(s.ToGraphDef(&item.graph)); Tensor feed_tensor(DT_FLOAT, input_shape); item.feed.push_back(std::make_pair("feed", feed_tensor)); - grappler::GraphProperties graph_properties(item); TF_EXPECT_OK(graph_properties.InferStatically(true)); - TrtCandidateSelector selector(graph_properties); - TF_EXPECT_OK(selector.IsTensorRTCandidate(matmul.operation.node())); - ExpectStatus( - selector.IsTensorRTCandidate(incompatible_matmul.operation.node()), - error::INVALID_ARGUMENT, - "transpose_a is not supported for TensorRT FullyConnected " - "(op: MatMul), at: incompatible_matmul"); - ExpectStatus(selector.IsTensorRTCandidate(unsupported_op.operation.node()), - error::UNIMPLEMENTED, "Op type Sin is not supported"); - ExpectStatus(selector.IsTensorRTCandidate( - matmul_with_incompatible_input.operation.node()), - error::INTERNAL, - "Failed to convert input with index 0 to a TRT_TensorOrWeights"); + for (const int precision_mode : {FP32MODE, INT8MODE}) { + TrtCandidateSelector selector(graph_properties, precision_mode); + TF_EXPECT_OK(selector.IsTensorRTCandidate(matmul.operation.node())); + ExpectStatus( + selector.IsTensorRTCandidate(incompatible_matmul.operation.node()), + error::INVALID_ARGUMENT, + "transpose_a is not supported for TensorRT FullyConnected " + "(op: MatMul), at: incompatible_matmul"); + ExpectStatus(selector.IsTensorRTCandidate(unsupported_op.operation.node()), + error::UNIMPLEMENTED, "Op type Sin is not supported"); + ExpectStatus( + selector.IsTensorRTCandidate( + matmul_with_incompatible_input.operation.node()), + error::INTERNAL, + "Failed to convert input with index 0 to a TRT_TensorOrWeights"); + if (precision_mode == INT8MODE) { + TF_EXPECT_OK(selector.IsTensorRTCandidate(quantize.operation.node())); + } else { + ExpectStatus(selector.IsTensorRTCandidate(quantize.operation.node()), + error::UNIMPLEMENTED, + "Op type FakeQuantWithMinMaxArgs is not supported"); + } + } } class FakeCluster : public grappler::Cluster { diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index e2988f5f2a8..25a34dd3503 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -54,10 +54,10 @@ limitations under the License. // would work! #define TFTRT_CHECK_EQ_TYPE(val1, val2) CHECK_EQ((int)val1, (int)val2) -#define TFTRT_INTERNAL_ERROR_AT_NODE(node) \ - do { \ - return tensorflow::errors::Internal( \ - "TFTRT::", __FUNCTION__, "failed to add TRT layer, at: ", node); \ +#define TFTRT_INTERNAL_ERROR_AT_NODE(node) \ + do { \ + return tensorflow::errors::Internal( \ + "TFTRT::", __FUNCTION__, " failed to add TRT layer, at: ", node); \ } while (0) #define TFTRT_RETURN_ERROR_IF_FALSE(status, node) \ @@ -130,7 +130,7 @@ void GetOutputProperties(const grappler::GraphProperties& graph_properties, *dtype = out_shape.dtype(); *shape = out_shape.shape(); } else { - VLOG(0) << "Unknown output shape" << node->name(); + LOG(INFO) << "Unknown output shape" << node->name(); *dtype = node->output_type(out_port); } } @@ -181,16 +181,55 @@ Status ValidateTensorProperties(const string& producer_node_type, if (shape.dim_size(d) < 0) { return errors::InvalidArgument( "Input tensor with shape ", shape.DebugString(), - " has an unknown non-batch dimemension at dim ", d); + " has an unknown non-batch dimension at dim ", d); } } return Status::OK(); } +string DebugString(const nvinfer1::DimensionType type) { + switch (type) { + case nvinfer1::DimensionType::kSPATIAL: + return "kSPATIAL"; + case nvinfer1::DimensionType::kCHANNEL: + return "kCHANNEL"; + case nvinfer1::DimensionType::kINDEX: + return "kINDEX"; + case nvinfer1::DimensionType::kSEQUENCE: + return "kSEQUENCE"; + default: + return StrCat(static_cast(type), "=unknown"); + } +} + +string DebugString(const nvinfer1::DataType trt_dtype) { + switch (trt_dtype) { + case nvinfer1::DataType::kFLOAT: + return "kFLOAT"; + case nvinfer1::DataType::kHALF: + return "kHALF"; + case nvinfer1::DataType::kINT8: + return "kINT8"; + case nvinfer1::DataType::kINT32: + return "kINT32"; + default: + return "Invalid TRT data type"; + } +} + string DebugString(const nvinfer1::Dims& dims) { string out = StrCat("nvinfer1::Dims(nbDims=", dims.nbDims, ", d="); for (int i = 0; i < dims.nbDims; ++i) { - StrAppend(&out, dims.d[i], ","); + StrAppend(&out, dims.d[i], "[", DebugString(dims.type[i]), "],"); + } + StrAppend(&out, ")"); + return out; +} + +string DebugString(const nvinfer1::Permutation& permutation, int len) { + string out = "nvinfer1::Permutation("; + for (int i = 0; i < len; ++i) { + StrAppend(&out, permutation.order[i], ","); } StrAppend(&out, ")"); return out; @@ -198,16 +237,15 @@ string DebugString(const nvinfer1::Dims& dims) { string DebugString(const nvinfer1::ITensor& tensor) { return StrCat("nvinfer1::ITensor(@", reinterpret_cast(&tensor), - ", shape=", DebugString(tensor.getDimensions()), ")"); + ", name=", tensor.getName(), + ", dtype=", DebugString(tensor.getType()), + ", dims=", DebugString(tensor.getDimensions()), ")"); } -// Return whether or not the broadcast is feasible; -bool TensorRTGetBroadcastShape(const nvinfer1::Dims& operand_l, - const bool operand_l_is_tensor, - const nvinfer1::Dims& operand_r, - const bool operand_r_is_tensor, - nvinfer1::Dims* operand_l_new_shape, - nvinfer1::Dims* operand_r_new_shape) { +Status Converter::GetTrtBroadcastShape( + const TRT_TensorOrWeights& operand_l, const TRT_TensorOrWeights& operand_r, + nvinfer1::Dims* operand_l_new_dims, + nvinfer1::Dims* operand_r_new_dims) const { // *************************************************************************** // TensorRT Elementwise op supports broadcast but requires both tensor to be // of Identical rank @@ -232,52 +270,59 @@ bool TensorRTGetBroadcastShape(const nvinfer1::Dims& operand_l, // -> T: 1 1 1 -1 3 5 1 // -> W: 1 1 1 1 3 5 1 // *************************************************************************** + if (!operand_l.is_tensor() && !operand_r.is_tensor()) { + return errors::InvalidArgument( + "Broadcasting requires at least one of the operands be tensors"); + } + const int max_nb_dims = nvinfer1::Dims::MAX_DIMS + 1; - const size_t element_size = sizeof(operand_l.d[0]); + auto compute_output_dims = + [max_nb_dims](const TRT_TensorOrWeights& input, int broadcast_num_dims, + int* output_dims_array, nvinfer1::Dims* output_dims) { + const nvinfer1::Dims input_dims = input.GetTrtDims(); + std::fill(output_dims_array, output_dims_array + max_nb_dims, 1); + std::copy(input_dims.d, input_dims.d + input_dims.nbDims, + output_dims_array + broadcast_num_dims - input_dims.nbDims); + if (input.is_tensor()) { + const int true_input_dims = input_dims.nbDims + 1; + if (true_input_dims < broadcast_num_dims) { + return errors::InvalidArgument( + "Broadcasting beyond batch dimension is not supported ", + "(tensor #dims ", true_input_dims, " vs broadcast #dims ", + broadcast_num_dims, ")"); + } + // Set the batch dimension to -1, since batch size is not supposed to + // be broadcasted. + output_dims_array[0] = -1; + } + // Copy to output dimensions (stripping the batch dimension). + output_dims->nbDims = broadcast_num_dims - 1; + std::copy(output_dims_array + 1, output_dims_array + broadcast_num_dims, + output_dims->d); + return Status::OK(); + }; - // fill in dimensions - int l_s[max_nb_dims]; - std::fill(l_s, l_s + max_nb_dims, 1); - int l_d = operand_l_is_tensor ? operand_l.nbDims + 1 : operand_l.nbDims; - int r_s[max_nb_dims]; - std::fill(r_s, r_s + max_nb_dims, 1); - int r_d = operand_r_is_tensor ? operand_r.nbDims + 1 : operand_r.nbDims; + // Compute the output dimensions. + const int broadcast_num_dims = + std::max(operand_l.GetTrtDims().nbDims + (operand_l.is_tensor() ? 1 : 0), + operand_r.GetTrtDims().nbDims + (operand_r.is_tensor() ? 1 : 0)); + int output_l[max_nb_dims], output_r[max_nb_dims]; + TF_RETURN_IF_ERROR(compute_output_dims(operand_l, broadcast_num_dims, + output_l, operand_l_new_dims)); + TF_RETURN_IF_ERROR(compute_output_dims(operand_r, broadcast_num_dims, + output_r, operand_r_new_dims)); - int max_d = std::max(l_d, r_d); - std::memcpy(l_s + max_d - operand_l.nbDims, operand_l.d, - operand_l.nbDims * element_size); - std::memcpy(r_s + max_d - operand_r.nbDims, operand_r.d, - operand_r.nbDims * element_size); - - // set -1 for batch dimension, since batch size is not supposed to be - // broadcasted - if (operand_l_is_tensor) { - if (max_d != l_d) { // if broadcast beyond batch dimension, fail - return false; - } - l_s[0] = -1; - } - if (operand_r_is_tensor) { - if (max_d != r_d) { // if broadcast beyond batch dimension, fail - return false; - } - r_s[0] = -1; - } - - // compare broadcast feasibility - for (int i = max_d - 1; i >= 0; i--) { - if ((l_s[i] != r_s[i]) && (l_s[i] != 1) && (r_s[i] != 1)) { - return false; + // Compare broadcast feasibility + for (int i = 0; i < broadcast_num_dims; ++i) { + if ((output_l[i] != output_r[i]) && (output_l[i] != 1) && + (output_r[i] != 1)) { + return errors::InvalidArgument( + "Infeasible broadcast scheme (", "batch_dim: ", output_l[0], ", ", + DebugString(*operand_l_new_dims), " vs ", "batch_dim: ", output_r[0], + ", ", DebugString(*operand_r_new_dims), ")"); } } - - // output new TensorRT Dimension (stripping the batch dimension) - operand_l_new_shape->nbDims = max_d - 1; - std::memcpy(operand_l_new_shape->d, l_s + 1, (max_d - 1) * element_size); - operand_r_new_shape->nbDims = max_d - 1; - std::memcpy(operand_r_new_shape->d, r_s + 1, (max_d - 1) * element_size); - - return true; + return Status::OK(); } inline bool DimsEqual(const nvinfer1::Dims& dim_l, @@ -381,8 +426,8 @@ size_t TRT_ShapedWeights::size_bytes() const { string TRT_ShapedWeights::DebugString() const { return StrCat("TRT_ShapedWeights(shape=", convert::DebugString(shape_), - ", type=", type_, - ", values=", reinterpret_cast(GetValues()), ")"); + ", type=", DataTypeString(type_), ", values=", + reinterpret_cast(GetValues()), ")"); } // A fake ITensor implementation used to check whether the TF-TRT converter can @@ -425,7 +470,9 @@ class TRT_TensorOrWeights::SimpleITensor : public nvinfer1::ITensor { void setLocation(nvinfer1::TensorLocation location) override {} #if NV_TENSORRT_MAJOR >= 5 - bool setDynamicRange(float min, float max) override {} + bool setDynamicRange(float min, float max) override { return true; } + + float getDynamicRange() const override { return 0; } #endif private: @@ -489,8 +536,7 @@ nvinfer1::Dims TRT_TensorOrWeights::GetTrtDims() const { string TRT_TensorOrWeights::DebugString() const { string output = "TRT_TensorOrWeights(type="; if (is_tensor()) { - StrAppend(&output, "tensor @", reinterpret_cast(tensor()), - ", shape=", convert::DebugString(tensor()->getDimensions()), + StrAppend(&output, "tensor=", convert::DebugString(*tensor()), ", batch_size=", batch_size_); } else { StrAppend(&output, "weights=", weights_.DebugString()); @@ -627,11 +673,10 @@ void ReorderCKtoKC(const TRT_ShapedWeights& iweights, break; } case tensorflow::DataType::DT_HALF: { - Reorder2( - {k, c}, static_cast(iweights.GetValues()), - istrides, - static_cast(const_cast(oweights->GetValues())), - ostrides); + Reorder2({k, c}, static_cast(iweights.GetValues()), + istrides, static_cast( + const_cast(oweights->GetValues())), + ostrides); break; } default: @@ -753,8 +798,9 @@ Status TrtNodeValidator::ValidateNode( Status status = ConvertToTensorOrWeights( *pair.first, pair.second, graph_properties, &tensor_or_weights); if (!status.ok()) { - return errors::Internal("Failed to convert input with index ", i, - " to a TRT_TensorOrWeights"); + return errors::Internal( + "Failed to convert input with index ", i, + " to a TRT_TensorOrWeights: ", status.error_message()); } inputs.push_back(tensor_or_weights); } @@ -786,8 +832,11 @@ Status TrtNodeValidator::ConvertConstToWeights( return status; } -Converter::Converter(nvinfer1::INetworkDefinition* trt_network, bool is_fp16) - : trt_network_(trt_network), is_fp16_(is_fp16) { +Converter::Converter(nvinfer1::INetworkDefinition* trt_network, + int precision_mode, bool use_calibration) + : trt_network_(trt_network), + precision_mode_(precision_mode), + use_calibration_(use_calibration) { this->RegisterOpConverters(); } @@ -812,13 +861,18 @@ Status Converter::ConvertNode(const NodeDef& node_def) { TRT_TensorOrWeights& output = outputs[i]; string output_name = node_def.name(); if (i != 0) output_name = StrCat(output_name, ":", i); - // We need to check the name before setting it. For Identity op where the - // output is the input, if its input is one of the engine input, setting - // the name here will overwrite engine input bindings which will cause - // runtime error. + // We need to check the name before setting it. If the input is one of the + // engine input, setting the name here will overwrite engine input + // bindings which will cause runtime error. if (output.is_tensor()) { const char* tensor_name = output.tensor()->getName(); - if (tensor_name == nullptr || std::strlen(tensor_name) == 0) { + if (!tensorflow::str_util::StartsWith(tensor_name, kInputPHName)) { + // TRT initializes tensor names as "(Unnamed ITensor* N)". We rename + // them to match their corresponding TensorFlow name. + // Note: ITensors that we create internally within TF-TRT which are + // not inputs or outputs of a node will not be renamed. This is a + // potential cause of confusion if an error message or warning + // mentions the unnamed tensor. output.tensor()->setName(output_name.c_str()); } } @@ -930,11 +984,14 @@ Status Converter::TransposeTensor(nvinfer1::ITensor* input_tensor, nvinfer1::IShuffleLayer* layer = this->network()->addShuffle(*input_tensor); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, "TF-TRT Internal Transpose"); + MarkQuantizationRangesAsInferrable(input_tensor, layer->getOutput(0)); nvinfer1::Permutation permutation; for (int32_t i = 0; i < dims.nbDims; ++i) { permutation.order[i] = order_with_batch_dim[i + 1] - 1; } + VLOG(1) << "TransposeTensor permutation: " + << DebugString(permutation, dims.nbDims); layer->setFirstTranspose(permutation); nvinfer1::Dims reshape_dims; @@ -950,6 +1007,38 @@ Status Converter::TransposeTensor(nvinfer1::ITensor* input_tensor, return tensorflow::Status::OK(); } +Status Converter::GetWeightRange(const TRT_ShapedWeights& weights, + float* out_min, float* out_max) const { + switch (weights.type_) { + case DataType::DT_FLOAT: { + auto inp = static_cast(weights.GetValues()); + auto result = std::minmax_element(inp, inp + weights.count()); + *out_min = *result.first; + *out_max = *result.second; + break; + } + case DataType::DT_HALF: { + auto inp = static_cast(weights.GetValues()); + auto result = std::minmax_element(inp, inp + weights.count()); + *out_min = Eigen::half_impl::half_to_float(*result.first); + *out_max = Eigen::half_impl::half_to_float(*result.second); + break; + } + case DataType::DT_INT32: { + auto inp = static_cast(weights.GetValues()); + auto result = std::minmax_element(inp, inp + weights.count()); + *out_min = static_cast(*result.first); + *out_max = static_cast(*result.second); + break; + } + default: + return errors::Unimplemented( + "Data type not supported for GetWeightRange: ", + DataTypeString(weights.type_)); + } + return Status::OK(); +} + Status Converter::PrepareTensorForShape(const TRT_TensorOrWeights& input, const nvinfer1::Dims& dims, const nvinfer1::ITensor** tensor) { @@ -964,8 +1053,9 @@ Status Converter::PrepareTensorForShape(const TRT_TensorOrWeights& input, } if (can_check_shapes && TrtDimsNumElements(input.GetTrtDims()) != TrtDimsNumElements(dims)) { - return tensorflow::errors::InvalidArgument( - "Reshape shapes are not compatible."); + return errors::InvalidArgument("Reshape shapes are not compatible (", + DebugString(input.GetTrtDims()), " vs ", + DebugString(dims), ")"); } if (input.is_tensor()) { @@ -976,6 +1066,8 @@ Status Converter::PrepareTensorForShape(const TRT_TensorOrWeights& input, *const_cast(input.tensor())); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, "TF-TRT Internal Reshape"); layer->setReshapeDimensions(dims); + MarkQuantizationRangesAsInferrable( + const_cast(input.tensor()), layer->getOutput(0)); *tensor = layer->getOutput(0); } } else { @@ -983,10 +1075,123 @@ Status Converter::PrepareTensorForShape(const TRT_TensorOrWeights& input, this->network()->addConstant(dims, input.weights().GetTrtWeights()); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, "TF-TRT Internal Reshape"); *tensor = layer->getOutput(0); + if (precision_mode() == INT8MODE && !use_calibration()) { + // If we are in int8 mode and not calibrating, we need to explicitly set a + // quantization range for the output tensor of the IConstantLayer. Here we + // set the range to [min(weights), max(weights)]. + float min_range = 0.0f; + float max_range = 0.0f; + TF_RETURN_IF_ERROR( + GetWeightRange(input.weights(), &min_range, &max_range)); + // Avoid setting range to 0 because TRT will throw an error. If the + // weights are zero then the range doesn't matter: using 127.0f should + // ensure the quantized weight will be exactly zero. + if (min_range == 0.0f && max_range == 0.0f) { + min_range = -127.0f; + max_range = 127.0f; + } + ProvideQuantizationRange(const_cast(*tensor), + min_range, max_range); + } } return tensorflow::Status::OK(); } +void Converter::MarkQuantizationRangesAsInferrable(nvinfer1::ITensor* input, + nvinfer1::ITensor* output) { + quantization_infer_.push_back({input, output}); + quantization_infer_.push_back({output, input}); +} + +void Converter::ProvideQuantizationRange(nvinfer1::ITensor* tensor, + float min_range, float max_range) { + float symmetric_range = std::max(std::abs(min_range), std::abs(max_range)); + quantization_ranges_[tensor] = symmetric_range; +} + +void Converter::MaybeApplyQuantizationRanges() { + if (precision_mode() != INT8MODE) return; + + // Infer ranges across marked ops. + PropagateQuantizationRanges(); + // Apply ranges. +#if NV_TENSORRT_MAJOR >= 5 + for (auto pair : quantization_ranges_) { + nvinfer1::ITensor* tensor = pair.first; + const float range = pair.second; + VLOG(1) << "Setting range for: " << tensor->getName() << ": " << range; + // TODO(laigd): if 'tensor' already has a range set which doesn't match + // 'range', it should report error. + tensor->setDynamicRange(-range, range); + } +#endif + + // Warn user about tensors that are missing ranges. If TRT fuses some layers + // then these tensors may not actually be required, which is why this is + // just a warning. If we are still missing ranges even after fusion, + // Builder::buildCudaEngine() will return nullptr and we will catch the + // error at that point. + if (!use_calibration()) { + // Get all tensors from network + std::set all_tensors; + for (int i = 0; i < this->network()->getNbLayers(); i++) { + nvinfer1::ILayer* layer = this->network()->getLayer(i); + for (int j = 0; j < layer->getNbInputs(); j++) { + all_tensors.insert(layer->getInput(j)); + } + for (int j = 0; j < layer->getNbOutputs(); j++) { + all_tensors.insert(layer->getOutput(j)); + } + } + // Find tensors with no ranges + for (auto tensor : all_tensors) { + if (!quantization_ranges_.count(tensor)) { + // Note: there may be some warnings for "(Unnamed ITensor* N)". These + // are tensors which are created internally by TF-TRT. The ranges for + // these unnamed ITensors are always inferred from user provided ranges, + // thus there will also be a warning for the range(s) the user missed. + LOG(WARNING) << "Quantization range was not found for " + << tensor->getName() << ". " + << "This is okay if TensorRT does not need the range " + << "(e.g. due to node fusion)."; + } + } + } +} + +void Converter::PropagateQuantizationRanges() { + // Propagate ranges across edges in quantization_infer_ until no new + // information is added. + // Note: this function modifies quantization_infer_, it might be better to + // modify a copy instead if we for some reason need quantization_infer_ + // later. + bool information_added = true; + while (information_added) { + information_added = false; + for (auto it = quantization_infer_.begin(); + it != quantization_infer_.end();) { + auto input_tensor_range = quantization_ranges_.find(it->first); + auto output_tensor_range = quantization_ranges_.find(it->second); + if (input_tensor_range != quantization_ranges_.end() && + output_tensor_range == quantization_ranges_.end()) { + // Input has range but output doesn't: copy range + // TODO(laigd): consider reporting error if it a different range is + // already set. + quantization_ranges_[it->second] = input_tensor_range->second; + information_added = true; + VLOG(1) << "Copy quantization range: " << it->first->getName() << " -> " + << it->second->getName(); + } + // We can remove edges when the output range is known + if (quantization_ranges_.find(it->second) != quantization_ranges_.end()) { + it = quantization_infer_.erase(it); + } else { + ++it; + } + } + } +} + Status Converter::GetInputs(const tensorflow::NodeDef& node_def, std::vector* inputs) const { for (auto const& input_name : node_def.input()) { @@ -1043,12 +1248,11 @@ TRT_ShapedWeights ConvertFP32ToFP16(TrtWeightStore* store, } // **************************************************************************** -// Constant folding functions -// TODO(jie): once optimizer kicks in, we should have done constant folding -// there. +// Constant folding functions for weights. +// TODO(laigd): we should probably use eigen directly. // ***************************************************************************** struct LambdaFactory { - enum class OP_CATEGORY : int { RSQRT = 0, NEG, ADD, MUL, SUB, RECIP }; + enum class OP_CATEGORY : int { RSQRT = 0, NEG, RECIP }; OP_CATEGORY op; template @@ -1063,84 +1267,10 @@ struct LambdaFactory { case OP_CATEGORY::RECIP: return [](T t) -> T { return 1.0 / t; }; default: - VLOG(2) << "Not supported op for unary: " << static_cast(op); + LOG(ERROR) << "Not supported op for unary: " << static_cast(op); return nullptr; } } - - template - std::function binary() { - switch (op) { - case OP_CATEGORY::ADD: - return [](T l, T r) -> T { return l + r; }; - case OP_CATEGORY::SUB: - return [](T l, T r) -> T { return l - r; }; - case OP_CATEGORY::MUL: - return [](T l, T r) -> T { return l * r; }; - default: - LOG(WARNING) << "Not supported op for binary: " << static_cast(op); - } - return [](T l, T r) -> T { - LOG(FATAL) << "Unsupported op type "; - return l; - }; - } - - template - std::function broadcast_r(T val) { - VLOG(2) << "LAMBDA VAL : " << val; - switch (op) { - case OP_CATEGORY::ADD: - return [val](T l) -> T { - VLOG(2) << "LAMBDA VAL : " << val; - return l + val; - }; - case OP_CATEGORY::SUB: - return [val](T l) -> T { - VLOG(2) << "LAMBDA VAL : " << val; - return l - val; - }; - case OP_CATEGORY::MUL: - return [val](T l) -> T { - VLOG(2) << "LAMBDA VAL : " << val; - return l * val; - }; - default: - LOG(WARNING) << "Not supported op for binary: " << static_cast(op); - } - return [val](T l) -> T { - LOG(FATAL) << "Unsupported op type "; - return l; - }; - } - - template - std::function broadcast_l(T val) { - VLOG(2) << "LAMBDA VAL : " << val; - switch (op) { - case OP_CATEGORY::ADD: - return [val](T l) -> T { - VLOG(2) << "LAMBDA VAL : " << val; - return val + l; - }; - case OP_CATEGORY::SUB: - return [val](T l) -> T { - VLOG(2) << "LAMBDA VAL : " << val; - return val - l; - }; - case OP_CATEGORY::MUL: - return [val](T l) -> T { - VLOG(2) << "LAMBDA VAL : " << val; - return val * l; - }; - default: - LOG(ERROR) << "Not supported op for binary: " << static_cast(op); - } - return [val](T l) -> T { - LOG(FATAL) << "Unsupported op type "; - return l; - }; - } }; template <> @@ -1148,15 +1278,18 @@ std::function LambdaFactory::unary() { switch (op) { case OP_CATEGORY::RSQRT: { VLOG(2) << "RSQRT GETS DONE"; - return [](Eigen::half t) -> Eigen::half { + return [](Eigen::half t) { return Eigen::half(1.0 / sqrt(static_cast(t))); }; } case OP_CATEGORY::NEG: - return [](Eigen::half t) -> Eigen::half { return -t; }; - // TODO(aaroey): can we support RECIP? + return [](Eigen::half t) { return -t; }; + case OP_CATEGORY::RECIP: + return [](Eigen::half t) { + return Eigen::half(1.0 / static_cast(t)); + }; default: - VLOG(2) << "Not supported op for unary: " << static_cast(op); + LOG(ERROR) << "Not supported op for unary: " << static_cast(op); return nullptr; } } @@ -1188,50 +1321,48 @@ tensorflow::Status UnaryCompute(const TRT_ShapedWeights& iweights, return tensorflow::Status::OK(); } +// If swapped_inputs is false, 'tensor' is the left operand and 'weights' is the +// right operand. If swapped_inputs is true, those two are swapped. +// // TODO(jie): broadcast is needed yet not implemented. -// Only implemented channel wise for the time being -tensorflow::Status BinaryTensorOpWeight(OpConverterParams* params, - const nvinfer1::ITensor* tensor, - TRT_ShapedWeights weights, - bool swapped_inputs) { +// Only implemented channel wise for the time being. +Status BinaryTensorOpWeight(OpConverterParams* params, + const nvinfer1::ITensor* tensor, + TRT_ShapedWeights weights, bool swapped_inputs) { + static const std::unordered_set supported_ops = {"Sub", "Add", "Mul", + "Div", "RealDiv"}; const auto& node_def = params->node_def; - // tensor is the left operand while weights is the right operand; - // when swapped_inputs set to true, those two are swapped. - // TODO(aaroey): use a set. - if (node_def.op() != "Sub" && node_def.op() != "Add" && - node_def.op() != "Mul" && node_def.op() != "Div" && - node_def.op() != "RealDiv") { - return tensorflow::errors::Unimplemented( - "op not supported: " + node_def.op() + ", at: " + node_def.name()); + if (!supported_ops.count(node_def.op())) { + return errors::Unimplemented(node_def.op(), " is not supported, at ", + node_def.name()); } - // Check type consistency - nvinfer1::DataType ttype; - TF_RETURN_IF_ERROR(ConvertDType(weights.type_, &ttype)); + // Check type consistency. + nvinfer1::DataType trt_dtype; + TF_RETURN_IF_ERROR(ConvertDType(weights.type_, &trt_dtype)); - // Check scale mode + // Check scale mode. auto dims_w = weights.shape_; - auto dims_t = tensor->getDimensions(); + const auto dims_t = tensor->getDimensions(); // TODO(jie): addScale checks for input tensor dimension if (dims_t.nbDims != 3) { - return tensorflow::errors::InvalidArgument( - "addScale requires tensor with rank 3, " + node_def.name()); + return errors::InvalidArgument("addScale requires tensor with rank 3, at ", + node_def.name()); } - // default to element-wise + // Default to element-wise auto scale_mode = nvinfer1::ScaleMode::kELEMENTWISE; // TODO(jie): maybe use a permutation instead to support more cases; - bool permutation_flag = false; + bool need_to_permute = false; if (weights.count() == 1) { - VLOG(2) << "UNIFORM"; scale_mode = nvinfer1::ScaleMode::kUNIFORM; } else { - // no broadcasting on Batch dimension; - VLOG(2) << "WEIGHTS DIM: " << dims_w.nbDims - << " tensor DIM: " << dims_t.nbDims; + VLOG(2) << "weights dims: " << DebugString(dims_w) + << "; tensor dims: " << DebugString(dims_t); + // Make sure no broadcasting on batch dimension. if (dims_w.nbDims == dims_t.nbDims + 1) { if (dims_w.d[0] == 1) { for (int i = 1; i < dims_w.nbDims; i++) { @@ -1239,72 +1370,70 @@ tensorflow::Status BinaryTensorOpWeight(OpConverterParams* params, } dims_w.nbDims--; } else { - return tensorflow::errors::InvalidArgument( - "Binary op cannot operate on batch, " + node_def.name()); + return errors::InvalidArgument("Binary op cannot operate on batch, at ", + node_def.name()); } } if (dims_w.nbDims == dims_t.nbDims && dims_w.d[0] == dims_t.d[0]) { scale_mode = nvinfer1::ScaleMode::kELEMENTWISE; - // default is element; + // Default is element-wise for (int i = 1; i < dims_w.nbDims; i++) { if (dims_w.d[i] != dims_t.d[i]) { - // if dimension does not match, switch back to channel; - VLOG(2) << "channel"; + // If dimension does not match, switch back to per-channel scale_mode = nvinfer1::ScaleMode::kCHANNEL; break; } } - // if channel as candidate, validate it + // If the mode is per-channel, since channel dimension is assumed to be + // the third to last dimension, we need to make sure all other dimensions + // have size 1. if (scale_mode == nvinfer1::ScaleMode::kCHANNEL) { for (int i = 1; i < dims_w.nbDims; i++) { if (dims_w.d[i] != 1) - return tensorflow::errors::InvalidArgument( - "Weight shape not compatible at, " + node_def.name()); + return errors::InvalidArgument( + "Weight dims not compatible for channel-wise broadcast at ", + node_def.name()); } - } else { - VLOG(2) << "elementwise"; } } else if (dims_w.nbDims == 1 && dims_w.d[0] == dims_t.d[dims_t.nbDims - 1]) { - // channel wise and broadcast required; - permutation_flag = true; + // Channel wise and broadcast required. We compare the last dimension of + // the tensor shape because of tensorflow default broadcasting rules. + need_to_permute = true; scale_mode = nvinfer1::ScaleMode::kCHANNEL; } else { - return tensorflow::errors::InvalidArgument( - "Weight shape not compatible at, " + node_def.name()); + return errors::InvalidArgument("Weight dims not compatible at ", + node_def.name()); } } + // TODO(laigd): we should add validation_only support in TransposeTensor() and + // PrepareTensorForShape(). + if (params->validation_only) return Status::OK(); - // transpose last dimension + // Transpose last dimension. std::vector permutation(dims_t.nbDims + 1); - if (permutation_flag) { - if (scale_mode == nvinfer1::ScaleMode::kCHANNEL && dims_t.nbDims > 1) { - // we swap the last dimension into channel for trt. - // because of tensorflow default broadcasting rules. - for (int i = 0; i < static_cast(permutation.size()); i++) { - permutation[i] = i; - } - permutation[1] = dims_t.nbDims; - permutation[dims_t.nbDims] = 1; - TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - const_cast(tensor), permutation, &tensor)); - } else { - return tensorflow::errors::InvalidArgument( - "Transpose cannot be applied, " + node_def.name()); + if (need_to_permute) { + // We swap the last dimension into channel for trt, because of tensorflow + // default broadcasting rules. + for (int i = 0; i < static_cast(permutation.size()); i++) { + permutation[i] = i; } + permutation[1] = dims_t.nbDims; + permutation[dims_t.nbDims] = 1; + TF_RETURN_IF_ERROR(params->converter->TransposeTensor( + const_cast(tensor), permutation, &tensor)); } - if (params->converter->is_fp16()) { + if (params->converter->precision_mode() == FP16MODE) { weights = ConvertFP32ToFP16(params->weight_store, weights); } - // prepare weights + // Prepare weights TRT_ShapedWeights shift_weights(weights.type_); TRT_ShapedWeights scale_weights(weights.type_); TRT_ShapedWeights power_weights(weights.type_); - // Maybe I should do a switch if (node_def.op() == "Sub") { if (swapped_inputs) { shift_weights = weights; @@ -1312,6 +1441,10 @@ tensorflow::Status BinaryTensorOpWeight(OpConverterParams* params, *const_cast(tensor), nvinfer1::UnaryOperation::kNEG); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + // Since quantization ranges are symmetric, the same range as the input + // will work for the negation of the input. + params->converter->MarkQuantizationRangesAsInferrable( + const_cast(tensor), layer->getOutput(0)); tensor = layer->getOutput(0); } else { TRT_ShapedWeights neg_weights = @@ -1323,6 +1456,25 @@ tensorflow::Status BinaryTensorOpWeight(OpConverterParams* params, } } else if (node_def.op() == "Div" || node_def.op() == "RealDiv") { if (swapped_inputs) { + // We need to infer the quantization range for this intermediate tensor. + // + // x -> [Recip] -> 1/x -> [Scale] -> s/x + // ^ + // need range for this + // + // We have the quantization scales for x and s/x - can we divide the scale + // for s/x by s? Only if it is a scalar. + // + // Because of this issue, fall back to BinaryTensorOpTensor if we are + // doing INT8 with no calibration. There is most likely no performance + // penalty by falling back here. + if (params->converter->precision_mode() == INT8MODE && + !params->converter->use_calibration()) { + return errors::Unimplemented( + "Intermediate quantization range cannot be determined without" + " calibration. Falling back to BinaryTensorOpTensor for ", + node_def.op(), ", at ", node_def.name()); + } scale_weights = weights; nvinfer1::IUnaryLayer* layer = params->converter->network()->addUnary( *const_cast(tensor), @@ -1342,8 +1494,8 @@ tensorflow::Status BinaryTensorOpWeight(OpConverterParams* params, } else if (node_def.op() == "Add") { shift_weights = weights; } else { - return tensorflow::errors::Unimplemented("Binary op not supported: " + - node_def.op()); + // This should not happen. + return errors::Unimplemented("Binary op not supported at ", node_def.op()); } nvinfer1::IScaleLayer* layer = params->converter->network()->addScale( @@ -1353,8 +1505,8 @@ tensorflow::Status BinaryTensorOpWeight(OpConverterParams* params, TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); const nvinfer1::ITensor* output_tensor = layer->getOutput(0); - // transpose back dimension - if (permutation_flag) { + // Transpose back dimension + if (need_to_permute) { TF_RETURN_IF_ERROR(params->converter->TransposeTensor( const_cast(output_tensor), permutation, &output_tensor)); @@ -1398,7 +1550,7 @@ tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, int group) { return tensorflow::errors::Internal( "Conv2D expects kernel of dimension 4, at: " + node_def.name()); } - if (params->converter->is_fp16()) { + if (params->converter->precision_mode() == FP16MODE) { weights_rsck = ConvertFP32ToFP16(params->weight_store, inputs.at(1).weights()); } @@ -1445,6 +1597,8 @@ tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, int group) { nvinfer1::DimsHW(padding[0].first, padding[1].first), nvinfer1::DimsHW(padding[0].second, padding[1].second)); TFTRT_RETURN_ERROR_IF_NULLPTR(pad_layer, node_def.name()); + params->converter->MarkQuantizationRangesAsInferrable( + const_cast(tensor), pad_layer->getOutput(0)); padding = {{0, 0}, {0, 0}}; tensor = pad_layer->getOutput(0); VLOG(2) << "TENSOR after: " << DebugString(tensor->getDimensions()); @@ -1486,9 +1640,9 @@ tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, params->node_def.name()); } -tensorflow::Status BinaryTensorOpTensor(OpConverterParams* params, - const TRT_TensorOrWeights& operand_l, - const TRT_TensorOrWeights& operand_r) { +Status BinaryTensorOpTensor(OpConverterParams* params, + const TRT_TensorOrWeights& operand_l, + const TRT_TensorOrWeights& operand_r) { const auto& node_def = params->node_def; static const std::unordered_map ops{ {"Add", nvinfer1::ElementWiseOperation::kSUM}, @@ -1499,50 +1653,52 @@ tensorflow::Status BinaryTensorOpTensor(OpConverterParams* params, {"Minimum", nvinfer1::ElementWiseOperation::kMIN}, {"Maximum", nvinfer1::ElementWiseOperation::kMAX}, }; - - const nvinfer1::ITensor* tensor_l; - const nvinfer1::ITensor* tensor_r; - - nvinfer1::Dims dim_l; - nvinfer1::Dims dim_r; - - if (!TensorRTGetBroadcastShape(operand_l.GetTrtDims(), operand_l.is_tensor(), - operand_r.GetTrtDims(), operand_r.is_tensor(), - &dim_l, &dim_r)) { - return tensorflow::errors::InvalidArgument( - "Binary op broadcast scheme not supported by TensorRT op: " + - node_def.op() + ", at: " + node_def.name()); - } - - TF_RETURN_IF_ERROR( - params->converter->PrepareTensorForShape(operand_l, dim_l, &tensor_l)); - TF_RETURN_IF_ERROR( - params->converter->PrepareTensorForShape(operand_r, dim_r, &tensor_r)); - - // get trt type & shape - TFAttrs attrs(node_def); - // maybe this part has to be moved into the block of rsqrt later - nvinfer1::DataType dtype = attrs.get("T"); - - // check type consistency - TFTRT_CHECK_EQ_TYPE(tensor_l->getType(), dtype); - TFTRT_CHECK_EQ_TYPE(tensor_r->getType(), dtype); auto op_pair = ops.find(node_def.op()); if (op_pair == ops.end()) { - return tensorflow::errors::Unimplemented( - "binary op: ", node_def.op(), " not supported at: ", node_def.name()); + return errors::Unimplemented("Binary op ", node_def.op(), + " not supported at: ", node_def.name()); } + nvinfer1::Dims broadcasted_dims_l, broadcasted_dims_r; + Status status = params->converter->GetTrtBroadcastShape( + operand_l, operand_r, &broadcasted_dims_l, &broadcasted_dims_r); + if (!status.ok()) { + return errors::InvalidArgument( + "Unsupported binary op broadcast scheme for op ", node_def.name(), ": ", + status.error_message()); + } + if (params->validation_only) return Status::OK(); + + const nvinfer1::ITensor* tensor_l = nullptr; + const nvinfer1::ITensor* tensor_r = nullptr; + status = params->converter->PrepareTensorForShape( + operand_l, broadcasted_dims_l, &tensor_l); + if (status.ok()) { + status = params->converter->PrepareTensorForShape( + operand_r, broadcasted_dims_r, &tensor_r); + } + if (!status.ok()) { + return errors::Internal("Failed to convert binary op ", node_def.name(), + ": ", status.error_message()); + } + + // Check type consistency. + TFAttrs attrs(node_def); + nvinfer1::DataType dtype = attrs.get("T"); + TFTRT_CHECK_EQ_TYPE(tensor_l->getType(), dtype) + << DebugString(tensor_l->getType()) << " vs " << DebugString(dtype); + TFTRT_CHECK_EQ_TYPE(tensor_r->getType(), dtype) + << DebugString(tensor_r->getType()) << " vs " << DebugString(dtype); + + // Add ElementWise layer. nvinfer1::IElementWiseLayer* layer = params->converter->network()->addElementWise( - // TODO(aaroey): will tensor_l/tensor_r get modified? *const_cast(tensor_l), *const_cast(tensor_r), op_pair->second); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); - nvinfer1::ITensor* output_tensor = layer->getOutput(0); - // pass the output + // Pass the output params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return tensorflow::Status::OK(); } @@ -1789,6 +1945,8 @@ tensorflow::Status ConvertPool(OpConverterParams* params) { nvinfer1::DimsHW(padding[0].first, padding[1].first), nvinfer1::DimsHW(padding[0].second, padding[1].second)); TFTRT_RETURN_ERROR_IF_NULLPTR(pad_layer, node_def.name()); + params->converter->MarkQuantizationRangesAsInferrable( + const_cast(tensor), pad_layer->getOutput(0)); padding = {{0, 0}, {0, 0}}; tensor = pad_layer->getOutput(0); } @@ -1796,6 +1954,11 @@ tensorflow::Status ConvertPool(OpConverterParams* params) { nvinfer1::IPoolingLayer* layer = params->converter->network()->addPooling( *const_cast(tensor), type, ksize); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + // TODO(tmorris): Average pooling may not be entirely safe to infer + // quantization range through (at least forwards - backwards should be fine). + // Max pooling is okay. + params->converter->MarkQuantizationRangesAsInferrable( + const_cast(tensor), layer->getOutput(0)); layer->setStride(stride); layer->setPadding({padding[0].first, padding[1].first}); @@ -1813,110 +1976,290 @@ tensorflow::Status ConvertPool(OpConverterParams* params) { } tensorflow::Status ConvertActivation(OpConverterParams* params) { - const nvinfer1::ITensor* tensor = params->inputs.at(0).tensor(); + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + if (inputs.size() != 1) { + return tensorflow::errors::InvalidArgument( + node_def.op(), " expects one input, at ", node_def.name()); + } + if (!inputs.at(0).is_tensor()) { + return tensorflow::errors::Unimplemented( + node_def.op(), " is only implemented for tensors, at ", + node_def.name()); + } + static const std::unordered_map ops{ + {"Relu", nvinfer1::ActivationType::kRELU}, + {"Sigmoid", nvinfer1::ActivationType::kSIGMOID}, + {"Tanh", nvinfer1::ActivationType::kTANH}, + }; + auto op_pair = ops.find(node_def.op()); + if (op_pair == ops.end()) { + return tensorflow::errors::Unimplemented("Activation op: ", node_def.op(), + " not supported at: ", + node_def.name()); + } + if (params->validation_only) return tensorflow::Status::OK(); + + // Start conversion. + const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); nvinfer1::IActivationLayer* layer = params->converter->network()->addActivation( - *const_cast(tensor), - nvinfer1::ActivationType::kRELU); - TFTRT_RETURN_ERROR_IF_NULLPTR(layer, params->node_def.name()); + *const_cast(tensor), op_pair->second); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); + // Set quantization range for output of Sigmoid, Tanh. + if (node_def.op() == "Sigmoid") { + params->converter->ProvideQuantizationRange(output_tensor, 0.0f, 1.0f); + } else if (node_def.op() == "Tanh") { + params->converter->ProvideQuantizationRange(output_tensor, -1.0f, 1.0f); + } params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return tensorflow::Status::OK(); } -tensorflow::Status ConvertScale(OpConverterParams* params) { +Status ConvertQuantize(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + if ((inputs.size() == 0) || + (node_def.op() == "FakeQuantWithMinMaxArgs" && inputs.size() != 1) || + (node_def.op() == "FakeQuantWithMinMaxVars" && inputs.size() != 3) || + (node_def.op() == "QuantizeAndDequantizeV2" && inputs.size() != 3) || + (node_def.op() == "QuantizeAndDequantizeV3" && inputs.size() != 4)) { + return errors::InvalidArgument("Invalid number of inputs for ", + node_def.op(), ", at ", node_def.name()); + } + if (inputs.at(0).is_weights()) { + // TensorRT will automatically quantize weights, so we will ignore ranges + // for weights. + params->outputs->push_back(inputs.at(0)); + return Status::OK(); + } + float min_range = 0.0f; + float max_range = 0.0f; + if (node_def.op() == "FakeQuantWithMinMaxArgs") { + // Get ranges via node attributes. + TFAttrs attrs(node_def); + if (attrs.count("min") == 0 || attrs.count("max") == 0) { + return errors::InvalidArgument("Min or max attribute not found for ", + node_def.op(), " at ", node_def.name()); + } + min_range = attrs.get("min"); + max_range = attrs.get("max"); + } else if (node_def.op() == "FakeQuantWithMinMaxVars" || + node_def.op() == "QuantizeAndDequantizeV2" || + node_def.op() == "QuantizeAndDequantizeV3") { + // Get ranges via inputs. + if (!inputs.at(1).is_weights() || !inputs.at(2).is_weights()) { + return errors::InvalidArgument("Min and max inputs for ", node_def.op(), + " must be weights not tensors, at ", + node_def.name()); + } + auto get_weights_value = [&inputs](int index) { + auto raw_weights = static_cast( + const_cast(inputs.at(index).weights().GetValues())); + return raw_weights[0]; + }; + min_range = get_weights_value(1); + max_range = get_weights_value(2); + } else { + return errors::InvalidArgument("Unknown quantization op ", node_def.op(), + ", at ", node_def.name()); + } + if (params->validation_only) return Status::OK(); + + // Store ranges for tensor + params->converter->ProvideQuantizationRange( + const_cast(inputs.at(0).tensor()), min_range, + max_range); + // Sometimes, TRT may not quantize a tensor, either because it chooses to + // execute a higher precision kernel or because of op fusion. In these cases, + // accuracy will suffer if the model was trained to expect quantization at + // that tensor. We should consider adding a clip(tensor, min_range, max_range) + // operation here to ensure that any arbitrarily placed quantize node will + // execute as expected. However, this will negatively affect performance. If + // users train their models in a way which models inference as close as + // possible (i.e. not quantizing in place where fusion will occur), then there + // is no problem with the current implementation. + params->outputs->push_back(inputs.at(0)); + return Status::OK(); +} + +// TODO(pdavoodi): we should update relu6 implementation once TensorRT supports +// Relu6 natively. +tensorflow::Status ConvertRelu6(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + if (inputs.size() != 1) { + return tensorflow::errors::InvalidArgument( + "Invalid number of inputs for Relu6, at ", node_def.name()); + } + if (inputs.at(0).is_weights()) { + return tensorflow::errors::Unimplemented( + "Relu6 is only implemented for tensors, not weights, at ", + node_def.name()); + } + if (params->validation_only) return Status::OK(); + // *************************************************************************** + // TensorRT does not implement Relu6 natively. This function converts Relu6 op + // to available TensorRT ops: Relu6(x) = min(Relu(x), 6) + // *************************************************************************** + + // Input Tensor + const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); + + // Relu operation i.e. Relu(x) = max(0, x) + nvinfer1::IActivationLayer* relu_layer = + params->converter->network()->addActivation( + *const_cast(tensor), + nvinfer1::ActivationType::kRELU); + TFTRT_RETURN_ERROR_IF_NULLPTR(relu_layer, node_def.name()); + + // Large range of relu is problematic during quantization in INT8 precision + // mode. Setting dynamic range of relu = [0.f, 6.0f] helps with quantization. + // TRT only uses dynamic ranges in INT8 precision mode, + // and this does not affect the FP32 path. + params->converter->ProvideQuantizationRange(relu_layer->getOutput(0), 0.0f, + 6.0f); + + // Create a constant layer to store the floating point weight i.e. 6.0f This + // tensor will be broadcasted uniformly during elementwise `min` operation. + // The constant has to have the same rank as the input in order for TRT to + // broadcast + nvinfer1::Dims dims; + dims.nbDims = relu_layer->getOutput(0)->getDimensions().nbDims; + for (int i = 0; i < dims.nbDims; i++) { + dims.d[i] = 1; + } + TRT_ShapedWeights weights = params->weight_store->GetTempWeights( + tensorflow::DataType::DT_FLOAT, dims); + auto weights_ptr = + static_cast(const_cast(weights.GetValues())); + weights_ptr[0] = 6.0f; + nvinfer1::IConstantLayer* const6_layer = + params->converter->network()->addConstant(dims, weights.GetTrtWeights()); + TFTRT_RETURN_ERROR_IF_NULLPTR(const6_layer, node_def.name()); + params->converter->ProvideQuantizationRange(const6_layer->getOutput(0), 0.0f, + 6.0f); + + // ElementWise Min Operation + // Min op is a nop for INT8 execution path, as the input tensor + // to this layer will only have values in range [0.f, 6.0f]. + const nvinfer1::ITensor* tensor_l = relu_layer->getOutput(0); + const nvinfer1::ITensor* tensor_r = const6_layer->getOutput(0); + nvinfer1::IElementWiseLayer* relu6_layer = + params->converter->network()->addElementWise( + *const_cast(tensor_l), + *const_cast(tensor_r), + nvinfer1::ElementWiseOperation::kMIN); + TFTRT_RETURN_ERROR_IF_NULLPTR(relu6_layer, node_def.name()); + nvinfer1::ITensor* output_tensor = relu6_layer->getOutput(0); + params->converter->ProvideQuantizationRange(output_tensor, 0.0f, 6.0f); + + params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); + return Status::OK(); +} + +tensorflow::Status ConvertBiasAdd(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; if (inputs.size() != 2 || !inputs.at(0).is_tensor() || !inputs.at(1).is_weights()) { - return tensorflow::errors::Unimplemented( - "ConvertScale only supports tensorweight: ", node_def.name()); + return errors::InvalidArgument("Input expects tensor and weights, at ", + node_def.name()); } + if (params->validation_only) return Status::OK(); - const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); - TRT_ShapedWeights weights = inputs.at(1).weights(); - if (params->converter->is_fp16()) { - weights = ConvertFP32ToFP16(params->weight_store, inputs.at(1).weights()); - } - - TRT_ShapedWeights empty_weights(weights.type_); + nvinfer1::ITensor* tensor = + const_cast(inputs.at(0).tensor()); + const nvinfer1::Dims original_dims = tensor->getDimensions(); TFAttrs attrs(node_def); - - const auto data_format = attrs.get("data_format"); - int channel_index; - const auto dims = tensor->getDimensions(); - if (data_format == "NHWC") { - // 1). NHWC is really N+C - channel_index = dims.nbDims - 1; // batch dimension is implicit here! - } else { - // 2). NCHW is really N+CHW - channel_index = 0; // batch dimension is implicit here! - } + const string data_format = attrs.get("data_format"); + const int channel_index = + (data_format == "NHWC" ? original_dims.nbDims - 1 : 0); nvinfer1::Permutation permutation; - for (int32_t i = 0; i < dims.nbDims; ++i) { - permutation.order[i] = i; - } - - if (channel_index >= 0) { + if (channel_index != 0) { + // Permute the dimensions so that the channel dimension is the first + // dimension. + for (int i = 0; i < original_dims.nbDims; ++i) { + permutation.order[i] = i; + } permutation.order[0] = channel_index; permutation.order[channel_index] = 0; - } else { - return tensorflow::errors::Unimplemented( - "TFTRT::BiasAdd cannot apply on batch dimension, at ", node_def.name()); + VLOG(1) << "ConvertBiasAdd permutation: " + << DebugString(permutation, original_dims.nbDims); } // TensorRT addScale requires input to be of rank 3, we need to apply - // transpose as well as reshape - if (channel_index != 0 || dims.nbDims != 3) { + // transpose as well as reshape. + // TODO(laigd): this doesn't match what the TRT doc says, fix the doc? + if (channel_index != 0 || original_dims.nbDims != 3) { nvinfer1::IShuffleLayer* shuffle_layer = - params->converter->network()->addShuffle( - *const_cast(tensor)); + params->converter->network()->addShuffle(*tensor); TFTRT_RETURN_ERROR_IF_NULLPTR(shuffle_layer, node_def.name()); + params->converter->MarkQuantizationRangesAsInferrable( + tensor, shuffle_layer->getOutput(0)); + + // NOTE(laigd): for some reason we need to apply the reshape + // unconditionally. The default shape has nbDims==-1 and it seems the + // behavior is undefined in some cases. nvinfer1::Dims reshape_dims; reshape_dims.nbDims = 3; - reshape_dims.d[0] = 0; // 0 copy from the input - reshape_dims.d[1] = dims.nbDims >= 2 ? 0 : 1; // 0 copy from the input - reshape_dims.d[2] = dims.nbDims >= 3 ? -1 : 1; // -1 infer from the rest + // 0 means copying from input; -1 means inferring from the rest. + reshape_dims.d[0] = 0; + reshape_dims.d[1] = original_dims.nbDims >= 2 ? 0 : 1; + reshape_dims.d[2] = original_dims.nbDims >= 3 ? -1 : 1; + shuffle_layer->setReshapeDimensions(reshape_dims); + if (channel_index != 0) { - // maybe we do not need this check. concerned about TRT optimization shuffle_layer->setFirstTranspose(permutation); } - shuffle_layer->setReshapeDimensions(reshape_dims); tensor = shuffle_layer->getOutput(0); } + TRT_ShapedWeights weights = inputs.at(1).weights(); + if (params->converter->precision_mode() == FP16MODE) { + weights = ConvertFP32ToFP16(params->weight_store, weights); + } nvinfer1::ScaleMode mode = nvinfer1::ScaleMode::kCHANNEL; if (weights.shape_.d[0] == 1) { mode = nvinfer1::ScaleMode::kUNIFORM; } + TRT_ShapedWeights empty_weights(weights.type_); nvinfer1::IScaleLayer* layer = params->converter->network()->addScale( - *const_cast(tensor), mode, weights.GetTrtWeights(), - empty_weights.GetTrtWeights(), empty_weights.GetTrtWeights()); + *tensor, mode, weights.GetTrtWeights(), empty_weights.GetTrtWeights(), + empty_weights.GetTrtWeights()); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); - // restore transpose & reshape - if (channel_index != 0 || dims.nbDims != 3) { + // Restore transpose & reshape. + if (channel_index != 0 || original_dims.nbDims != 3) { nvinfer1::IShuffleLayer* shuffle_layer = - params->converter->network()->addShuffle( - *const_cast(output_tensor)); + params->converter->network()->addShuffle(*output_tensor); TFTRT_RETURN_ERROR_IF_NULLPTR(shuffle_layer, node_def.name()); - nvinfer1::Dims reshape_dims = dims; - int tmp = reshape_dims.d[channel_index]; - reshape_dims.d[channel_index] = reshape_dims.d[0]; - reshape_dims.d[0] = tmp; + // NOTE: for same reason as mentioned above we need to apply the reshape + // unconditionally. + nvinfer1::Dims reshape_dims = original_dims; + if (channel_index != 0) { + // NOTE: according to NVIDIA dimension types are deprecated, so we don't + // need to copy them back. + reshape_dims.d[channel_index] = original_dims.d[0]; + reshape_dims.d[0] = original_dims.d[channel_index]; + } shuffle_layer->setReshapeDimensions(reshape_dims); + if (channel_index != 0) { shuffle_layer->setSecondTranspose(permutation); } + params->converter->MarkQuantizationRangesAsInferrable( + output_tensor, shuffle_layer->getOutput(0)); output_tensor = shuffle_layer->getOutput(0); } params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); - return tensorflow::Status::OK(); + return Status::OK(); } Status GetTensorDimsWithProtoShape(const Tensor& tensor, @@ -2053,9 +2396,9 @@ tensorflow::Status ConvertConst(OpConverterParams* params) { uint8* data = reinterpret_cast(temp_weights.data()); std::copy(data, data + tensor.NumElements(), dst); } else { - return errors::FailedPrecondition( - "Unexpected data type: ", DataTypeString(dtype), - " at: ", node_def.name()); + return errors::FailedPrecondition("Unexpected data type: ", + DataTypeString(dtype), " at: ", + node_def.name()); } } } @@ -2070,32 +2413,41 @@ tensorflow::Status ConvertConst(OpConverterParams* params) { } tensorflow::Status ConvertIdentity(OpConverterParams* params) { + // TODO(tmorris): TRT's Identity layer does not get optimized away as of TRT + // 5.0, however once we know that it does it would be nice to use that + // instead. params->outputs->push_back(params->inputs.at(0)); return tensorflow::Status::OK(); } -tensorflow::Status ConvertBinary(OpConverterParams* params) { +Status ConvertBinary(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; if (inputs.size() != 2) { - return tensorflow::errors::FailedPrecondition( - "Binary ops require two tensor input, at ", node_def.name()); + return errors::InvalidArgument("Binary ops require two inputs, at ", + node_def.name()); } // Constant folding should have been done by TensorFlow - if (inputs.at(0).is_weights() && inputs.at(1).is_weights()) { - return tensorflow::errors::Unimplemented( + return errors::Unimplemented( "Constant folding is falled back to TensorFlow, binary op received " "both input as constant at: ", node_def.name()); } - // Try to convert into Scale layer first (for better performance) + // TODO(tmorris): TRT plans to deprecate IScaleLayer and will replace it with + // IElementwiseLayer. At that point, we can remove BinaryTensorOpWeight. For + // now, the performance will be slightly better with IScaleLayer because it + // can be fused in more situations. However, most of the benefits of + // IScaleLayer are when the layer performs both a shift and a scale, which we + // don't do except for convolutions. + // + // Try to convert into Scale layer first (for better performance). // Since scale layer supports restricted broadcast policy and op types, we // allow failure and try to handle it through Elementwise op - // (BinaryTensorOpTensor) - Status status = tensorflow::Status::OK(); + // (BinaryTensorOpTensor). + Status status = Status::OK(); if (inputs.at(0).is_tensor() && inputs.at(1).is_weights()) { status = BinaryTensorOpWeight(params, inputs.at(0).tensor(), inputs.at(1).weights(), false); @@ -2103,7 +2455,10 @@ tensorflow::Status ConvertBinary(OpConverterParams* params) { status = BinaryTensorOpWeight(params, inputs.at(1).tensor(), inputs.at(0).weights(), true); } + // If both input are tensors, or one of them is weights but the conversion + // above failed, try the conversion using BinaryTensorOpTensor. if ((inputs.at(0).is_tensor() && inputs.at(1).is_tensor()) || !status.ok()) { + if (!status.ok()) VLOG(1) << status; status = BinaryTensorOpTensor(params, inputs.at(0), inputs.at(1)); } return status; @@ -2133,6 +2488,20 @@ tensorflow::Status ConvertUnary(OpConverterParams* params) { nvinfer1::IUnaryLayer* layer; if (node_def.op() == "Rsqrt") { + // We will need a quantization range for intermediate tensor if not using + // calibration. + // + // x -> [Sqrt] -> sqrt(x) -> [Recip] -> 1/sqrt(x) + // ^ + // need range here + if (params->converter->precision_mode() == INT8MODE && + !params->converter->use_calibration()) { + return errors::Unimplemented( + "Intermediate quantization range cannot be determined without" + " calibration for Rsqrt, consider replacing with " + "Sqrt -> FakeQuant -> Reciprocal ops, at ", + node_def.name()); + } layer = params->converter->network()->addUnary( *const_cast(tensor), nvinfer1::UnaryOperation::kSQRT); @@ -2156,6 +2525,48 @@ tensorflow::Status ConvertUnary(OpConverterParams* params) { return tensorflow::Status::OK(); } +tensorflow::Status ConvertSquare(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + if (inputs.size() != 1) { + return tensorflow::errors::InvalidArgument("Square expects one input, at ", + node_def.name()); + } + if (inputs.at(0).is_weights()) { + return tensorflow::errors::Unimplemented( + "Square is only implemented for tensors, at ", node_def.name()); + } + if (params->validation_only) return Status::OK(); + + // Constant 2 with same rank as input + nvinfer1::Dims dims = inputs.at(0).GetTrtDims(); + for (int i = 0; i < dims.nbDims; i++) { + dims.d[i] = 1; + } + TRT_ShapedWeights weights = params->weight_store->GetTempWeights( + tensorflow::DataType::DT_FLOAT, dims); + auto weights_ptr = + static_cast(const_cast(weights.GetValues())); + weights_ptr[0] = 2.f; + nvinfer1::IConstantLayer* const2_layer = + params->converter->network()->addConstant(dims, weights.GetTrtWeights()); + TFTRT_RETURN_ERROR_IF_NULLPTR(const2_layer, node_def.name()); + + // ElementWise Pow Operation + const nvinfer1::ITensor* tensor_l = inputs.at(0).tensor(); + const nvinfer1::ITensor* tensor_r = const2_layer->getOutput(0); + nvinfer1::IElementWiseLayer* layer = + params->converter->network()->addElementWise( + *const_cast(tensor_l), + *const_cast(tensor_r), + nvinfer1::ElementWiseOperation::kPOW); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + nvinfer1::ITensor* output_tensor = layer->getOutput(0); + + params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); + return tensorflow::Status::OK(); +} + tensorflow::Status ConvertReduce(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; @@ -2692,6 +3103,8 @@ tensorflow::Status ConvertSoftmax(OpConverterParams* params) { layer->setAxes(1 << (nbDims - 1)); nvinfer1::ITensor* output_tensor = layer->getOutput(0); + // Quantization range for SoftMax is always (0, 1) + params->converter->ProvideQuantizationRange(output_tensor, 0.0f, 1.0f); params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return tensorflow::Status::OK(); } @@ -2716,9 +3129,9 @@ tensorflow::Status ConvertTopK(OpConverterParams* params) { op = nvinfer1::TopKOperation::kMAX; reducedAxes |= 1 << (nbDims - 1); } else { - return tensorflow::errors::Unimplemented( - "Operation: " + node_def.op() + - " not implemented, at: " + node_def.name()); + return tensorflow::errors::Unimplemented("Operation: " + node_def.op() + + " not implemented, at: " + + node_def.name()); } nvinfer1::ITopKLayer* layer = params->converter->network()->addTopK( @@ -2732,40 +3145,52 @@ tensorflow::Status ConvertTopK(OpConverterParams* params) { return tensorflow::Status::OK(); } -void TrtNodeValidator::RegisterOpValidators() { +static void RegisterValidatableOpConverters( + std::unordered_map* registration) { // TODO(laigd): support all op types. - op_validators_["Const"] = ConvertConst; - op_validators_["Transpose"] = ConvertTranspose; - op_validators_["Reshape"] = ConvertReshape; - op_validators_["MatMul"] = ConvertMatMul; + (*registration)["BiasAdd"] = ConvertBiasAdd; + (*registration)["Const"] = ConvertConst; + (*registration)["Transpose"] = ConvertTranspose; + (*registration)["Reshape"] = ConvertReshape; + (*registration)["MatMul"] = ConvertMatMul; + (*registration)["Relu6"] = ConvertRelu6; + (*registration)["Square"] = ConvertSquare; + + for (auto quantization_op_type : + {"QuantizeAndDequantizeV2", "QuantizeAndDequantizeV3", + "FakeQuantWithMinMaxVars", "FakeQuantWithMinMaxArgs"}) { + (*registration)[quantization_op_type] = ConvertQuantize; + } + for (auto binary_op_type : + {"Add", "Mul", "Sub", "Div", "RealDiv", "Maximum", "Minimum"}) { + (*registration)[binary_op_type] = ConvertBinary; + } + for (auto activation_op_type : {"Relu", "Sigmoid", "Tanh"}) { + (*registration)[activation_op_type] = ConvertActivation; + } +} + +void TrtNodeValidator::RegisterOpValidators() { + RegisterValidatableOpConverters(&op_validators_); } void Converter::RegisterOpConverters() { - // vgg_16 slim implementation + RegisterValidatableOpConverters(&op_registry_); + op_registry_["Conv2D"] = ConvertConv2D; op_registry_["DepthwiseConv2dNative"] = ConvertConv2DDepthwise; - op_registry_["Relu"] = ConvertActivation; op_registry_["MaxPool"] = ConvertPool; op_registry_["AvgPool"] = ConvertPool; - op_registry_["BiasAdd"] = ConvertScale; - op_registry_["Const"] = ConvertConst; // TODO(ben,jie): this is a temp hack. op_registry_["Identity"] = ConvertIdentity; // Identity should be removed op_registry_["Snapshot"] = ConvertIdentity; // Snapshot should be removed - // resnet_50_v1 slim implementation - op_registry_["Add"] = ConvertBinary; - op_registry_["Mul"] = ConvertBinary; - op_registry_["Sub"] = ConvertBinary; op_registry_["Pad"] = ConvertPad; op_registry_["ConcatV2"] = ConvertConcat; op_registry_["FusedBatchNorm"] = ConvertFusedBatchNorm; op_registry_["FusedBatchNormV2"] = ConvertFusedBatchNorm; - op_registry_["Div"] = ConvertBinary; - op_registry_["RealDiv"] = ConvertBinary; - op_registry_["Rsqrt"] = ConvertUnary; op_registry_["Reciprocal"] = ConvertUnary; op_registry_["Exp"] = ConvertUnary; @@ -2774,20 +3199,19 @@ void Converter::RegisterOpConverters() { op_registry_["Abs"] = ConvertUnary; op_registry_["Neg"] = ConvertUnary; - op_registry_["Transpose"] = ConvertTranspose; - op_registry_["Reshape"] = ConvertReshape; - op_registry_["Sum"] = ConvertReduce; op_registry_["Prod"] = ConvertReduce; op_registry_["Max"] = ConvertReduce; op_registry_["Min"] = ConvertReduce; op_registry_["Mean"] = ConvertReduce; - op_registry_["Maximum"] = ConvertBinary; - op_registry_["Minimum"] = ConvertBinary; op_registry_["Softmax"] = ConvertSoftmax; - op_registry_["MatMul"] = ConvertMatMul; op_registry_["BatchMatMul"] = ConvertBatchMatMul; op_registry_["TopKV2"] = ConvertTopK; + op_registry_["Relu6"] = ConvertRelu6; + op_registry_["QuantizeAndDequantizeV2"] = ConvertQuantize; + op_registry_["QuantizeAndDequantizeV3"] = ConvertQuantize; + op_registry_["FakeQuantWithMinMaxVars"] = ConvertQuantize; + op_registry_["FakeQuantWithMinMaxArgs"] = ConvertQuantize; plugin_converter_ = ConvertPlugin; } @@ -2798,7 +3222,7 @@ tensorflow::Status ConvertGraphDefToEngine( const std::vector& input_shapes, Logger* logger, nvinfer1::IGpuAllocator* allocator, TRTInt8Calibrator* calibrator, - TrtUniquePtrType* engine, + TrtUniquePtrType* engine, bool use_calibration, bool* convert_successfully) { engine->reset(); if (convert_successfully) *convert_successfully = false; @@ -2813,7 +3237,11 @@ tensorflow::Status ConvertGraphDefToEngine( builder->setHalf2Mode(true); } else if (precision_mode == INT8MODE) { builder->setInt8Mode(true); - builder->setInt8Calibrator(calibrator); + if (use_calibration) { + builder->setInt8Calibrator(calibrator); + } else { + builder->setInt8Calibrator(nullptr); + } } // Create the network. @@ -2826,7 +3254,7 @@ tensorflow::Status ConvertGraphDefToEngine( // Build the network VLOG(1) << "Starting engine conversion "; - Converter converter(trt_network.get(), precision_mode == FP16MODE); + Converter converter(trt_network.get(), precision_mode, use_calibration); std::vector> output_tensors; // Graph nodes are already topologically sorted during construction for (const auto& node_def : gdef.node()) { @@ -2882,6 +3310,9 @@ tensorflow::Status ConvertGraphDefToEngine( TF_RETURN_IF_ERROR(converter.RenameAndMarkOutputTensors(output_tensors)); if (convert_successfully) *convert_successfully = true; + // Apply user provided quantization ranges to tensors + converter.MaybeApplyQuantizationRanges(); + // Build the engine. VLOG(1) << "Starting engine creation"; engine->reset(builder->buildCudaEngine(*converter.network())); @@ -3026,7 +3457,8 @@ tensorflow::Status ConvertSegmentToGraphDef( } } *common_scope = local_scope; - VLOG(0) << "Segment @scope '" << local_scope << "', converted to graph"; + VLOG(1) << "Converted TensorRT candidate segment @scope '" << local_scope + << "' to a GraphDef"; return tensorflow::Status::OK(); } diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h index 5cc28b33e7f..f1c4c121ae6 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h @@ -92,7 +92,8 @@ struct EngineInfo { EngineInfo() : engine_type(EngineType::TRTStatic), max_workspace_size_bytes(0), - precision_mode(FP32MODE) {} + precision_mode(FP32MODE), + use_calibration(true) {} string engine_name; string device; @@ -109,6 +110,7 @@ struct EngineInfo { int maximum_cached_engines; std::vector cached_engine_batches; int precision_mode; + bool use_calibration; }; // Constructs a graphdef from the segment in the given graph. Adds placeholder @@ -145,7 +147,7 @@ tensorflow::Status ConvertGraphDefToEngine( const std::vector& input_shapes, Logger* logger, nvinfer1::IGpuAllocator* allocator, TRTInt8Calibrator* calibrator, - TrtUniquePtrType* engine, + TrtUniquePtrType* engine, bool use_calibration, bool* convert_successfully); // Helper class for the segmenter to determine whether an output edge from the @@ -392,7 +394,8 @@ class TrtNodeValidator { // Class to convert TF nodes to TRT network. class Converter { public: - Converter(nvinfer1::INetworkDefinition* trt_network, bool is_fp16); + Converter(nvinfer1::INetworkDefinition* trt_network, int precision_mode, + bool use_calibration); ////////////////////////////////////////////////////////////////////////////// // Methods used by the TRT engine builder to build a TRT network from a TF @@ -422,8 +425,43 @@ class Converter { // to add TRT layers. nvinfer1::INetworkDefinition* network() { return trt_network_; } - // Is the converter operating in fp16 mode? - bool is_fp16() const { return is_fp16_; } + // What precision are we targeting? + int precision_mode() const { return precision_mode_; } + + // Calibration will be or was previously performed on this network? + bool use_calibration() const { return use_calibration_; } + + // This should be called on the inputs and outputs of any layer we create + // where we know that the quantization range does not change during that + // operation. (e.g. Reshape, Transpose, Identity, MaxPool). + void MarkQuantizationRangesAsInferrable(nvinfer1::ITensor* input, + nvinfer1::ITensor* output); + + // This function should be called when we know the quantization range of a + // tensor, either from a quantize/dequantize node or when the output is a + // fixed range (e.g. SoftMax, Relu6, Sigmoid). + void ProvideQuantizationRange(nvinfer1::ITensor* tensor, float min_range, + float max_range); + + // Should be called when full TRT network has been constructed and before + // building the engine. + void MaybeApplyQuantizationRanges(); + + // This should be called on the inputs and outputs of any layer we create + // where we know that the quantization range does not change during that + // operation. (e.g. Reshape, Transpose, Identity, MaxPool). + void MarkQuantizationRangesAsInferrable(nvinfer1::ITensor* input, + nvinfer1::ITensor* output); + + // This function should be called when we know the quantization range of a + // tensor, either from a quantize/dequantize node or when the output is a + // fixed range (e.g. SoftMax, Relu6, Sigmoid). + void ProvideQuantizationRange(nvinfer1::ITensor* tensor, + float min_range, float max_range); + + // Should be called when full TRT network has been constructed and before + // building the engine. + void ApplyQuantizationRanges(bool warn_missing_ranges); // Below are helper methods for op converters to add different layers to the // TRT network. @@ -440,6 +478,13 @@ class Converter { const nvinfer1::Dims& dims, const nvinfer1::ITensor** tensor); + // Return OK if the broadcast scheme is supported and compute the shapes after + // broadcasting. + Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l, + const TRT_TensorOrWeights& operand_r, + nvinfer1::Dims* operand_l_new_dims, + nvinfer1::Dims* operand_r_new_dims) const; + private: // Verify the provided batch_size is consistent with batch_size_ and update it // if necessary. @@ -457,6 +502,12 @@ class Converter { void RegisterOpConverters(); + void PropagateQuantizationRanges(); + + // Gets the min and max value in a TRT_ShapedWeights + Status GetWeightRange(const TRT_ShapedWeights& weights, float* out_min, + float* out_max) const; + // Registered op converters by op type. std::unordered_map op_registry_; @@ -472,7 +523,25 @@ class Converter { // Store the weights added during construction of trt_network_. TrtWeightStore weight_store_; - const bool is_fp16_; + // During conversion, this table is populated with quantization ranges per + // tensor. MaybeApplyQuantizationRanges() will use this table to set the TRT + // quantization ranges. Since TRT only supports symmetric ranges, we will + // store the range as a single float = max(abs(min_range), abs(max_range)). + // Range refers to the floating point values, e.g. min_range = 0.0f, max_range + // = 6.0f for Relu6. + std::unordered_map quantization_ranges_; + + // Edges where quantization ranges can be inferred (copied) across ops - from + // first tensor to second tensor. PropagateQuantizationRanges() will propagate + // known ranges from quantization_ranges_ across these edges, adding the new + // ranges to quantization_ranges_ so that they can be applied in + // MaybeApplyQuantizationRanges(). + std::vector> + quantization_infer_; + + const int precision_mode_; + + const bool use_calibration_; // Batch size of inputs to trt_network_ added by AddInputTensor(). During // network construction it will update this, use it to verify the batch diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc index c3a39395f3a..a95ab8dfbbb 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc @@ -35,7 +35,10 @@ limitations under the License. #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/config.pb.h" // NOLINT +#include "tensorflow/core/public/session.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -47,7 +50,9 @@ namespace tensorflow { namespace tensorrt { namespace convert { +using ::tensorflow::strings::StrCat; using ::testing::ElementsAre; +using ::testing::ElementsAreArray; // TODO(laigd): put this into some test utils file. void ExpectStatus(Status status, error::Code code = error::OK, @@ -69,6 +74,32 @@ nvinfer1::Dims GetTestDims(const std::vector& d) { return dims; } +nvinfer1::DataType TfDataTypeToTrt(DataType tf_dtype) { + switch (tf_dtype) { + case DT_FLOAT: + return nvinfer1::DataType::kFLOAT; + case DT_HALF: + return nvinfer1::DataType::kHALF; + case DT_INT32: + return nvinfer1::DataType::kINT32; + default: + QCHECK(false) << "Unexpected data type " << DataTypeString(tf_dtype); + } +} + +DataType TrtDataTypeToTf(nvinfer1::DataType trt_dtype) { + switch (trt_dtype) { + case nvinfer1::DataType::kFLOAT: + return DT_FLOAT; + case nvinfer1::DataType::kHALF: + return DT_HALF; + case nvinfer1::DataType::kINT32: + return DT_INT32; + default: + QCHECK(false) << "Unexpected data type " << static_cast(trt_dtype); + } +} + NodeDef MakeNodeDef(const string& name, const string& op, const std::vector& inputs) { NodeDef node_def; @@ -111,6 +142,35 @@ bool TrtDimsEqualsArray(const std::vector& lhs, return TrtDimsEquals(GetTestDims(lhs), rhs); } +// TODO(laigd): define a parameterized matcher that can compare against the +// vector. +void ExpectTrtDimsEqualsArray(const std::vector& lhs, + const nvinfer1::Dims& rhs) { + EXPECT_TRUE(TrtDimsEqualsArray(lhs, rhs)) + << "expected: " << DebugString(GetTestDims(lhs)) << "\n" + << " actual: " << DebugString(rhs); +} + +template +void ExpectArrayNear(const std::vector& lhs, const std::vector& rhs) { + ASSERT_EQ(lhs.size(), rhs.size()); + for (int i = 0; i < lhs.size(); i++) { + EXPECT_FLOAT_EQ(lhs[i], rhs[i]); + } +} + +// Eigen::half cannot implicitly convert to float which is required for +// EXPECT_FLOAT_EQ. +template <> +void ExpectArrayNear(const std::vector& lhs, + const std::vector& rhs) { + ASSERT_EQ(lhs.size(), rhs.size()); + for (int i = 0; i < lhs.size(); i++) { + EXPECT_FLOAT_EQ(Eigen::half_impl::half_to_float(lhs[i]), + Eigen::half_impl::half_to_float(rhs[i])); + } +} + bool TrtShapedWeightsEquals(const TRT_ShapedWeights& lhs, const TRT_ShapedWeights& rhs) { return TrtDimsEquals(lhs.shape_, rhs.shape_) && lhs.type_ == rhs.type_ && @@ -121,8 +181,7 @@ template void ValidateWeights(const TRT_ShapedWeights& weights, const std::vector& expected_dims, const std::vector& expected_value) { - EXPECT_TRUE(TrtDimsEqualsArray(expected_dims, weights.shape_)) - << weights.DebugString(); + ExpectTrtDimsEqualsArray(expected_dims, weights.shape_); ASSERT_EQ(expected_value.size(), weights.count()) << weights.DebugString(); const T* actual_values = static_cast(weights.GetValues()); for (int i = 0; i < expected_value.size(); ++i) { @@ -133,11 +192,12 @@ void ValidateWeights(const TRT_ShapedWeights& weights, // Fake ITensor implementation for testing purposes. class FakeITensor : public nvinfer1::ITensor { public: - FakeITensor() {} + FakeITensor() : dynamic_range_(0.0f) {} - FakeITensor(const nvinfer1::Dims& dims) : dims_(dims) {} + FakeITensor(const nvinfer1::Dims& dims) : dims_(dims), dynamic_range_(0.0f) {} - FakeITensor(const std::vector& dims) : dims_(GetTestDims(dims)) {} + FakeITensor(const std::vector& dims) + : dims_(GetTestDims(dims)), dynamic_range_(0.0f) {} void setName(const char* name) override { name_ = name; } @@ -166,7 +226,12 @@ class FakeITensor : public nvinfer1::ITensor { } #if NV_TENSORRT_MAJOR >= 5 - bool setDynamicRange(float min, float max) override {} + bool setDynamicRange(float min, float max) override { + dynamic_range_ = std::max(std::abs(min), std::abs(max)); + return true; + } + + float getDynamicRange() const override { return dynamic_range_; } #endif private: @@ -174,6 +239,7 @@ class FakeITensor : public nvinfer1::ITensor { nvinfer1::Dims dims_; nvinfer1::DataType type_; nvinfer1::TensorLocation location_; + float dynamic_range_; }; TEST(TRT_ShapedWeights_Test, Basic) { @@ -265,9 +331,7 @@ TEST(TRT_TensorOrWeights_Test, Basic) { EXPECT_EQ(1, ptr->batch_size()); } EXPECT_EQ(&itensor, ptr->tensor()); - EXPECT_TRUE(TrtDimsEqualsArray({1}, ptr->GetTrtDims())) - << "- expected: " << DebugString(dims) - << "\n vs\n- actual: " << DebugString(ptr->GetTrtDims()); + ExpectTrtDimsEqualsArray({1}, ptr->GetTrtDims()); } } } @@ -286,9 +350,7 @@ TEST(TRT_TensorOrWeights_Test, Basic) { EXPECT_EQ(false, ptr->is_weights()); EXPECT_EQ(1, ptr->batch_size()); EXPECT_NE(nullptr, ptr->tensor()); - EXPECT_TRUE(TrtDimsEqualsArray({1}, ptr->GetTrtDims())) - << "- expected: " << DebugString(dims) - << "\n vs\n- actual: " << DebugString(ptr->GetTrtDims()); + ExpectTrtDimsEqualsArray({1}, ptr->GetTrtDims()); } } // Test constructor with TRT_ShapedWeights argument. @@ -305,9 +367,7 @@ TEST(TRT_TensorOrWeights_Test, Basic) { nvinfer1::Dims dims; dims.nbDims = 0; - EXPECT_TRUE(TrtDimsEqualsArray({}, ptr->GetTrtDims())) - << "- expected: " << DebugString(dims) - << "\n vs\n- actual: " << DebugString(ptr->GetTrtDims()); + ExpectTrtDimsEqualsArray({}, ptr->GetTrtDims()); } } } @@ -341,34 +401,50 @@ TEST_F(ValidatorTest, ConvertToTensorOrWeights) { graph_properties, &output)); ValidateWeights(output.weights(), {2}, {1.0, 2.0}); } - // Convert non-Const. We test the case where the non-batch dimemsion is - // unknown as well, to make sure the validator allows that. - for (const int32 non_batch_dim : {-1, 2}) { - const int32 batch_size = 12; + // Helper method to run ConvertToTensorOrWeights() with predefined parameters. + auto convert_to_tensor_or_weights = [this](const std::vector& dims, + TRT_TensorOrWeights* output) { Scope s = Scope::NewRootScope(); - ops::Placeholder::Attrs attrs; - TF_EXPECT_OK(TensorShapeUtils::MakeShape( - std::vector{batch_size, non_batch_dim}, &attrs.shape_)); + const auto attrs = ops::Placeholder::Shape(PartialTensorShape{dims}); auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT, attrs); auto add = ops::Add(s.WithOpName("add"), feed, feed); grappler::GrapplerItem item; TF_EXPECT_OK(s.ToGraphDef(&item.graph)); - grappler::GraphProperties graph_properties(item); TF_EXPECT_OK(graph_properties.InferStatically(true)); - - auto& node_def = add.operation.node()->def(); + const NodeDef& node_def = add.operation.node()->def(); + return this->ConvertToTensorOrWeights(node_def, /*output_port=*/0, + graph_properties, output); + }; + // Convert non-Const with #dims > nvinfer1::Dims::MAX_DIMS+1. + { TRT_TensorOrWeights output; - ExpectStatus(ConvertToTensorOrWeights(node_def, /*output_port=*/0, - graph_properties, &output)); + ExpectStatus( + convert_to_tensor_or_weights( + std::vector(nvinfer1::Dims::MAX_DIMS + 2, 1), &output), + error::OUT_OF_RANGE, "Input tensor rank is greater than 9"); + } + // Convert non-Const with #dims < 2. + { + TRT_TensorOrWeights output; + ExpectStatus( + convert_to_tensor_or_weights({1}, &output), error::INVALID_ARGUMENT, + "Input tensor with rank<2 is not supported since the first dimension " + "is treated as batch dimension by TRT"); + } + // Convert non-Const. We test the case where the non-batch dimemsion is + // unknown as well, to make sure the validator allows that. + for (const int32 non_batch_dim : {-1, 2}) { + const int32 batch_size = 12; + TRT_TensorOrWeights output; + ExpectStatus( + convert_to_tensor_or_weights({batch_size, non_batch_dim}, &output)); EXPECT_EQ(true, output.is_tensor()); EXPECT_EQ(batch_size, output.batch_size()); EXPECT_NE(nullptr, output.tensor()); - EXPECT_TRUE(TrtDimsEqualsArray({non_batch_dim}, output.GetTrtDims())) - << "- expected: {" << non_batch_dim << "} \n vs\n" - << "- actual: " << DebugString(output.GetTrtDims()); + ExpectTrtDimsEqualsArray({non_batch_dim}, output.GetTrtDims()); } } @@ -405,7 +481,9 @@ class ConverterTest : public ::testing::Test { ConverterTest() { builder_.reset(nvinfer1::createInferBuilder(logger_)); network_.reset(builder_->createNetwork()); - converter_.reset(new Converter(network_.get(), /*fp16=*/false)); + converter_.reset(new Converter(network_.get(), + /*precision_mode=*/FP32MODE, + /*use_calibration=*/false)); weight_store_ = &converter_->weight_store_; } @@ -432,8 +510,21 @@ class ConverterTest : public ::testing::Test { return converter_->GetInputs(node_def, inputs); } + Status GetWeightRange(const TRT_ShapedWeights& weights, float* out_min, + float* out_max) const { + return converter_->GetWeightRange(weights, out_min, out_max); + } + + void PropagateQuantizationRanges() { + converter_->PropagateQuantizationRanges(); + } + int batch_size() const { return converter_->batch_size_; } + std::unordered_map& quantization_ranges() { + return converter_->quantization_ranges_; + } + private: Logger logger_; // These members are ordered in a way such that the destruction order is: @@ -504,9 +595,9 @@ TEST_F(ConverterTest, AddAndGetInputs) { EXPECT_EQ(nvinfer1::DataType::kFLOAT, inputs[0].tensor()->getType()); EXPECT_EQ(nvinfer1::DataType::kINT32, inputs[2].tensor()->getType()); EXPECT_EQ(nvinfer1::DataType::kHALF, inputs[3].tensor()->getType()); - EXPECT_TRUE(TrtDimsEqualsArray({1}, inputs[0].tensor()->getDimensions())); - EXPECT_TRUE(TrtDimsEqualsArray({2, 3}, inputs[2].tensor()->getDimensions())); - EXPECT_TRUE(TrtDimsEqualsArray({5, 3}, inputs[3].tensor()->getDimensions())); + ExpectTrtDimsEqualsArray({1}, inputs[0].tensor()->getDimensions()); + ExpectTrtDimsEqualsArray({2, 3}, inputs[2].tensor()->getDimensions()); + ExpectTrtDimsEqualsArray({5, 3}, inputs[3].tensor()->getDimensions()); } TEST_F(ConverterTest, RenameAndMarkOutputTensors) { @@ -552,7 +643,7 @@ TEST_F(ConverterTest, RenameAndMarkOutputTensors) { {{"my_op", "my_output"}, {"my_op:1", "my_output_1"}})); EXPECT_EQ(2, output_tensors.size()); for (auto output_tensor : output_tensors) { - EXPECT_TRUE(TrtDimsEqualsArray({2, 1}, output_tensor->getDimensions())); + ExpectTrtDimsEqualsArray({2, 1}, output_tensor->getDimensions()); } EXPECT_EQ("my_output", string(output_tensors[0]->getName())); EXPECT_EQ("my_output_1", string(output_tensors[1]->getName())); @@ -577,8 +668,7 @@ TEST_F(ConverterTest, TransposeTensor) { // OK. TF_EXPECT_OK( converter_->TransposeTensor(input_tensor, {0, 3, 1, 2}, &output_tensor)); - EXPECT_TRUE(TrtDimsEqualsArray({5, 2, 3}, output_tensor->getDimensions())) - << DebugString(*output_tensor); + ExpectTrtDimsEqualsArray({5, 2, 3}, output_tensor->getDimensions()); } TEST_F(ConverterTest, PrepareTensorForShape_Tensor) { @@ -590,7 +680,7 @@ TEST_F(ConverterTest, PrepareTensorForShape_Tensor) { // Shape size doesn't match. ExpectStatus(converter_->PrepareTensorForShape(tw, GetTestDims({2, 3, 6}), &output_tensor), - error::INVALID_ARGUMENT, "Reshape shapes are not compatible."); + error::INVALID_ARGUMENT, "Reshape shapes are not compatible"); // TODO(aaroey): we should check the case where uninferred dimensions are not // an exact divisor of input dim ensions, e.g. for dims {-1, 7}. @@ -598,14 +688,12 @@ TEST_F(ConverterTest, PrepareTensorForShape_Tensor) { // Infer shape, ok. TF_EXPECT_OK(converter_->PrepareTensorForShape(tw, GetTestDims({-1, 2}), &output_tensor)); - EXPECT_TRUE(TrtDimsEqualsArray({15, 2}, output_tensor->getDimensions())) - << DebugString(*output_tensor); + ExpectTrtDimsEqualsArray({15, 2}, output_tensor->getDimensions()); // Regular shape. TF_EXPECT_OK(converter_->PrepareTensorForShape(tw, GetTestDims({10, 3}), &output_tensor)); - EXPECT_TRUE(TrtDimsEqualsArray({10, 3}, output_tensor->getDimensions())) - << DebugString(*output_tensor); + ExpectTrtDimsEqualsArray({10, 3}, output_tensor->getDimensions()); } TEST_F(ConverterTest, PrepareTensorForShape_Weights) { @@ -615,8 +703,7 @@ TEST_F(ConverterTest, PrepareTensorForShape_Weights) { const nvinfer1::ITensor* output_tensor = nullptr; TF_EXPECT_OK(converter_->PrepareTensorForShape(tw, GetTestDims({10, 3}), &output_tensor)); - EXPECT_TRUE(TrtDimsEqualsArray({10, 3}, output_tensor->getDimensions())) - << DebugString(*output_tensor); + ExpectTrtDimsEqualsArray({10, 3}, output_tensor->getDimensions()); } TEST_F(ConverterTest, MaybeUpdateBatchSize) { @@ -656,6 +743,178 @@ TEST_F(ConverterTest, AddAndGetTensorOrWeights) { "tensor/weights my_tensor already exist"); } +template +void TestGetWeightRange(ConverterTest* test, TrtWeightStore* weight_store) { + TRT_ShapedWeights weights = + weight_store->GetTempWeights(DataTypeToEnum::v(), GetTestDims({2, 3})); + const std::vector values = {T(3), T(1), T(2), T(6), T(5), T(4)}; + memcpy(const_cast(weights.GetValues()), values.data(), + weights.size_bytes()); + + float out_min = 0.0f; + float out_max = 0.0f; + TF_EXPECT_OK(test->GetWeightRange(weights, &out_min, &out_max)); + EXPECT_EQ(1.0f, out_min); + EXPECT_EQ(6.0f, out_max); +} + +TEST_F(ConverterTest, GetWeightRange) { + TestGetWeightRange(this, weight_store_); + TestGetWeightRange(this, weight_store_); + TestGetWeightRange(this, weight_store_); +} + +TEST_F(ConverterTest, ProvideQuantizationRange) { + FakeITensor fake_tensor; + // Assymetric range + converter_->ProvideQuantizationRange(&fake_tensor, 0.0f, 6.0f); + EXPECT_EQ(6.0f, quantization_ranges()[&fake_tensor]); + converter_->ProvideQuantizationRange(&fake_tensor, 1.0f, 6.0f); + EXPECT_EQ(6.0f, quantization_ranges()[&fake_tensor]); + converter_->ProvideQuantizationRange(&fake_tensor, -8.0f, 6.0f); + EXPECT_EQ(8.0f, quantization_ranges()[&fake_tensor]); + converter_->ProvideQuantizationRange(&fake_tensor, -8.123f, -6.123f); + EXPECT_EQ(8.123f, quantization_ranges()[&fake_tensor]); + // Symmetric range + converter_->ProvideQuantizationRange(&fake_tensor, -6.123f, 6.123f); + EXPECT_EQ(6.123f, quantization_ranges()[&fake_tensor]); +} + +TEST_F(ConverterTest, MaybeApplyQuantizationRanges) { + // input -> infer1 -> infer2 -> infer3 + FakeITensor input, infer_1, infer_2, infer_3; + FakeITensor not_infer; + Converter int8_converter(/*trt_network=*/nullptr, INT8MODE, + /*use_calibration=*/true); + int8_converter.ProvideQuantizationRange(&input, -5.0f, 5.0f); + int8_converter.ProvideQuantizationRange(¬_infer, -100.0f, 100.0f); + int8_converter.MarkQuantizationRangesAsInferrable(&input, &infer_1); + int8_converter.MarkQuantizationRangesAsInferrable(&infer_1, &infer_2); + int8_converter.MarkQuantizationRangesAsInferrable(&infer_2, &infer_3); + + // Input range should be inferred along the chain and applied to tensors. + int8_converter.MaybeApplyQuantizationRanges(); +#if NV_TENSORRT_MAJOR >= 5 + EXPECT_EQ(input.getDynamicRange(), 5.0f); + EXPECT_EQ(infer_1.getDynamicRange(), 5.0f); + EXPECT_EQ(infer_2.getDynamicRange(), 5.0f); + EXPECT_EQ(infer_3.getDynamicRange(), 5.0f); + EXPECT_EQ(not_infer.getDynamicRange(), 100.0f); +#endif +} + +TEST_F(ConverterTest, PropagateQuantizationRanges) { + // infer0 <-> infer1 <-> infer2 <-> infer3 + // | + // infer4 <-> infer5 + FakeITensor infer[6]; + FakeITensor not_infer; + converter_->ProvideQuantizationRange(&infer[4], -5.0f, 5.0f); + converter_->MarkQuantizationRangesAsInferrable(&infer[0], &infer[1]); + converter_->MarkQuantizationRangesAsInferrable(&infer[1], &infer[2]); + converter_->MarkQuantizationRangesAsInferrable(&infer[3], &infer[2]); + converter_->MarkQuantizationRangesAsInferrable(&infer[4], &infer[1]); + converter_->MarkQuantizationRangesAsInferrable(&infer[4], &infer[5]); + + // Input range should be inferred along the chain. + PropagateQuantizationRanges(); + auto ranges = quantization_ranges(); + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(5.0f, ranges[&infer[i]]); + } + EXPECT_EQ(ranges.count(¬_infer), 0); +} + +TEST_F(ConverterTest, GetTrtBroadcastShape) { + const bool kIsTensor = true; + const bool kIsNotTensor = false; + auto symmetric_test = [this](const std::vector& operand_1_shape, + const std::vector& operand_2_shape, + const bool operand_1_is_tensor, + const bool operand_2_is_tensor, + const std::vector& expected_operand_1_shape, + const std::vector& expected_operand_2_shape, + error::Code expected_code = error::OK, + const char* expected_error_msg_substr = nullptr, + const int operand_1_batch_size = -1, + const int operand_2_batch_size = -1) { + auto create_tensor_or_weights = [](const std::vector& shape, + bool is_tensor, int batch_size = -1) { + if (is_tensor) { + return TRT_TensorOrWeights{nvinfer1::DataType::kFLOAT, + GetTestDims(shape), batch_size}; + } + TRT_ShapedWeights weights; + weights.shape_ = GetTestDims(shape); + return TRT_TensorOrWeights(weights); + }; + + nvinfer1::Dims operand_1_new_dims, operand_2_new_dims; + TRT_TensorOrWeights operand_1 = create_tensor_or_weights( + operand_1_shape, operand_1_is_tensor, operand_1_batch_size); + TRT_TensorOrWeights operand_2 = create_tensor_or_weights( + operand_2_shape, operand_2_is_tensor, operand_2_batch_size); + + // operand_1 broadcast operand_2 + ExpectStatus( + this->converter_->GetTrtBroadcastShape( + operand_1, operand_2, &operand_1_new_dims, &operand_2_new_dims), + expected_code, expected_error_msg_substr); + if (expected_code == error::OK) { + ExpectTrtDimsEqualsArray(expected_operand_1_shape, operand_1_new_dims); + ExpectTrtDimsEqualsArray(expected_operand_2_shape, operand_2_new_dims); + } + // operand_2 broadcast operand_1 + ExpectStatus( + this->converter_->GetTrtBroadcastShape( + operand_2, operand_1, &operand_2_new_dims, &operand_1_new_dims), + expected_code, expected_error_msg_substr); + if (expected_code == error::OK) { + ExpectTrtDimsEqualsArray(expected_operand_1_shape, operand_1_new_dims); + ExpectTrtDimsEqualsArray(expected_operand_2_shape, operand_2_new_dims); + } + }; + + // Both inputs are weights. + symmetric_test( + {1}, {1}, kIsNotTensor, kIsNotTensor, {}, {}, error::INVALID_ARGUMENT, + "Broadcasting requires at least one of the operands be tensors"); + + // One tensor and one weights. + symmetric_test({1, 1, 1}, {2}, kIsTensor, kIsNotTensor, {1, 1, 1}, {1, 1, 2}); + symmetric_test({1, 1, 2}, {2}, kIsTensor, kIsNotTensor, {1, 1, 2}, {1, 1, 2}); + symmetric_test({1, 3, 2}, {1}, kIsTensor, kIsNotTensor, {1, 3, 2}, {1, 1, 1}); + symmetric_test({1, 1, 1}, {2, 3}, kIsTensor, kIsNotTensor, {1, 1, 1}, + {1, 2, 3}); + symmetric_test({1, 1, 1}, {2, 3, 4}, kIsTensor, kIsNotTensor, {1, 1, 1}, + {2, 3, 4}); + symmetric_test({1, 1, 1}, {1, 2, 3, 4}, kIsTensor, kIsNotTensor, {1, 1, 1}, + {2, 3, 4}); + symmetric_test({1, 3, 4}, {1, 2, 1, 4}, kIsTensor, kIsNotTensor, {1, 3, 4}, + {2, 1, 4}); + symmetric_test({1, 1, 1}, {2, 1, 1, 1}, kIsTensor, kIsNotTensor, {}, {}, + error::INVALID_ARGUMENT, "Infeasible broadcast scheme"); + symmetric_test({1, 1, 1}, {2, 1, 1, 1}, kIsTensor, kIsNotTensor, {}, {}, + error::INVALID_ARGUMENT, "Infeasible broadcast scheme", + /*operand_1_batch_size=*/2); + symmetric_test({1, 1, 1}, {1, 1, 1, 1, 1}, kIsTensor, kIsNotTensor, {}, {}, + error::INVALID_ARGUMENT, + "Broadcasting beyond batch dimension is not supported " + "(tensor #dims 4 vs broadcast #dims 5)"); + + // Both inputs are tensors. + symmetric_test({1, 1, 1}, {1, 1}, kIsTensor, kIsTensor, {}, {}, + error::INVALID_ARGUMENT, + "Broadcasting beyond batch dimension is not supported " + "(tensor #dims 3 vs broadcast #dims 4)"); + symmetric_test({1, 3, 4}, {2, 1, 4}, kIsTensor, kIsTensor, {1, 3, 4}, + {2, 1, 4}); + symmetric_test({1, 1, 1}, {1, 1, 1, 1}, kIsTensor, kIsTensor, {}, {}, + error::INVALID_ARGUMENT, + "Broadcasting beyond batch dimension is not supported " + "(tensor #dims 4 vs broadcast #dims 5)"); +} + // Class to test various op converters, using both a TrtNodeValidator and // Converter. class OpConverterTest : public ::testing::Test { @@ -684,15 +943,21 @@ class OpConverterTest : public ::testing::Test { // Reset the validator and converter. validator_.reset(new TrtNodeValidator); - converter_.reset(new Converter(network_.get(), /*fp16=*/false)); + converter_.reset(new Converter(network_.get(), + /*precision_mode=*/FP32MODE, + /*use_calibration=*/false)); // Reset other related artifacts. scope_ = Scope::NewRootScope(); validator_inputs_.clear(); } - void BuildAndRun(const char* input_name, const std::vector& input_data, - const char* output_name, std::vector* output_data) { + // TODO(laigd): test fp16 and int8 support. + template + void BuildAndRun( + const std::vector>>& + input_data, + const char* output_name, std::vector* output_data) { // Mark the output tensor as TRT engine output. TF_EXPECT_OK(converter_->RenameAndMarkOutputTensors( {{string(output_name), string(output_name)}})); @@ -703,25 +968,33 @@ class OpConverterTest : public ::testing::Test { CHECK_NOTNULL(engine_.get()); // Execute the TRT engine. - const int input_size = input_data.size() * sizeof(float); - const int output_size = output_data->size() * sizeof(float); - const int input_index = engine_->getBindingIndex(input_name); - const int output_index = engine_->getBindingIndex(output_name); + ASSERT_LE(input_data.size() + 1, 3); + void* buffers[3]; + for (const auto name_and_data : input_data) { + const int input_size = name_and_data.second.size() * sizeof(T); + const int input_index = engine_->getBindingIndex(name_and_data.first); + ASSERT_EQ(0, cudaMalloc(&buffers[input_index], input_size)); + ASSERT_EQ( + 0, cudaMemcpyAsync(buffers[input_index], name_and_data.second.data(), + input_size, cudaMemcpyHostToDevice, stream_)); + } - ASSERT_EQ(engine_->getNbBindings(), 2); - void* buffers[2]; - ASSERT_EQ(0, cudaMalloc(&buffers[input_index], input_size)); + const int output_size = output_data->size() * sizeof(T); + const int output_index = engine_->getBindingIndex(output_name); ASSERT_EQ(0, cudaMalloc(&buffers[output_index], output_size)); - ASSERT_EQ(0, cudaMemcpyAsync(buffers[input_index], input_data.data(), - input_size, cudaMemcpyHostToDevice, stream_)); + + ASSERT_EQ(engine_->getNbBindings(), input_data.size() + 1); + TrtUniquePtrType execution_context( engine_->createExecutionContext()); execution_context->enqueue(/*batchSize=*/1, buffers, stream_, nullptr); ASSERT_EQ(0, cudaMemcpyAsync(output_data->data(), buffers[output_index], output_size, cudaMemcpyDeviceToHost, stream_)); cudaStreamSynchronize(stream_); - ASSERT_EQ(0, cudaFree(buffers[input_index])); - ASSERT_EQ(0, cudaFree(buffers[output_index])); + + for (int i = 0; i < input_data.size() + 1; ++i) { + ASSERT_EQ(0, cudaFree(buffers[i])); + } } bool HasStaticShape(const nvinfer1::Dims& dims) const { @@ -736,18 +1009,7 @@ class OpConverterTest : public ::testing::Test { void AddTestTensor( const char* name, const std::vector& dims, int batch_size = 1, nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT) { - DataType tf_dtype = DT_FLOAT; - switch (trt_dtype) { - case nvinfer1::DataType::kFLOAT: - tf_dtype = DT_FLOAT; - break; - case nvinfer1::DataType::kINT32: - tf_dtype = DT_INT32; - break; - default: - ASSERT_TRUE(false) << "Unexpected data type " - << static_cast(trt_dtype); - } + DataType tf_dtype = TrtDataTypeToTf(trt_dtype); ops::Placeholder::Attrs attrs; TF_EXPECT_OK(TensorShapeUtils::MakeShape(dims, &attrs.shape_)); attrs.shape_.InsertDim(0, batch_size); @@ -826,6 +1088,11 @@ class OpConverterTest : public ::testing::Test { } } + // Expose quantization_ranges_ for tests + std::unordered_map& quantization_ranges() { + return converter_->quantization_ranges_; + } + std::unique_ptr converter_; std::unique_ptr validator_; @@ -835,6 +1102,11 @@ class OpConverterTest : public ::testing::Test { TrtUniquePtrType network_; TrtUniquePtrType engine_; cudaStream_t stream_; + // Used to create placeholders with shape and data type information. The + // created placeholders will be used as inputs to the node to be verified, + // thus we need the shape and data type information to get a non-empty + // GraphProperties. + // TODO(laigd): consider use this Scope to create the NodeDef to verify. Scope scope_; std::unordered_map validator_inputs_; }; @@ -958,15 +1230,15 @@ TEST_F(OpConverterTest, ConvertTranspose) { Reset(); AddTestTensor("input", {1, 2, 3}); AddTestWeights("weights", {4}, {0, 3, 1, 2}); - RunConversion(node_def); + RunValidationAndConversion(node_def); TRT_TensorOrWeights output; TF_EXPECT_OK(GetTensorOrWeights("my_transpose", &output)); EXPECT_TRUE(output.is_tensor()); - EXPECT_TRUE(TrtDimsEqualsArray({3, 1, 2}, output.tensor()->getDimensions())) - << output.DebugString(); + ExpectTrtDimsEqualsArray({3, 1, 2}, output.tensor()->getDimensions()); std::vector output_data(6); - BuildAndRun("input", {1, 2, 3, 4, 5, 6}, "my_transpose", &output_data); + BuildAndRun({{"input", {1, 2, 3, 4, 5, 6}}}, "my_transpose", + &output_data); EXPECT_THAT(output_data, ElementsAre(1, 4, 2, 5, 3, 6)); } } @@ -1048,15 +1320,15 @@ TEST_F(OpConverterTest, ConvertReshape) { Reset(); AddTestTensor("input", ok_params[i].tensor_dims, ok_params[i].batch_size); AddTestWeights("weights", {4}, ok_params[i].shape); - RunConversion(node_def); + RunValidationAndConversion(node_def); TRT_TensorOrWeights output; TF_EXPECT_OK(GetTensorOrWeights("my_reshape", &output)); EXPECT_TRUE(output.is_tensor()); - EXPECT_TRUE(TrtDimsEqualsArray({1, 3, 2}, output.tensor()->getDimensions())) - << output.DebugString(); + ExpectTrtDimsEqualsArray({1, 3, 2}, output.tensor()->getDimensions()); std::vector output_data(6); - BuildAndRun("input", {1, 2, 3, 4, 5, 6}, "my_reshape", &output_data); + BuildAndRun({{"input", {1, 2, 3, 4, 5, 6}}}, "my_reshape", + &output_data); EXPECT_THAT(output_data, ElementsAre(1, 2, 3, 4, 5, 6)); } } @@ -1070,15 +1342,14 @@ TEST_F(OpConverterTest, ConvertMatMul) { "Input expects tensor and weights, at my_matmul"); } - // Get the NodeDef for Reshape. + // Get the NodeDef for MatMul. auto get_matmul_nodedef = [](DataType dtype, bool transpose_a, bool transpose_b) -> NodeDef { Scope s = Scope::NewRootScope(); auto input = ops::Placeholder(s.WithOpName("input"), dtype); auto weights = ops::Placeholder(s.WithOpName("weights"), dtype); - ops::MatMul::Attrs matmul_attrs; - matmul_attrs.transpose_a_ = transpose_a; - matmul_attrs.transpose_b_ = transpose_b; + const auto matmul_attrs = + ops::MatMul::TransposeA(transpose_a).TransposeB(transpose_b); auto matmul = ops::MatMul(s.WithOpName("my_matmul"), input, weights, matmul_attrs); return matmul.operation.node()->def(); @@ -1094,45 +1365,845 @@ TEST_F(OpConverterTest, ConvertMatMul) { node_def, error::UNIMPLEMENTED, "Data type is not supported, for node my_matmul got int32"); } - { - // transpose_a is set. - for (bool transpose_b : {false, true}) { - Reset(); - NodeDef node_def = - get_matmul_nodedef(DT_FLOAT, /*transpose_a=*/true, transpose_b); - AddTestTensor("input", {2}, /*batch_size=*/1); - AddTestWeights("weights", {2, 2}, {0, 1, 2, 3}); - RunValidationAndConversion( - node_def, error::INVALID_ARGUMENT, - "transpose_a is not supported for TensorRT FullyConnected"); + // transpose_a is set. + for (bool transpose_b : {false, true}) { + Reset(); + NodeDef node_def = + get_matmul_nodedef(DT_FLOAT, /*transpose_a=*/true, transpose_b); + AddTestTensor("input", {2}, /*batch_size=*/1); + AddTestWeights("weights", {2, 2}, {0, 1, 2, 3}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "transpose_a is not supported for TensorRT FullyConnected"); + } + // OK. + for (bool transpose_b : {false, true}) { + Reset(); + NodeDef node_def = + get_matmul_nodedef(DT_FLOAT, /*transpose_a=*/false, transpose_b); + AddTestTensor("input", {2}, /*batch_size=*/1); + AddTestWeights("weights", {2, 2}, {0, 1, 2, 3}); + RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_matmul", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray({2}, output.tensor()->getDimensions()); + + std::vector output_data(2); + BuildAndRun({{"input", {0, 1}}}, "my_matmul", &output_data); + if (transpose_b) { + EXPECT_THAT(output_data, ElementsAre(1, 3)); + } else { + EXPECT_THAT(output_data, ElementsAre(2, 3)); } } - { - // OK. - for (bool transpose_b : {false, true}) { - Reset(); - NodeDef node_def = - get_matmul_nodedef(DT_FLOAT, /*transpose_a=*/false, transpose_b); - AddTestTensor("input", {2}, /*batch_size=*/1); - AddTestWeights("weights", {2, 2}, {0, 1, 2, 3}); - RunConversion(node_def); - TRT_TensorOrWeights output; - TF_EXPECT_OK(GetTensorOrWeights("my_matmul", &output)); - EXPECT_TRUE(output.is_tensor()); - EXPECT_TRUE(TrtDimsEqualsArray({2}, output.tensor()->getDimensions())) - << output.DebugString(); +} - std::vector output_data(2); - BuildAndRun("input", {0, 1}, "my_matmul", &output_data); - if (transpose_b) { - EXPECT_THAT(output_data, ElementsAre(1, 3)); +template +void TestConvertBiasAdd(OpConverterTest* test) { + // Get the NodeDef for BiasAdd. + auto get_biasadd_nodedef = [](const string& data_format) -> NodeDef { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), dtype); + auto weights = ops::Placeholder(s.WithOpName("weights"), dtype); + const auto biasadd_attrs = ops::BiasAdd::DataFormat(data_format); + auto biasadd = + ops::BiasAdd(s.WithOpName("my_biasadd"), input, weights, biasadd_attrs); + return biasadd.operation.node()->def(); + }; + + typedef typename EnumToDataType::Type CType; + for (const string& data_format : {"NHWC", "NCHW"}) { + for (const int trt_input_rank : {1, 2, 3, 4}) { + test->Reset(); + NodeDef node_def = get_biasadd_nodedef(data_format); + + // Add input, dims_array will be like {2, 1, ..., 1, 3} + std::vector dims_array(trt_input_rank, 1); + if (trt_input_rank == 1) { + dims_array[0] = (data_format == "NHWC" ? 3 : 2); } else { - EXPECT_THAT(output_data, ElementsAre(2, 3)); + dims_array[0] = 2; + dims_array[trt_input_rank - 1] = 3; + } + test->AddTestTensor("input", dims_array, /*batch_size=*/1, + TfDataTypeToTrt(dtype)); + + // Add bias weights. + const int channel_size = (data_format == "NHWC" ? 3 : 2); + std::vector bias(channel_size); + for (int i = 0; i < channel_size; ++i) { + bias[i] = CType(i + 1); // bias will be {1, 2, 3, ...} + } + test->AddTestWeights("weights", {channel_size}, bias); + + // Run the conversion. + test->RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(test->GetTensorOrWeights("my_biasadd", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray(dims_array, output.tensor()->getDimensions()); + + // Build and run the engine. + const int num_input = TrtDimsNumElements(GetTestDims(dims_array)); + ASSERT_EQ(trt_input_rank > 1 ? 6 : (data_format == "NHWC" ? 3 : 2), + num_input); + std::vector output_data(num_input); + test->BuildAndRun( + {{"input", std::vector(num_input, CType(0))}}, "my_biasadd", + &output_data); + if (trt_input_rank == 1) { + if (data_format == "NHWC") { + EXPECT_THAT(output_data, ElementsAre(CType(1), CType(2), CType(3))); + } else { + EXPECT_THAT(output_data, ElementsAre(CType(1), CType(2))); + } + } else { + if (data_format == "NHWC") { + EXPECT_THAT(output_data, ElementsAre(CType(1), CType(2), CType(3), + CType(1), CType(2), CType(3))); + } else { + EXPECT_THAT(output_data, ElementsAre(CType(1), CType(1), CType(1), + CType(2), CType(2), CType(2))); + } } } } } +TEST_F(OpConverterTest, ConvertQuantize) { + { + // Input list is empty, should fail. + NodeDef node_def = + MakeNodeDef("my_quantize", "QuantizeAndDequantizeV2", {}); + RunConversion( + node_def, error::INVALID_ARGUMENT, + "Invalid number of inputs for QuantizeAndDequantizeV2, at my_quantize"); + } + { + // FakeQuantWithMinMaxArgs attributes are empty, should fail. + NodeDef node_def = + MakeNodeDef("my_quantize", "FakeQuantWithMinMaxArgs", {"input"}); + AddTestTensor("input", {1, 2, 3}); + RunConversion(node_def, error::INVALID_ARGUMENT, + "Min or max attribute not found for FakeQuantWithMinMaxArgs " + "at my_quantize"); + } + { + // FakeQuantWithMinMaxArgs ranges set via attributes, ok. + Reset(); + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + ops::FakeQuantWithMinMaxArgs::Attrs quantize_attrs; + quantize_attrs.min_ = -6.0f; + quantize_attrs.max_ = 6.0f; + auto quantize = ops::FakeQuantWithMinMaxArgs(s.WithOpName("my_quantize"), + input, quantize_attrs); + const NodeDef& node_def = quantize.operation.node()->def(); + AddTestTensor("input", {1, 2, 3}); + RunConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_quantize", &output)); + EXPECT_TRUE(output.is_tensor()); + auto ranges = quantization_ranges(); + EXPECT_EQ(ranges.count(output.tensor()), 1); + EXPECT_EQ(ranges[output.tensor()], 6.0f); + } + { + // FakeQuantWithMinMaxVars ranges set via inputs, ok. + Reset(); + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto weights_min = ops::Placeholder(s.WithOpName("weights_min"), DT_FLOAT); + auto weights_max = ops::Placeholder(s.WithOpName("weights_max"), DT_FLOAT); + auto quantize = ops::FakeQuantWithMinMaxVars( + s.WithOpName("my_quantize"), input, weights_min, weights_max); + const NodeDef& node_def = quantize.operation.node()->def(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("weights_min", {1}, {-6.0f}); + AddTestWeights("weights_max", {1}, {6.0f}); + RunConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_quantize", &output)); + EXPECT_TRUE(output.is_tensor()); + auto ranges = quantization_ranges(); + EXPECT_EQ(ranges.count(output.tensor()), 1); + EXPECT_EQ(ranges[output.tensor()], 6.0f); + } + { + // QuantizeAndDequantizeV2 ranges set via inputs, ok. + Reset(); + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto weights_min = ops::Placeholder(s.WithOpName("weights_min"), DT_FLOAT); + auto weights_max = ops::Placeholder(s.WithOpName("weights_max"), DT_FLOAT); + auto quantize = ops::QuantizeAndDequantizeV2( + s.WithOpName("my_quantize"), input, weights_min, weights_max); + const NodeDef& node_def = quantize.operation.node()->def(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("weights_min", {1}, {-6.0f}); + AddTestWeights("weights_max", {1}, {6.0f}); + RunConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_quantize", &output)); + EXPECT_TRUE(output.is_tensor()); + auto ranges = quantization_ranges(); + EXPECT_EQ(ranges.count(output.tensor()), 1); + EXPECT_EQ(ranges[output.tensor()], 6.0f); + } + { + // QuantizeAndDequantizeV3 ranges set via inputs, ok. + Reset(); + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto weights_min = ops::Placeholder(s.WithOpName("weights_min"), DT_FLOAT); + auto weights_max = ops::Placeholder(s.WithOpName("weights_max"), DT_FLOAT); + auto num_bits = ops::Placeholder(s.WithOpName("num_bits"), DT_INT32); + auto quantize = ops::QuantizeAndDequantizeV3( + s.WithOpName("my_quantize"), input, weights_min, weights_max, num_bits); + const NodeDef& node_def = quantize.operation.node()->def(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("weights_min", {1}, {-6.0f}); + AddTestWeights("weights_max", {1}, {6.0f}); + AddTestWeights("num_bits", {1}, {8}); + RunConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_quantize", &output)); + EXPECT_TRUE(output.is_tensor()); + auto ranges = quantization_ranges(); + EXPECT_EQ(ranges.count(output.tensor()), 1); + EXPECT_EQ(ranges[output.tensor()], 6.0f); + } + { + // QuantizeAndDequantizeV2 Range inputs are tensors, should fail. + Reset(); + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto weights_min = ops::Placeholder(s.WithOpName("weights_min"), DT_FLOAT); + auto weights_max = ops::Placeholder(s.WithOpName("weights_max"), DT_FLOAT); + auto quantize = ops::QuantizeAndDequantizeV2( + s.WithOpName("my_quantize"), input, weights_min, weights_max); + const NodeDef& node_def = quantize.operation.node()->def(); + AddTestTensor("input", {1, 2, 3}); + AddTestTensor("weights_min", {1}); + AddTestTensor("weights_max", {1}); + RunConversion( + node_def, error::INVALID_ARGUMENT, + "Min and max inputs for QuantizeAndDequantizeV2 must be weights not " + "tensors, at my_quantize"); + } +} + +TEST_F(OpConverterTest, ConvertRelu6) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_relu6", "Relu6", {}); + RunConversion(node_def, error::INVALID_ARGUMENT, + "Invalid number of inputs for Relu6, at my_relu6"); + } + + // Get the NodeDef for Relu6. + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto relu6 = ops::Relu6(s.WithOpName("my_relu6"), input); + const NodeDef& node_def = relu6.operation.node()->def(); + + { + // Clip tensor values and set quantization ranges, ok. + Reset(); + AddTestTensor("input", {1, 2, 3}); + RunConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_relu6", &output)); + EXPECT_TRUE(output.is_tensor()); + auto ranges = quantization_ranges(); + EXPECT_EQ(ranges[output.tensor()], 6.0f); + + std::vector output_data(6); + BuildAndRun("input", {-100, -1, 0, 3, 5, 9}, "my_relu6", &output_data); + EXPECT_THAT(output_data, ElementsAre(0, 0, 0, 3, 5, 6)); + } + { + // Input is weights, should fail. + Reset(); + AddTestWeights("input", {1, 2, 3}, {-100, -1, 0, 3, 5, 9}); + RunConversion( + node_def, error::UNIMPLEMENTED, + "Relu6 is only implemented for tensors, not weights, at my_relu6"); + } +} + +TEST_F(OpConverterTest, ConvertBiasAdd) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_biasadd", "BiasAdd", {}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Input expects tensor and weights, at my_biasadd"); + } + + // OK. Note that kINT32 is not supported by IScaleLayer, so we don't test + // DT_INT32 type here. + TestConvertBiasAdd(this); + TestConvertBiasAdd(this); +} + +template +NodeDef GetBinaryOpNodeDef(const string& input_name_l, + const string& input_name_r, DataType dtype) { + Scope s = Scope::NewRootScope(); + auto input_l = ops::Placeholder(s.WithOpName(input_name_l), dtype); + auto input_r = ops::Placeholder(s.WithOpName(input_name_r), dtype); + auto op = OpType(s.WithOpName("my_binary"), input_l, input_r); + return op.operation.node()->def(); +} + +void CheckAddedLayers(OpConverterTest* test, bool expect_scale_layer) { + bool element_wise_layer_found = false; + bool scale_layer_found = false; + for (int i = 0; i < test->converter_->network()->getNbLayers(); i++) { + nvinfer1::ILayer* layer = test->converter_->network()->getLayer(i); + if (dynamic_cast(layer)) { + scale_layer_found = true; + } else if (dynamic_cast(layer)) { + element_wise_layer_found = true; + } + } + EXPECT_EQ(expect_scale_layer, scale_layer_found); + EXPECT_NE(expect_scale_layer, element_wise_layer_found); +} + +template +void TestBinaryTensorOpWeightNoBroadcast(OpConverterTest* test) { + typedef typename EnumToDataType::Type CType; + for (auto swap_inputs : {false, true}) { + test->Reset(); + NodeDef node_def; + if (swap_inputs) { + node_def = GetBinaryOpNodeDef("weights", "input", dtype); + } else { + node_def = GetBinaryOpNodeDef("input", "weights", dtype); + } + + const std::vector operand1{CType(3), CType(7.5)}; + const std::vector operand2{CType(2), CType(3)}; + + // It requires the dims to be at least of rank 3 to apply an IScaleLayer. + test->AddTestTensor("input", /*dims=*/{1, 1, 2}, /*batch_size=*/1, + TfDataTypeToTrt(dtype)); + test->AddTestWeights("weights", /*dims=*/{1, 1, 2}, + /*values=*/swap_inputs ? operand1 : operand2); + test->RunValidationAndConversion(node_def); + + // Make sure it does use BinaryTensorOpWeight, not BinaryTensorOpTensor. + CheckAddedLayers(test, /*expect_scale_layer=*/true); + + // Check the dims of the output ITensor. + TRT_TensorOrWeights output; + TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray({1, 1, 2}, output.tensor()->getDimensions()); + + std::vector output_data(2); + test->BuildAndRun( + {{"input", + /*input_data=*/swap_inputs ? operand2 : operand1}}, + "my_binary", &output_data); + if (node_def.op() == "Add") { + EXPECT_THAT(output_data, ElementsAre(CType(5), CType(10.5))); + } else if (node_def.op() == "Sub") { + EXPECT_THAT(output_data, ElementsAre(CType(1), CType(4.5))); + } else if (node_def.op() == "Mul") { + EXPECT_THAT(output_data, ElementsAre(CType(6), CType(22.5))); + } else if (node_def.op() == "Div") { + EXPECT_THAT(output_data, ElementsAre(CType(1.5), CType(2.5))); + } else if (node_def.op() == "RealDiv") { + EXPECT_THAT(output_data, ElementsAre(CType(1.5), CType(2.5))); + } else { + ASSERT_TRUE(false); + } + } +} + +template +void TestBinaryTensorOpWeightWithChannelWiseBroadcast(OpConverterTest* test) { + typedef typename EnumToDataType::Type CType; + const NodeDef node_def = + GetBinaryOpNodeDef("input", "weights", dtype); + const std::vector input{CType(1), CType(2), CType(3), CType(4)}; + const std::vector weights{CType(10), CType(20)}; + // There are two types of valid dim pairs which requires channel-wise + // broadcasting: + // - input dims (X Y Z) vs weights dims (X 1 1) + // - input dims (X Y Z) vs weights dims (Z) + // Here X=Z=2 and Y=1. + for (auto weights_dims : std::vector>{{2, 1, 1}, {2}}) { + test->Reset(); + test->AddTestTensor("input", /*dims=*/{2, 1, 2}, /*batch_size=*/1, + TfDataTypeToTrt(dtype)); + test->AddTestWeights("weights", weights_dims, weights); + test->RunValidationAndConversion(node_def); + + // Make sure it does use BinaryTensorOpWeight, not BinaryTensorOpTensor. + CheckAddedLayers(test, /*expect_scale_layer=*/true); + + // Check the dims of the output ITensor. + TRT_TensorOrWeights output; + TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray({2, 1, 2}, output.tensor()->getDimensions()); + + std::vector output_data(4); + test->BuildAndRun({{"input", input}}, "my_binary", &output_data); + if (weights_dims.size() == 1) { + EXPECT_THAT(output_data, + ElementsAre(CType(11), CType(22), CType(13), CType(24))); + } else { + EXPECT_THAT(output_data, + ElementsAre(CType(11), CType(12), CType(23), CType(24))); + } + } +} + +template +void TestBinaryTensorOpWeightWithUniformlyBroadcast(OpConverterTest* test) { + typedef typename EnumToDataType::Type CType; + const NodeDef node_def = + GetBinaryOpNodeDef("input", "weights", dtype); + const std::vector input{CType(1), CType(2), CType(3), CType(4)}; + const std::vector weights{CType(10)}; + test->Reset(); + test->AddTestTensor("input", /*dims=*/{2, 1, 2}, /*batch_size=*/1, + TfDataTypeToTrt(dtype)); + test->AddTestWeights("weights", {1, 1, 1, 1}, weights); + test->RunValidationAndConversion(node_def); + + // Make sure it does use BinaryTensorOpWeight, not BinaryTensorOpTensor. + CheckAddedLayers(test, /*expect_scale_layer=*/true); + + // Check the dims of the output ITensor. + TRT_TensorOrWeights output; + TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray({2, 1, 2}, output.tensor()->getDimensions()); + + std::vector output_data(4); + test->BuildAndRun({{"input", input}}, "my_binary", &output_data); + EXPECT_THAT(output_data, + ElementsAre(CType(11), CType(12), CType(13), CType(14))); +} + +template +void TestBinaryTensorOpWeightFallback(OpConverterTest* test, + const std::vector& input_dims, + const std::vector& weights_dims, + error::Code code = error::OK, + const char* error_msg_substr = nullptr, + const int input_batch_size = 1) { + const DataType dtype = DT_FLOAT; + typedef typename EnumToDataType::Type CType; + const size_t num_inputs = TrtDimsNumElements(GetTestDims(input_dims)); + const size_t num_weights = TrtDimsNumElements(GetTestDims(weights_dims)); + + test->Reset(); + const NodeDef node_def = + GetBinaryOpNodeDef("input", "weights", dtype); + test->AddTestTensor("input", /*dims=*/input_dims, input_batch_size, + TfDataTypeToTrt(dtype)); + test->AddTestWeights( + "weights", /*dims=*/weights_dims, + /*values=*/std::vector(num_weights, CType(1))); + test->RunValidationAndConversion(node_def, code, error_msg_substr); + if (code != error::OK) return; + + // Make sure it does use BinaryTensorOpTensor, not BinaryTensorOpWeight. + CheckAddedLayers(test, /*expect_scale_layer=*/false); + + TRT_TensorOrWeights output; + TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output)); + EXPECT_TRUE(output.is_tensor()); + + // Check the dims of the output ITensor. + std::vector expected_output_dims = input_dims; + for (int i = expected_output_dims.size() - 1, j = weights_dims.size() - 1; + i >= 0 && j >= 0; --i, --j) { + if (expected_output_dims[i] == 1) { + expected_output_dims[i] = weights_dims[j]; + } + } + ExpectTrtDimsEqualsArray(expected_output_dims, + output.tensor()->getDimensions()); + + // Check the result of running the engine. + const int expected_num_outputs = + TrtDimsNumElements(GetTestDims(expected_output_dims)); + std::vector output_data(expected_num_outputs); + test->BuildAndRun( + {{"input", + /*input_data=*/std::vector(num_inputs, CType(2))}}, + "my_binary", &output_data); + if (node_def.op() == "Add") { + EXPECT_THAT(output_data, ElementsAreArray(std::vector( + expected_num_outputs, CType(3)))); + } else if (node_def.op() == "Minimum") { + EXPECT_THAT(output_data, ElementsAreArray(std::vector( + expected_num_outputs, CType(1)))); + } else { + ASSERT_TRUE(false); + } +} + +template +void TestBinaryTensorOpTensor(OpConverterTest* test) { + typedef typename EnumToDataType::Type CType; + test->Reset(); + const NodeDef node_def = + GetBinaryOpNodeDef("input1", "input2", dtype); + test->AddTestTensor("input1", /*dims=*/{1, 2}, /*batch_size=*/1, + TfDataTypeToTrt(dtype)); + test->AddTestTensor("input2", /*dims=*/{2, 1}, /*batch_size=*/1, + TfDataTypeToTrt(dtype)); + test->RunValidationAndConversion(node_def); + + // Make sure it does use BinaryTensorOpTensor, not BinaryTensorOpWeight. + CheckAddedLayers(test, /*expect_scale_layer=*/false); + + // Check output dims. + TRT_TensorOrWeights output; + TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray({2, 2}, output.tensor()->getDimensions()); + + std::vector output_data(4); + // After broadcasting first input becomes {3, 6, 3, 6} and second input + // becomes {2, 3, 2, 3}. + test->BuildAndRun( + {{"input1", {CType(3), CType(6)}}, {"input2", {CType(2), CType(3)}}}, + "my_binary", &output_data); + if (node_def.op() == "Add") { + EXPECT_THAT(output_data, + ElementsAre(CType(5), CType(8), CType(6), CType(9))); + } else if (node_def.op() == "Sub") { + EXPECT_THAT(output_data, + ElementsAre(CType(1), CType(4), CType(0), CType(3))); + } else if (node_def.op() == "Mul") { + EXPECT_THAT(output_data, + ElementsAre(CType(6), CType(12), CType(9), CType(18))); + } else if (node_def.op() == "Div") { + EXPECT_THAT(output_data, + ElementsAre(CType(1.5), CType(3), CType(1), CType(2))); + } else if (node_def.op() == "RealDiv") { + EXPECT_THAT(output_data, + ElementsAre(CType(1.5), CType(3), CType(1), CType(2))); + } else if (node_def.op() == "Minimum") { + EXPECT_THAT(output_data, + ElementsAre(CType(2), CType(2), CType(3), CType(3))); + } else if (node_def.op() == "Maximum") { + EXPECT_THAT(output_data, + ElementsAre(CType(3), CType(6), CType(3), CType(6))); + } else { + ASSERT_TRUE(false); + } +} + +TEST_F(OpConverterTest, ConvertBinary) { + // Input size doesn't match, should fail. + for (size_t num_inputs = 0; num_inputs < 2; ++num_inputs) { + Reset(); + NodeDef node_def = MakeNodeDef("my_add", "Add", {num_inputs, "input"}); + AddTestTensor("input", {1}, /*batch_size=*/1, nvinfer1::DataType::kFLOAT); + RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, + "Binary ops require two inputs, at my_add"); + } + { + // Both inputs are weights. + Reset(); + NodeDef node_def = MakeNodeDef("my_add", "Add", {"weights1", "weights2"}); + AddTestWeights("weights1", {1}, {1}); + AddTestWeights("weights2", {1}, {1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Constant folding is falled back to TensorFlow, binary op received " + "both input as constant at: my_add"); + } + + // Test BinaryTensorOpWeight() without broadcasting. + TestBinaryTensorOpWeightNoBroadcast(this); + TestBinaryTensorOpWeightNoBroadcast(this); + TestBinaryTensorOpWeightNoBroadcast(this); + TestBinaryTensorOpWeightNoBroadcast(this); + TestBinaryTensorOpWeightNoBroadcast(this); +#if 0 + // TODO(b/119560144): it doesn't support FP16 constants and the following test + // will fail. + TestBinaryTensorOpWeightNoBroadcast(this); + TestBinaryTensorOpWeightNoBroadcast(this); + TestBinaryTensorOpWeightNoBroadcast(this); + TestBinaryTensorOpWeightNoBroadcast(this); + TestBinaryTensorOpWeightNoBroadcast(this); +#endif + + // Test BinaryTensorOpWeight() with channel-wise broadcasting. + TestBinaryTensorOpWeightWithChannelWiseBroadcast(this); + + // Test BinaryTensorOpWeight() with uniformly broadcasting. + TestBinaryTensorOpWeightWithUniformlyBroadcast(this); + + // Test BinaryTensorOpWeight() falling back to BinaryTensorOpTensor(). + // Unsupported op. + TestBinaryTensorOpWeightFallback(this, {1, 1, 1}, {1}); + // Rank of input tensor dimension <3. + TestBinaryTensorOpWeightFallback(this, {1, 1}, {1}); + // Broadcast on batch dimension, should fail. + TestBinaryTensorOpWeightFallback( + this, {1, 1, 1}, {2, 1, 1, 1}, error::INVALID_ARGUMENT, + "Unsupported binary op broadcast scheme for op my_binary", + /*input_batch_size=*/2); + // Incompatible dims with per-channel mode. + TestBinaryTensorOpWeightFallback(this, {1, 1, 1}, {1, 2, 1}); + // Incompatible dims. + TestBinaryTensorOpWeightFallback(this, {1, 2, 1}, {2}); + + // Test BinaryTensorOpTensor() with broadcasting. + TestBinaryTensorOpTensor(this); + TestBinaryTensorOpTensor(this); + TestBinaryTensorOpTensor(this); + TestBinaryTensorOpTensor(this); + TestBinaryTensorOpTensor(this); + TestBinaryTensorOpTensor(this); + TestBinaryTensorOpTensor(this); + + TestBinaryTensorOpTensor(this); + TestBinaryTensorOpTensor(this); + TestBinaryTensorOpTensor(this); + TestBinaryTensorOpTensor(this); + TestBinaryTensorOpTensor(this); + TestBinaryTensorOpTensor(this); + TestBinaryTensorOpTensor(this); +} + +TEST_F(OpConverterTest, ConvertQuantize) { + for (const string& op : + {"FakeQuantWithMinMaxArgs", "FakeQuantWithMinMaxVars", + "QuantizeAndDequantizeV2", "QuantizeAndDequantizeV3"}) { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_quantize", op, {}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + StrCat("Invalid number of inputs for ", op, ", at my_quantize") + .c_str()); + } + { + // FakeQuantWithMinMaxArgs attributes are empty, should fail. + NodeDef node_def = + MakeNodeDef("my_quantize", "FakeQuantWithMinMaxArgs", {"input"}); + AddTestTensor("input", {1, 2, 3}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Min or max attribute not found for FakeQuantWithMinMaxArgs " + "at my_quantize"); + } + { + // FakeQuantWithMinMaxArgs ranges set via attributes, ok. + Reset(); + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto quantize_attrs = ops::FakeQuantWithMinMaxArgs::Min(-6.0f).Max(6.0f); + auto quantize = ops::FakeQuantWithMinMaxArgs(s.WithOpName("my_quantize"), + input, quantize_attrs); + const NodeDef& node_def = quantize.operation.node()->def(); + AddTestTensor("input", {1, 2, 3}); + RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_quantize", &output)); + EXPECT_TRUE(output.is_tensor()); + auto ranges = quantization_ranges(); + EXPECT_EQ(1, ranges.count(output.tensor())); + EXPECT_EQ(6.0f, ranges[output.tensor()]); + } + { + // FakeQuantWithMinMaxVars ranges set via inputs, ok. + Reset(); + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto weights_min = ops::Placeholder(s.WithOpName("weights_min"), DT_FLOAT); + auto weights_max = ops::Placeholder(s.WithOpName("weights_max"), DT_FLOAT); + auto quantize = ops::FakeQuantWithMinMaxVars( + s.WithOpName("my_quantize"), input, weights_min, weights_max); + const NodeDef& node_def = quantize.operation.node()->def(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("weights_min", {1}, {-6.0f}); + AddTestWeights("weights_max", {1}, {6.0f}); + RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_quantize", &output)); + EXPECT_TRUE(output.is_tensor()); + auto ranges = quantization_ranges(); + EXPECT_EQ(1, ranges.count(output.tensor())); + EXPECT_EQ(6.0f, ranges[output.tensor()]); + } + { + // QuantizeAndDequantizeV2 ranges set via inputs, ok. + Reset(); + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto weights_min = ops::Placeholder(s.WithOpName("weights_min"), DT_FLOAT); + auto weights_max = ops::Placeholder(s.WithOpName("weights_max"), DT_FLOAT); + auto quantize = ops::QuantizeAndDequantizeV2( + s.WithOpName("my_quantize"), input, weights_min, weights_max); + const NodeDef& node_def = quantize.operation.node()->def(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("weights_min", {1}, {-6.0f}); + AddTestWeights("weights_max", {1}, {6.0f}); + RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_quantize", &output)); + EXPECT_TRUE(output.is_tensor()); + auto ranges = quantization_ranges(); + EXPECT_EQ(1, ranges.count(output.tensor())); + EXPECT_EQ(6.0f, ranges[output.tensor()]); + } + { + // QuantizeAndDequantizeV2 Range inputs are tensors, should fail. + Reset(); + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto weights_min = ops::Placeholder(s.WithOpName("weights_min"), DT_FLOAT); + auto weights_max = ops::Placeholder(s.WithOpName("weights_max"), DT_FLOAT); + auto quantize = ops::QuantizeAndDequantizeV2( + s.WithOpName("my_quantize"), input, weights_min, weights_max); + const NodeDef& node_def = quantize.operation.node()->def(); + AddTestTensor("input", {1, 2, 3}); + AddTestTensor("weights_min", {1}); + AddTestTensor("weights_max", {1}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Min and max inputs for QuantizeAndDequantizeV2 must be weights not " + "tensors, at my_quantize"); + } + { + // QuantizeAndDequantizeV3 ranges set via inputs, ok. + Reset(); + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto weights_min = ops::Placeholder(s.WithOpName("weights_min"), DT_FLOAT); + auto weights_max = ops::Placeholder(s.WithOpName("weights_max"), DT_FLOAT); + auto num_bits = ops::Placeholder(s.WithOpName("num_bits"), DT_INT32); + auto quantize = ops::QuantizeAndDequantizeV3( + s.WithOpName("my_quantize"), input, weights_min, weights_max, num_bits); + const NodeDef& node_def = quantize.operation.node()->def(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("weights_min", {1}, {-6.0f}); + AddTestWeights("weights_max", {1}, {6.0f}); + AddTestWeights("num_bits", {1}, {8}); + RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_quantize", &output)); + EXPECT_TRUE(output.is_tensor()); + auto ranges = quantization_ranges(); + EXPECT_EQ(1, ranges.count(output.tensor())); + EXPECT_EQ(6.0f, ranges[output.tensor()]); + } +} + +TEST_F(OpConverterTest, ConvertRelu6) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_relu6", "Relu6", {}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Invalid number of inputs for Relu6, at my_relu6"); + } + + // Get the NodeDef for Relu6. + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto relu6 = ops::Relu6(s.WithOpName("my_relu6"), input); + const NodeDef node_def = relu6.operation.node()->def(); + { + // Input is weights, should fail. + Reset(); + AddTestWeights("input", {1}, {1.0f}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Relu6 is only implemented for tensors, not weights, at my_relu6"); + } + { + // Clip tensor values and set quantization ranges, ok. + Reset(); + AddTestTensor("input", {1, 2, 3}); + RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_relu6", &output)); + EXPECT_TRUE(output.is_tensor()); + auto ranges = quantization_ranges(); + EXPECT_EQ(ranges[output.tensor()], 6.0f); + + std::vector output_data(6); + BuildAndRun({{"input", {-100, -1, 0, 3, 5, 9}}}, "my_relu6", + &output_data); + EXPECT_THAT(output_data, ElementsAre(0, 0, 0, 3, 5, 6)); + } +} + +template +void TestConvertSquare(OpConverterTest* test) { + test->Reset(); + typedef typename EnumToDataType::Type CType; + + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), dtype); + auto square = ops::Square(s.WithOpName("my_square"), input); + NodeDef node_def = square.operation.node()->def(); + + test->AddTestTensor("input", {1, 20}); + test->RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(test->GetTensorOrWeights("my_square", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray({1, 20}, output.tensor()->getDimensions()); + + const int num_inputs = 20; + std::vector input_data(num_inputs); + std::vector expected_output_data(num_inputs); + for (int i = 0; i < 20; i++) { + const CType value = CType(i - 9); + input_data[i] = value; + expected_output_data[i] = value * value; + } + std::vector output_data(num_inputs); + test->BuildAndRun({{"input", input_data}}, "my_square", &output_data); + ExpectArrayNear(expected_output_data, output_data); +} + +TEST_F(OpConverterTest, ConvertSquare) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_square", "Square", {}); + RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, + "Square expects one input, at my_square"); + } + { + // Input is weights, should fail. + Reset(); + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto square = ops::Square(s.WithOpName("my_square"), input); + NodeDef node_def = square.operation.node()->def(); + AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, -5, 6}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Square is only implemented for tensors, at my_square"); + } + + // OK. Note that kINT32 is not supported by IElementWiseLayer, so we don't + // test DT_INT32 type here. + TestConvertSquare(this); + // TODO(tmorris): Looks like there may be a bug with this layer for FP16 + // inputs. Disabling for now. + // TestConvertSquare(this); +} + } // namespace convert } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc index b30d94b0282..4ac7e21d348 100644 --- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc +++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc @@ -67,6 +67,9 @@ tensorflow::Status TRTOptimizationPass::Init( TF_RETURN_IF_ERROR(GetPrecisionMode( Uppercase(params.at("precision_mode").s()), &precision_mode_)); } + if (params.count("use_calibration")) { + use_calibration_ = params.at("use_calibration").b(); + } return tensorflow::Status::OK(); } @@ -222,6 +225,12 @@ tensorflow::Status TRTOptimizationPass::Optimize( TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(true)); tensorflow::tensorrt::convert::ConversionParams cp; + if (use_calibration_ && precision_mode_ != INT8MODE) { + LOG(ERROR) << "Calibration with FP32 or FP16 is not implemented. " + << "Falling back to use_calibration = False."; + use_calibration_ = false; + } + std::vector nodes_to_preserve; for (const auto& n : item.NodesToPreserve()) { auto tokens = str_util::Split(n, ":"); @@ -250,6 +259,7 @@ tensorflow::Status TRTOptimizationPass::Optimize( cp.is_dyn_op = is_dynamic_op_; cp.cached_engine_batches = batches_; cp.max_cached_engines = max_cached_batches_; + cp.use_calibration = use_calibration_; auto status = tensorflow::tensorrt::convert::ConvertAfterShapes(cp); VLOG(1) << "Returning from " << name_; return status; diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h index 71b51d13681..3e8dc0978e4 100644 --- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h +++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h @@ -38,7 +38,8 @@ class TRTOptimizationPass : public tensorflow::grappler::CustomGraphOptimizer { maximum_batch_size_(-1), is_dynamic_op_(false), max_cached_batches_(1), - max_workspace_size_bytes_(256LL << 20) { + max_workspace_size_bytes_(256LL << 20), + use_calibration_(true) { VLOG(1) << "Constructing " << name_; } @@ -67,6 +68,7 @@ class TRTOptimizationPass : public tensorflow::grappler::CustomGraphOptimizer { std::vector batches_; int max_cached_batches_; int64_t max_workspace_size_bytes_; + bool use_calibration_; }; } // namespace convert diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc index 019446813a5..117039683c0 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc @@ -124,8 +124,10 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) OP_REQUIRES_OK(context, context->GetAttr("segment_funcdef_name", &funcdef_name_)); OP_REQUIRES_OK(context, GetPrecisionMode(precision_string, &precision_mode_)); - calibration_mode_ = - (precision_mode_ == INT8MODE && calibration_data.size() == 0); + OP_REQUIRES_OK(context, + context->GetAttr("use_calibration", &use_calibration_)); + calibration_mode_ = (use_calibration_ && precision_mode_ == INT8MODE && + calibration_data.size() == 0); if (calibration_data.size()) { calibrator_.reset(new TRTInt8Calibrator(calibration_data)); calibration_data.resize(0); @@ -252,9 +254,8 @@ int TRTEngineOp::GetEngineBatch(OpKernelContext* ctx) { cached_engine_batches_.push_back(num_batch); VLOG(1) << "Running with batch size " << num_batch; } else { - string msg = - StrCat("Engine buffer is full. buffer limit=", max_cached_engines_, - ", current entries="); + string msg = StrCat("Engine buffer is full. buffer limit=", + max_cached_engines_, ", current entries="); for (auto i : cached_engine_batches_) StrAppend(&msg, i, ","); StrAppend(&msg, " requested batch=", num_batch); LOG(WARNING) << msg; @@ -308,7 +309,7 @@ bool TRTEngineOp::ExecuteTrtEngine( std::vector buffers(num_binding); for (int i = 0; i < ctx->num_inputs(); i++) { const string input_name = StrCat(kInputPHName, i); - const size_t binding_index = + const int binding_index = trt_engine_ptr->getBindingIndex(input_name.c_str()); if (binding_index == -1) { LOG(ERROR) << "Input node not found, at " << input_name; @@ -345,7 +346,7 @@ bool TRTEngineOp::ExecuteTrtEngine( for (int i = 0; i < ctx->num_outputs(); i++) { // Create an output tensor const string output_name = StrCat(kOutputPHName, i); - const size_t binding_index = + const int binding_index = trt_engine_ptr->getBindingIndex(output_name.c_str()); Tensor* output_tensor = nullptr; @@ -491,13 +492,14 @@ TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size, } TrtUniquePtrType engine; bool convert_successfully = false; - VLOG(0) << name() << " Constructing a new engine with batch size " - << batch_size; + LOG(INFO) << "Building a new TensorRT engine for " << name() + << " with batch size " << batch_size; // Up to this point, calibrator_ can never be empty, since otherwise it // means calibration_mode_ is true and this path won't get executed. auto status = convert::ConvertGraphDefToEngine( segment_graph_, precision_mode_, batch_size, workspace_size_, shapes, - &logger, allocator, calibrator_.get(), &engine, &convert_successfully); + &logger, allocator, calibrator_.get(), &engine, use_calibration_, + &convert_successfully); if (!status.ok()) { if (convert_successfully) { // This means it fail to build the engine even when the network is built @@ -567,8 +569,8 @@ tensorflow::Status TRTEngineOp::AllocateCalibrationResources( const int64 workspace_size_bytes = workspace_size_; cres->thr_.reset(new std::thread([cres, label, segment_graph, shapes, platform_gpu_id, workspace_size_bytes]() { - VLOG(0) << "Starting calibration thread on device " << platform_gpu_id - << ", Calibration Resource @ " << cres; + LOG(INFO) << "Starting calibration thread on device " << platform_gpu_id + << ", Calibration Resource @ " << cres; auto err = cudaSetDevice(platform_gpu_id); if (err != cudaSuccess) { // TODO(aaroey): should return error here. @@ -586,6 +588,7 @@ tensorflow::Status TRTEngineOp::AllocateCalibrationResources( *segment_graph, INT8MODE, cres->calibrator_->getBatchSize(), workspace_size_bytes, shapes, &cres->logger_, cres->allocator_.get(), cres->calibrator_.get(), &cres->engine_, + /*use_calibration=*/true, /*convert_successfully=*/nullptr); if (!s.ok()) { LOG(ERROR) << "Calibration failed: " << s; diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h index 8fe06758914..b545f497f32 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h @@ -130,6 +130,10 @@ class TRTEngineOp : public AsyncOpKernel { // The finalized calibrator for inference. std::unique_ptr calibrator_; + + // If true, create calibration graph for INT8 mode. Otherwise, we are using + // user-provided quantization ranges. + bool use_calibration_; }; } // namespace tensorrt diff --git a/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc b/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc index e0c7b627237..92405906eb7 100644 --- a/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc @@ -16,6 +16,7 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/shape_inference.h" @@ -39,18 +40,19 @@ REGISTER_OP("TRTEngineOp") .Attr("cached_engine_batches: list(int) = []") .Attr("max_cached_engines_count: int = 1") .Attr("workspace_size_bytes: int") - .Attr("precision_mode: {'FP32', 'FP16', 'INT8', 'INT8CALIB'}") + .Attr("precision_mode: {'FP32', 'FP16', 'INT8'}") .Attr("calibration_data: string = ''") + .Attr("use_calibration: bool = true") .Input("in_tensor: InT") - .Output("out_tensor: OutT"); -// TODO(jie): TF requires concrete output shape for concrete input shapes. -// This is tricky for batch dimension, since we cannot ensure which input -// would carry the correct batch dimension (for the current stage of the -// implementation, we do require all input tensor to carry the same batch -// size, but this could change in the future). Hence we disable shape -// inference function as a workaround. -// .SetShapeFn(shape_inference::TRTEngineOpShapeInference); - + .Output("out_tensor: OutT") + // TODO(jie): TF requires concrete output shape for concrete input shapes. + // This is tricky for batch dimension, since we cannot ensure which input + // would carry the correct batch dimension (for the current stage of the + // implementation, we do require all input tensor to carry the same batch + // size, but this could change in the future). Hence we disable shape + // inference function as a workaround. + // .SetShapeFn(shape_inference::TRTEngineOpShapeInference); + .SetShapeFn(shape_inference::UnknownShape); } // namespace tensorflow #endif // GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py index bb81fbf93f3..74a2c2392ad 100644 --- a/tensorflow/contrib/tensorrt/python/trt_convert.py +++ b/tensorflow/contrib/tensorrt/python/trt_convert.py @@ -63,19 +63,20 @@ class TrtPrecisionMode(object): return [TrtPrecisionMode.FP32, TrtPrecisionMode.FP16, TrtPrecisionMode.INT8] -def tensorrt_rewriter_config(rewriter_config=None, - max_batch_size=1, - max_workspace_size_bytes=2 << 20, - precision_mode=TrtPrecisionMode.FP32, - minimum_segment_size=3, - is_dynamic_op=False, - maximum_cached_engines=1, - cached_engine_batch_sizes=None): +def get_tensorrt_rewriter_config(rewriter_config=None, + max_batch_size=1, + max_workspace_size_bytes=2 << 20, + precision_mode=TrtPrecisionMode.FP32, + minimum_segment_size=3, + is_dynamic_op=False, + maximum_cached_engines=1, + cached_engine_batch_sizes=None, + use_calibration=True): """Returns a RewriterConfig proto for TRT transformation. Args: - rewriter_config: a RewriterConfig proto to append the TensorRTOptimizer to. - If None, it will create one with default settings. + rewriter_config: a template RewriterConfig proto used to create a + TRT-enabled RewriterConfig. If None, it will use a default one. max_batch_size: max size for the input batch max_workspace_size_bytes: the maximum GPU temporary memory which the TRT engine can use at execution time. This corresponds to the 'workspaceSize' @@ -95,6 +96,15 @@ def tensorrt_rewriter_config(rewriter_config=None, use this list to determine the batch sizes of the cached engines, instead of making the decision on the fly. This is useful when we know the most common batch size(s) the application is going to generate. + use_calibration: this argument is ignored if precision_mode is not INT8. if + set to True, a calibration graph will be created to calibrate the missing + ranges. The calibration graph must be converted to an inference graph + using calib_graph_to_infer_graph() after running calibration. if set to + False, quantization nodes will be expected for every tensor in the graph + (exlcuding those which will be fused). If a range is missing, an error + will occur. Please note that accuracy may be negatively affected if there + is a mismatch between which tensors TRT quantizes and which tensors were + trained with fake quantization. Returns: A RewriterConfig proto which sets a TensorRTOptimizer to run Grappler. @@ -107,13 +117,16 @@ def tensorrt_rewriter_config(rewriter_config=None, rewriter_config, rewriter_config_pb2.RewriterConfig): raise TypeError("rewriter_config should be a RewriterConfig proto.") + rewriter_config_with_trt = rewriter_config_pb2.RewriterConfig() if rewriter_config is None: - rewriter_config = rewriter_config_pb2.RewriterConfig() # Layout optimizer may add Const nodes followed by Reshape nodes, thus we # need to run constant folding again. - rewriter_config.optimizers.extend(["constfold", "layout", "constfold"]) - rewriter_config.meta_optimizer_iterations = ( + rewriter_config_with_trt.optimizers.extend( + ["constfold", "layout", "constfold"]) + rewriter_config_with_trt.meta_optimizer_iterations = ( rewriter_config_pb2.RewriterConfig.ONE) + else: + rewriter_config_with_trt.CopyFrom(rewriter_config) if precision_mode.upper() not in TrtPrecisionMode.supported_precision_modes(): raise ValueError(("precision mode '{}' is not supported." @@ -121,7 +134,7 @@ def tensorrt_rewriter_config(rewriter_config=None, precision_mode, TrtPrecisionMode.supported_precision_modes)) - optimizer = rewriter_config.custom_optimizers.add() + optimizer = rewriter_config_with_trt.custom_optimizers.add() optimizer.name = "TensorRTOptimizer" optimizer.parameter_map["minimum_segment_size"].i = minimum_segment_size optimizer.parameter_map["max_batch_size"].i = max_batch_size @@ -138,7 +151,8 @@ def tensorrt_rewriter_config(rewriter_config=None, "maximum_cached_engines items.") optimizer.parameter_map["cached_engine_batches"].list.i.extend( cached_engine_batch_sizes) - return rewriter_config + optimizer.parameter_map["use_calibration"].b = use_calibration + return rewriter_config_with_trt def create_inference_graph(input_graph_def, @@ -150,7 +164,7 @@ def create_inference_graph(input_graph_def, is_dynamic_op=False, maximum_cached_engines=1, cached_engine_batch_sizes=None, - rewriter_config=None, + use_calibration=True, input_saved_model_dir=None, input_saved_model_tags=None, output_saved_model_dir=None, @@ -182,8 +196,15 @@ def create_inference_graph(input_graph_def, use this list to determine the batch sizes of the cached engines, instead of making the decision on the fly. This is useful when we know the most common batch size(s) the application is going to generate. - rewriter_config: a RewriterConfig proto to append the TensorRTOptimizer to. - If None, it will create one with default settings. + use_calibration: this argument is ignored if precision_mode is not INT8. if + set to True, a calibration graph will be created to calibrate the missing + ranges. The calibration graph must be converted to an inference graph + using calib_graph_to_infer_graph() after running calibration. if set to + False, quantization nodes will be expected for every tensor in the graph + (exlcuding those which will be fused). If a range is missing, an error + will occur. Please note that accuracy may be negatively affected if there + is a mismatch between which tensors TRT quantizes and which tensors were + trained with fake quantization. input_saved_model_dir: the directory to load the SavedModel which contains the input graph to transforms. Used only when input_graph_def is None. input_saved_model_tags: list of tags to load the SavedModel. @@ -191,8 +212,9 @@ def create_inference_graph(input_graph_def, returned GraphDef and save it to the specified directory. This option only works when the input graph is loaded from a SavedModel, i.e. when input_saved_model_dir is specified and input_graph_def is None. - session_config: the ConfigProto used to create a Session. If not specified, - a default ConfigProto will be used. + session_config: the ConfigProto used to create a Session. It's also used as + a template to create a TRT-enabled ConfigProto for conversion. If not + specified, a default ConfigProto will be used. Returns: A GraphDef transformed from input_graph_def (or the SavedModel graph def @@ -322,21 +344,30 @@ def create_inference_graph(input_graph_def, grappler_meta_graph_def.collection_def["train_op"].CopyFrom( output_collection) - # Create RewriterConfig. - rewriter_config = tensorrt_rewriter_config( + # Create TRT-enabled ConfigProto. + session_config_with_trt = config_pb2.ConfigProto() + session_config_with_trt.CopyFrom(session_config) + rewriter_config = None + if (session_config_with_trt.HasField("graph_options") and + session_config_with_trt.graph_options.HasField("rewrite_options")): + rewriter_config = session_config_with_trt.graph_options.rewrite_options + rewriter_config_with_trt = get_tensorrt_rewriter_config( rewriter_config, max_batch_size, max_workspace_size_bytes, precision_mode, minimum_segment_size, is_dynamic_op, maximum_cached_engines, - cached_engine_batch_sizes) + cached_engine_batch_sizes, use_calibration) + session_config_with_trt.graph_options.rewrite_options.CopyFrom( + rewriter_config_with_trt) # Run Grappler. transformed_graph_def = tf_optimizer.OptimizeGraph( - rewriter_config, grappler_meta_graph_def, graph_id=b"tf_graph") + session_config_with_trt, grappler_meta_graph_def, graph_id=b"tf_graph") # Optionally write the transformed graphdef as SavedModel. if output_saved_model_dir is not None: saved_model_builder = builder.SavedModelBuilder(output_saved_model_dir) with ops.Graph().as_default(): importer.import_graph_def(transformed_graph_def, name="") + # We don't use TRT here. with session.Session(config=session_config) as sess: saved_model_builder.add_meta_graph_and_variables( sess, diff --git a/tensorflow/contrib/tensorrt/python/trt_convert_test.py b/tensorflow/contrib/tensorrt/python/trt_convert_test.py index 9f2eeac990d..dbf8dd26144 100644 --- a/tensorflow/contrib/tensorrt/python/trt_convert_test.py +++ b/tensorflow/contrib/tensorrt/python/trt_convert_test.py @@ -47,9 +47,9 @@ from tensorflow.python.tools import saved_model_utils class TrtConvertTest(test_util.TensorFlowTestCase): """Class to test Tensorflow-TensorRT integration python API.""" - def testTensorrtRewriterConfig(self): - """Test case for trt_convert.tensorrt_rewriter_config().""" - rewriter_cfg = trt_convert.tensorrt_rewriter_config( + def testGetTensorrtRewriterConfig(self): + """Test case for trt_convert.get_tensorrt_rewriter_config().""" + rewriter_cfg = trt_convert.get_tensorrt_rewriter_config( rewriter_config=None, max_batch_size=128, max_workspace_size_bytes=1234, @@ -162,7 +162,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase): node_name_to_op = {node.name: node.op for node in graph_def.node} self.assertEqual({ "input": "Placeholder", - "my_trt_op_0": "TRTEngineOp", + "TRTEngineOp_0": "TRTEngineOp", "output": "Identity" }, node_name_to_op) @@ -189,10 +189,11 @@ class TrtConvertTest(test_util.TensorFlowTestCase): execute_engine_test_value = ("done" if expect_engine_is_run else "") execute_native_segment_test_value = ("" if expect_engine_is_run else "done") self.assertEqual(execute_engine_test_value, - trt_convert.get_test_value("my_trt_op_0:ExecuteTrtEngine")) + trt_convert.get_test_value( + "TRTEngineOp_0:ExecuteTrtEngine")) self.assertEqual( execute_native_segment_test_value, - trt_convert.get_test_value("my_trt_op_0:ExecuteNativeSegment")) + trt_convert.get_test_value("TRTEngineOp_0:ExecuteNativeSegment")) def testCreateInferenceGraph_MinimumSegmentSize(self): if not trt_convert.is_tensorrt_enabled(): diff --git a/tensorflow/contrib/tensorrt/resources/trt_resources.h b/tensorflow/contrib/tensorrt/resources/trt_resources.h index 840da6e78d8..aac9e5c7bd7 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_resources.h +++ b/tensorflow/contrib/tensorrt/resources/trt_resources.h @@ -39,7 +39,8 @@ namespace tensorrt { class TRTCalibrationResource : public tensorflow::ResourceBase { public: ~TRTCalibrationResource() { - VLOG(0) << "Destroying Calibration Resource " << std::endl << DebugString(); + LOG(INFO) << "Destroying Calibration Resource " << std::endl + << DebugString(); builder_.reset(); engine_.reset(); // We need to manually destroy the builder and engine before the allocator diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/contrib/tensorrt/segment/segment.cc index 4f64b7a9522..d8f63779e64 100644 --- a/tensorflow/contrib/tensorrt/segment/segment.cc +++ b/tensorflow/contrib/tensorrt/segment/segment.cc @@ -33,6 +33,7 @@ namespace tensorflow { namespace tensorrt { namespace segment { using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; // A simple graph representation to mirror tensorflow::Graph. This structure // helps saving memory since segmenter modifies the graph in place, preventing @@ -406,22 +407,42 @@ tensorflow::Status SegmentGraph( // Use a union-find to collect the nodes that belong to the same // segment. A node value of nullptr indicates that the node is not a candidate // for TRT. + std::unordered_set unsupported_ops; + int num_unsupported_ops = 0; std::vector> node_segments; for (int i = 0; i < graph->num_node_ids(); ++i) { SimpleNode* node = graph->FindNodeId(i); if (options.exclude_node_list.count(node->name()) != 0) { - VLOG(1) << "Not a TF-TRT candidate: " << node->name() - << " (excluded by segmenter option)."; + VLOG(1) << "Not a TF-TRT candidate, " + << "(Op type: " << node->tf_node()->type_string() << "), " + << "(Op name: " << node->name() << "), " + << "(Reason: excluded by segmenter option)"; + unsupported_ops.emplace(node->tf_node()->type_string()); + num_unsupported_ops++; node = nullptr; } else { const Status status = candidate_fn(node->tf_node()); if (!status.ok()) { - VLOG(1) << "Not a TF-TRT candidate: " << node->name() << ": " << status; + VLOG(1) << "Not a TF-TRT candidate, " + << "(Op type: " << node->tf_node()->type_string() << "), " + << "(Op name: " << node->name() << "), " + << "(Reason: " << status << ")"; + unsupported_ops.emplace(node->tf_node()->type_string()); + num_unsupported_ops++; node = nullptr; } } node_segments.emplace_back(node); } + string msg = StrCat( + "There are ", num_unsupported_ops, " ops of ", unsupported_ops.size(), + " different types in the graph that", " are not converted to TensorRT: "); + for (const auto& elem : unsupported_ops) { + StrAppend(&msg, elem, ", "); + } + LOG(INFO) << msg << "(For more information see " + << "https://docs.nvidia.com/deeplearning" + << "/dgx/integrate-tf-trt/index.html#support-ops)."; // The segmentation algorithm below visits nodes in reverse topological order // and attempts to merge nodes along output edges. That means that subgraphs @@ -439,7 +460,8 @@ tensorflow::Status SegmentGraph( std::vector order; order.reserve(graph->num_node_ids()); StableDFS(*graph, /*reverse=*/false, {graph->source_node()}, - /*enter=*/nullptr, [&order](const SimpleNode* n) { + /*enter=*/nullptr, + [&order](const SimpleNode* n) { order.push_back(n); return true; }); @@ -548,7 +570,7 @@ tensorflow::Status SegmentGraph( std::set& segment_nodes = itr.second; VLOG(1) << "Segment original size: " << segment_nodes.size(); while (true) { - std::deque in_nodes_que, out_nodes_que; + std::deque in_nodes_que, out_nodes_que; // Find an input node that is not eligible and add it to the queue. // Nodes that has no incoming edges should not be treated as "input", // as there are really no inputs to them. Similar for output nodes. @@ -594,8 +616,7 @@ tensorflow::Status SegmentGraph( // their outputs. In this way, for common cases the number of removed // nodes should be minimum. auto remove_nodes = [&segment_nodes]( - bool is_input_nodes, - std::deque* que) { + bool is_input_nodes, std::deque* que) { // Run a BFS on the queue to find all the input/output nodes. std::set visited; std::set logged(que->begin(), que->end()); diff --git a/tensorflow/contrib/tensorrt/test/base_test.py b/tensorflow/contrib/tensorrt/test/base_test.py index 18096e0ff1e..03faf1df243 100644 --- a/tensorflow/contrib/tensorrt/test/base_test.py +++ b/tensorflow/contrib/tensorrt/test/base_test.py @@ -56,8 +56,9 @@ class SimpleSingleEngineTest(trt_test.TfTrtIntegrationTestBase): strides=[1, 2, 2, 1], padding="SAME", name="conv") - bias = constant_op.constant( - [4., 1.5, 2., 3., 5., 7.], name="bias", dtype=dtype) + bias = constant_op.constant([4., 1.5, 2., 3., 5., 7.], + name="bias", + dtype=dtype) added = nn.bias_add(conv, bias, name="bias_add") relu = nn.relu(added, "relu") identity = array_ops.identity(relu, "identity") @@ -73,11 +74,12 @@ class SimpleSingleEngineTest(trt_test.TfTrtIntegrationTestBase): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" - # TODO(aaroey): LayoutOptimizer adds additional nodes to the graph which - # breaks the connection check, fix it. - # - my_trt_op_0 should have ["weights", "conv", "bias", "bias_add", - # "relu", "identity", "max_pool"] - return ["my_trt_op_0"] + return { + "my_trt_op_0": [ + "weights", "conv", "bias", "bias_add", "relu", "identity", + "max_pool" + ] + } class SimpleMultiEnginesTest(trt_test.TfTrtIntegrationTestBase): @@ -92,7 +94,7 @@ class SimpleMultiEnginesTest(trt_test.TfTrtIntegrationTestBase): g = ops.Graph() with g.as_default(): inp = array_ops.placeholder( - dtype=dtype, shape=[None] + input_dims[1:], name=input_name) + dtype=dtype, shape=input_dims, name=input_name) with g.device("/GPU:0"): conv_filter = constant_op.constant( [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]], @@ -105,10 +107,10 @@ class SimpleMultiEnginesTest(trt_test.TfTrtIntegrationTestBase): padding="SAME", name="conv") c1 = constant_op.constant( - np.random.randn(input_dims[0], 12, 12, 6), dtype=dtype, name="c1") + np.random.randn(12, 12, 6), dtype=dtype, name="c1") p = math_ops.mul(conv, c1, name="mul") c2 = constant_op.constant( - np.random.randn(input_dims[0], 12, 12, 6), dtype=dtype, name="c2") + np.random.randn(12, 12, 6), dtype=dtype, name="c2") q = math_ops.div(conv, c2, name="div") edge = self.trt_incompatible_op(q, name="incompatible") @@ -129,22 +131,21 @@ class SimpleMultiEnginesTest(trt_test.TfTrtIntegrationTestBase): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" - # TODO(aaroey): LayoutOptimizer adds additional nodes to the graph which - # breaks the connection check, fix it. - # - my_trt_op_0 should have ["mul", "sub", "div1", "mul1", "add1", - # "add", "sub1"]; - # - my_trt_op_1 should have ["weights","conv", "div"] - return ["my_trt_op_0", "my_trt_op_1"] + return { + "my_trt_op_0": [ + "add", "add1", "c1", "div1", "mul", "mul1", "sub", "sub1" + ], + "my_trt_op_1": ["c2", "conv", "div", "weights"] + } - def ShouldRunTest(self, run_params): - # TODO(aaroey): LayoutOptimizer adds Transpose(Const, Const) to the graph - # which breaks the conversion. We should fix it as: - # - Detect the invalid NodeDef earlier before adding them to segment - # - Let it able to change the RewriterConfig when calling - # create_inference_graph(). - # It will be good to add debugging feature for Grappler to print the graph - # after running each optimizer. - return False + def GetConversionParams(self, run_params): + """Return a ConversionParams for test.""" + return super( + SimpleMultiEnginesTest, self + ).GetConversionParams(run_params)._replace( + # Disable layout optimizer, since it'll add Transpose(Const, Const) to + # the graph and breaks the conversion check. + rewriter_config=trt_test.OptimizerDisabledRewriterConfig()) class PartiallyConvertedTestA(trt_test.TfTrtIntegrationTestBase): @@ -153,7 +154,7 @@ class PartiallyConvertedTestA(trt_test.TfTrtIntegrationTestBase): """Setup method.""" super(PartiallyConvertedTestA, self).setUp() # Let it fail to build the second engine. - trt_convert.add_test_value("my_trt_op_1:CreateTRTNode", "fail") + trt_convert.add_test_value("TRTEngineOp_1:CreateTRTNode", "fail") def GetParams(self): """Create a graph containing two segment.""" @@ -190,14 +191,16 @@ class PartiallyConvertedTestA(trt_test.TfTrtIntegrationTestBase): """Return the expected engines to build.""" return { # Only the first engine is built. - "my_trt_op_0": ["c0", "c1", "add0", "add1", "mul0", "mul1"] + "TRTEngineOp_0": ["c0", "c1", "add0", "add1", "mul0", "mul1"] } def ShouldRunTest(self, run_params): """Whether to run the test.""" # Disable the test in fp16 mode since multiple matmul and add ops together # can cause overflow. - return run_params.precision_mode != "FP16" + return ((run_params.precision_mode != "FP16") and + not (trt_test.IsQuantizationMode(run_params.precision_mode) and + not run_params.use_calibration)) class PartiallyConvertedTestB(PartiallyConvertedTestA): @@ -207,13 +210,13 @@ class PartiallyConvertedTestB(PartiallyConvertedTestA): super(PartiallyConvertedTestB, self).setUp() # Let it fail to build the first engine. trt_convert.clear_test_values("") - trt_convert.add_test_value("my_trt_op_0:CreateTRTNode", "fail") + trt_convert.add_test_value("TRTEngineOp_0:CreateTRTNode", "fail") def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" return { # Only the second engine is built. - "my_trt_op_1": ["c2", "c3", "add2", "add3", "mul2", "mul3"] + "TRTEngineOp_1": ["c2", "c3", "add2", "add3", "mul2", "mul3"] } @@ -257,8 +260,8 @@ class ConstInputTest(trt_test.TfTrtIntegrationTestBase): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" return { - "my_trt_op_0": ["add", "add1", "mul"], - "my_trt_op_1": ["add2", "add3", "mul1"] + "TRTEngineOp_0": ["add", "add1", "mul"], + "TRTEngineOp_1": ["add2", "add3", "mul1"] } @@ -289,7 +292,7 @@ class ConstDataInputSingleEngineTest(trt_test.TfTrtIntegrationTestBase): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" - return {"my_trt_op_0": ["c", "add", "add1", "mul"]} + return {"TRTEngineOp_0": ["c", "add", "add1", "mul"]} class ConstDataInputMultipleEnginesTest(trt_test.TfTrtIntegrationTestBase): @@ -324,12 +327,12 @@ class ConstDataInputMultipleEnginesTest(trt_test.TfTrtIntegrationTestBase): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" return { - "my_trt_op_0": ["add2", "add3", "mul1"], + "TRTEngineOp_0": ["add2", "add3", "mul1"], # Why segment ["add", "add1", "mul"] was assigned segment id 1 # instead of 0: the parent node of this segment is actually const # node 'c', but it's removed later since it's const output of the # segment which is not allowed. - "my_trt_op_1": ["add", "add1", "mul"] + "TRTEngineOp_1": ["add", "add1", "mul"] } @@ -373,8 +376,8 @@ class ControlDependencyTest(trt_test.TfTrtIntegrationTestBase): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" return { - "my_trt_op_0": ["c1", "add", "add1", "mul"], - "my_trt_op_1": ["c2", "add2", "add3", "mul1"] + "TRTEngineOp_0": ["c1", "add", "add1", "mul"], + "TRTEngineOp_1": ["c2", "add2", "add3", "mul1"] } diff --git a/tensorflow/contrib/tensorrt/test/batch_matmul_test.py b/tensorflow/contrib/tensorrt/test/batch_matmul_test.py index 4b888081787..f42308ecb7c 100644 --- a/tensorflow/contrib/tensorrt/test/batch_matmul_test.py +++ b/tensorflow/contrib/tensorrt/test/batch_matmul_test.py @@ -79,12 +79,12 @@ class BatchMatMulTest(trt_test.TfTrtIntegrationTestBase): """Return the expected engines to build.""" if (run_params.dynamic_engine and not trt_test.IsQuantizationMode(run_params.precision_mode)): - return ["my_trt_op_0", "my_trt_op_1"] - return ["my_trt_op_1"] + return ["TRTEngineOp_0", "TRTEngineOp_1"] + return ["TRTEngineOp_1"] def ExpectedEnginesToRun(self, run_params): """Return the expected engines to run.""" - return ["my_trt_op_1"] + return ["TRTEngineOp_1"] def ShouldRunTest(self, run_params): """Whether to run the test.""" diff --git a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py index 7545bb9df20..053b38ff1c0 100644 --- a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py +++ b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py @@ -41,6 +41,7 @@ class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase): input_name = "input" input_matrix_rows = 4 input_matrix_columns = 144 + # Note that tf.nn.bias_add supports up to 5 dimensions. input_dims = [input_matrix_rows, input_matrix_columns] output_name = "output" g = ops.Graph() @@ -74,18 +75,18 @@ class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase): x5 = nn.bias_add(x5, b) x5 = gen_array_ops.reshape(x5, [4, -1]) - x6 = gen_array_ops.reshape(x, [4, 12, 12]) - b = self._ConstOp((12,)) + x6 = gen_array_ops.reshape(x, [4, 24, 6]) + b = self._ConstOp((6,)) x6 = nn.bias_add(x6, b, data_format="NHWC") x6 = gen_array_ops.reshape(x6, [4, -1]) - x7 = gen_array_ops.reshape(x, [4, 12, 3, 4]) - b = self._ConstOp((4,)) + x7 = gen_array_ops.reshape(x, [4, 12, 4, 3]) + b = self._ConstOp((3,)) x7 = nn.bias_add(x7, b, data_format="NHWC") x7 = gen_array_ops.reshape(x7, [4, -1]) - x8 = gen_array_ops.reshape(x, [4, 12, 3, 2, 2]) - b = self._ConstOp((2,)) + x8 = gen_array_ops.reshape(x, [4, 4, 3, 2, 6]) + b = self._ConstOp((6,)) x8 = nn.bias_add(x8, b, data_format="NHWC") x8 = gen_array_ops.reshape(x8, [4, -1]) @@ -94,13 +95,13 @@ class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase): x9 = nn.bias_add(x9, b, data_format="NCHW") x9 = gen_array_ops.reshape(x9, [4, -1]) - x10 = gen_array_ops.reshape(x, [4, 12, 3, 4]) - b = self._ConstOp((12,)) + x10 = gen_array_ops.reshape(x, [4, 3, 4, 12]) + b = self._ConstOp((3,)) x10 = nn.bias_add(x10, b, data_format="NCHW") x10 = gen_array_ops.reshape(x10, [4, -1]) - x11 = gen_array_ops.reshape(x, [4, 12, 12]) - b = self._ConstOp((12,)) + x11 = gen_array_ops.reshape(x, [4, 6, 24]) + b = self._ConstOp((6,)) x11 = nn.bias_add(x11, b, data_format="NCHW") x11 = gen_array_ops.reshape(x11, [4, -1]) @@ -116,13 +117,18 @@ class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase): def GetConversionParams(self, run_params): """Return a ConversionParams for test.""" - return super(BiasaddMatMulTest, - self).GetConversionParams(run_params)._replace( - max_batch_size=4, maximum_cached_engines=1) + conversion_params = super(BiasaddMatMulTest, + self).GetConversionParams(run_params) + return conversion_params._replace( + max_batch_size=4, + maximum_cached_engines=1, + # Disable layout optimizer, since it will convert BiasAdd with NHWC + # format to NCHW format under four dimentional input. + rewriter_config=trt_test.OptimizerDisabledRewriterConfig()) def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" - return ["my_trt_op_0"] + return ["TRTEngineOp_0"] def ShouldRunTest(self, run_params): """Whether to run the test.""" diff --git a/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py b/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py index b53cb3c091e..169835956c0 100644 --- a/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py +++ b/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py @@ -26,7 +26,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops -from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -56,10 +55,10 @@ class BinaryTensorWeightBroadcastTest(trt_test.TfTrtIntegrationTestBase): ]: a = self._ConstOp(weights_shape) f = x + a - x = math_ops.sigmoid(f) + x = self.trt_incompatible_op(f) a = self._ConstOp(weights_shape) f = a + x - x = math_ops.sigmoid(f) + x = self.trt_incompatible_op(f) gen_array_ops.reshape(x, [5, -1], name=output_name) return trt_test.TfTrtIntegrationTestParams( gdef=g.as_graph_def(), @@ -70,7 +69,7 @@ class BinaryTensorWeightBroadcastTest(trt_test.TfTrtIntegrationTestBase): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" - return ["my_trt_op_%d" % i for i in range(16)] + return ["TRTEngineOp_%d" % i for i in range(16)] if __name__ == "__main__": diff --git a/tensorflow/contrib/tensorrt/test/concatenation_test.py b/tensorflow/contrib/tensorrt/test/concatenation_test.py index 465cb022964..c3576f81d97 100644 --- a/tensorflow/contrib/tensorrt/test/concatenation_test.py +++ b/tensorflow/contrib/tensorrt/test/concatenation_test.py @@ -79,7 +79,7 @@ class ConcatenationTest(trt_test.TfTrtIntegrationTestBase): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" - return ["my_trt_op_0"] + return ["TRTEngineOp_0"] if __name__ == "__main__": diff --git a/tensorflow/contrib/tensorrt/test/const_broadcast_test.py b/tensorflow/contrib/tensorrt/test/const_broadcast_test.py index e32f0478661..c1c883312d8 100644 --- a/tensorflow/contrib/tensorrt/test/const_broadcast_test.py +++ b/tensorflow/contrib/tensorrt/test/const_broadcast_test.py @@ -64,7 +64,7 @@ class ConstBroadcastTest(trt_test.TfTrtIntegrationTestBase): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" - return ['my_trt_op_0'] + return ['TRTEngineOp_0'] def ExpectedAbsoluteTolerance(self, run_params): """The absolute tolerance to compare floating point results.""" diff --git a/tensorflow/contrib/tensorrt/test/memory_alignment_test.py b/tensorflow/contrib/tensorrt/test/memory_alignment_test.py index bc7c90081ff..104bac43a0b 100644 --- a/tensorflow/contrib/tensorrt/test/memory_alignment_test.py +++ b/tensorflow/contrib/tensorrt/test/memory_alignment_test.py @@ -68,7 +68,7 @@ class MemoryAlignmentTest(trt_test.TfTrtIntegrationTestBase): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" - return ["my_trt_op_0"] + return ["TRTEngineOp_0"] def ExpectedAbsoluteTolerance(self, run_params): """The absolute tolerance to compare floating point results.""" diff --git a/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py b/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py index 11be4feaf7b..293f93d8a78 100644 --- a/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py +++ b/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py @@ -25,8 +25,6 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gen_math_ops -from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.platform import test @@ -60,14 +58,14 @@ class MultiConnectionNeighborEngineTest(trt_test.TfTrtIntegrationTestBase): b = constant_op.constant( np.random.normal(5.0, 1.0, [1, 4, 1, 1]), name="bias", dtype=dtype) q = conv - b - edge = math_ops.sigmoid(q) + edge = self.trt_incompatible_op(q) b = constant_op.constant( np.random.normal(5.0, 1.0, [1, 4, 1, 1]), name="bias", dtype=dtype) d = b + conv - edge3 = math_ops.sigmoid(d) + edge3 = self.trt_incompatible_op(d) - edge1 = gen_math_ops.tan(conv) + edge1 = self.trt_incompatible_op(conv) t = t - edge1 q = q + edge t = t + q @@ -83,7 +81,7 @@ class MultiConnectionNeighborEngineTest(trt_test.TfTrtIntegrationTestBase): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" - return ["my_trt_op_0", "my_trt_op_1"] + return ["TRTEngineOp_0", "TRTEngineOp_1"] if __name__ == "__main__": diff --git a/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py b/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py index eddeafa38bc..3e1e4b088ba 100644 --- a/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py +++ b/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py @@ -66,8 +66,8 @@ class NeighboringEngineTest(trt_test.TfTrtIntegrationTestBase): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" return { - "my_trt_op_0": ["bias", "mul", "sub"], - "my_trt_op_1": ["weights", "conv"] + "TRTEngineOp_0": ["bias", "mul", "sub"], + "TRTEngineOp_1": ["weights", "conv"] } diff --git a/tensorflow/contrib/tensorrt/test/quantization_mnist_test.py b/tensorflow/contrib/tensorrt/test/quantization_mnist_test.py new file mode 100644 index 00000000000..31cbef89e23 --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/quantization_mnist_test.py @@ -0,0 +1,290 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Script to test TF-TRT INT8 conversion without calibration on Mnist model.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.tensorrt.python import trt_convert +# pylint: disable=unused-import +from tensorflow.contrib.tensorrt.python.ops import trt_engine_op +# pylint: enable=unused-import +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python import data +from tensorflow.python import keras +from tensorflow.python.estimator.estimator import Estimator +from tensorflow.python.estimator.model_fn import EstimatorSpec +from tensorflow.python.estimator.model_fn import ModeKeys +from tensorflow.python.estimator.run_config import RunConfig +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import graph_util +from tensorflow.python.framework import importer +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.keras.datasets import mnist +from tensorflow.python.layers import layers +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics +from tensorflow.python.ops import nn +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops.losses import losses +from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.summary import summary +from tensorflow.python.training import saver +from tensorflow.python.training.adam import AdamOptimizer +from tensorflow.python.training.checkpoint_management import latest_checkpoint +from tensorflow.python.training.training_util import get_global_step + +INPUT_NODE_NAME = 'input' +OUTPUT_NODE_NAME = 'output' + + +class QuantizationAwareTrainingMNISTTest(test_util.TensorFlowTestCase): + + def _BuildGraph(self, x): + + def _Quantize(x, r): + x = gen_array_ops.quantize_and_dequantize_v2(x, -r, r) + return x + + def _DenseLayer(x, num_inputs, num_outputs, quantization_range, name): + """Dense layer with quantized outputs. + + Args: + x: input to the dense layer + num_inputs: number of input columns of x + num_outputs: number of output columns + quantization_range: the min/max range for quantization + name: name of the variable scope + + Returns: + The output of the layer. + """ + with variable_scope.variable_scope(name): + kernel = variable_scope.get_variable( + 'kernel', + shape=[num_inputs, num_outputs], + dtype=dtypes.float32, + initializer=keras.initializers.glorot_uniform()) + bias = variable_scope.get_variable( + 'bias', + shape=[num_outputs], + dtype=dtypes.float32, + initializer=keras.initializers.zeros()) + x = math_ops.matmul(x, kernel) + x = _Quantize(x, quantization_range) + x = nn.bias_add(x, bias) + x = _Quantize(x, quantization_range) + return x + + x = _Quantize(x, 1) + # Conv + Bias + Relu6 + x = layers.conv2d(x, filters=32, kernel_size=3, use_bias=True) + x = nn.relu6(x) + # Conv + Bias + Relu6 + x = layers.conv2d(x, filters=64, kernel_size=3, use_bias=True) + x = nn.relu6(x) + # Reduce + x = math_ops.reduce_mean(x, [1, 2]) + x = _Quantize(x, 6) + # FC1 + x = _DenseLayer(x, 64, 512, 6, name='dense') + x = nn.relu6(x) + # FC2 + x = _DenseLayer(x, 512, 10, 25, name='dense_1') + x = array_ops.identity(x, name=OUTPUT_NODE_NAME) + return x + + def _GetGraphDef(self, use_trt, max_batch_size, model_dir): + """Get the frozen mnist GraphDef. + + Args: + use_trt: whether use TF-TRT to convert the graph. + max_batch_size: the max batch size to apply during TF-TRT conversion. + model_dir: the model directory to load the checkpoints. + + Returns: + The frozen mnist GraphDef. + """ + graph = ops.Graph() + with self.session(graph=graph) as sess: + with graph.device('/GPU:0'): + x = array_ops.placeholder( + shape=(None, 28, 28, 1), dtype=dtypes.float32, name=INPUT_NODE_NAME) + self._BuildGraph(x) + # Load weights + mnist_saver = saver.Saver() + checkpoint_file = latest_checkpoint(model_dir) + mnist_saver.restore(sess, checkpoint_file) + # Freeze + graph_def = graph_util.convert_variables_to_constants( + sess, sess.graph_def, output_node_names=[OUTPUT_NODE_NAME]) + # Convert with TF-TRT + if use_trt: + logging.info('Number of nodes before TF-TRT conversion: %d', + len(graph_def.node)) + graph_def = trt_convert.create_inference_graph( + graph_def, + outputs=[OUTPUT_NODE_NAME], + max_batch_size=max_batch_size, + precision_mode='INT8', + max_workspace_size_bytes=4096 << 19, + minimum_segment_size=2, + use_calibration=False, + ) + logging.info('Number of nodes after TF-TRT conversion: %d', + len(graph_def.node)) + num_engines = len( + [1 for n in graph_def.node if str(n.op) == 'TRTEngineOp']) + self.assertEqual(1, num_engines) + return graph_def + + def _Run(self, is_training, use_trt, batch_size, num_epochs, model_dir): + """Train or evaluate the model. + + Args: + is_training: whether to train or evaluate the model. In training mode, + quantization will be simulated where the quantize_and_dequantize_v2 are + placed. + use_trt: if true, use TRT INT8 mode for evaluation, which will perform + real quantization. Otherwise use native TensorFlow which will perform + simulated quantization. Ignored if is_training is True. + batch_size: batch size. + num_epochs: how many epochs to train. Ignored if is_training is False. + model_dir: where to save or load checkpoint. + + Returns: + The Estimator evaluation result. + """ + # Get dataset + train_data, test_data = mnist.load_data() + + def _PreprocessFn(x, y): + x = math_ops.cast(x, dtypes.float32) + x = array_ops.expand_dims(x, axis=2) + x = 2.0 * (x / 255.0) - 1.0 + y = math_ops.cast(y, dtypes.int32) + return x, y + + def _EvalInputFn(): + mnist_x, mnist_y = test_data + dataset = data.Dataset.from_tensor_slices((mnist_x, mnist_y)) + dataset = dataset.apply( + data.experimental.map_and_batch( + map_func=_PreprocessFn, + batch_size=batch_size, + num_parallel_calls=8)) + dataset = dataset.repeat(count=1) + iterator = data.make_one_shot_iterator(dataset) + features, labels = iterator.get_next() + return features, labels + + def _TrainInputFn(): + mnist_x, mnist_y = train_data + dataset = data.Dataset.from_tensor_slices((mnist_x, mnist_y)) + dataset = dataset.shuffle(2 * len(mnist_x)) + dataset = dataset.apply( + data.experimental.map_and_batch( + map_func=_PreprocessFn, + batch_size=batch_size, + num_parallel_calls=8)) + dataset = dataset.repeat(count=num_epochs) + iterator = data.make_one_shot_iterator(dataset) + features, labels = iterator.get_next() + return features, labels + + def _ModelFn(features, labels, mode): + if is_training: + logits_out = self._BuildGraph(features) + else: + graph_def = self._GetGraphDef(use_trt, batch_size, model_dir) + logits_out = importer.import_graph_def( + graph_def, + input_map={INPUT_NODE_NAME: features}, + return_elements=[OUTPUT_NODE_NAME + ':0'], + name='')[0] + + loss = losses.sparse_softmax_cross_entropy( + labels=labels, logits=logits_out) + summary.scalar('loss', loss) + + classes_out = math_ops.argmax(logits_out, axis=1, name='classes_out') + accuracy = metrics.accuracy( + labels=labels, predictions=classes_out, name='acc_op') + summary.scalar('accuracy', accuracy[1]) + + if mode == ModeKeys.EVAL: + return EstimatorSpec( + mode, loss=loss, eval_metric_ops={'accuracy': accuracy}) + elif mode == ModeKeys.TRAIN: + optimizer = AdamOptimizer(learning_rate=1e-2) + train_op = optimizer.minimize(loss, global_step=get_global_step()) + return EstimatorSpec(mode, loss=loss, train_op=train_op) + + config_proto = config_pb2.ConfigProto() + config_proto.gpu_options.allow_growth = True + estimator = Estimator( + model_fn=_ModelFn, + model_dir=model_dir if is_training else None, + config=RunConfig(session_config=config_proto)) + + if is_training: + estimator.train(_TrainInputFn) + results = estimator.evaluate(_EvalInputFn) + logging.info('accuracy: %s', str(results['accuracy'])) + return results + + # To generate the checkpoint, set a different model_dir and call self._Run() + # by setting is_training=True and num_epochs=1000, e.g.: + # model_dir = '/tmp/quantization_mnist' + # self._Run( + # is_training=True, + # use_trt=False, + # batch_size=128, + # num_epochs=100, + # model_dir=model_dir) + def testEval(self): + if not trt_convert.is_tensorrt_enabled(): + return + model_dir = test.test_src_dir_path('contrib/tensorrt/test/testdata') + + accuracy_tf_native = self._Run( + is_training=False, + use_trt=False, + batch_size=128, + num_epochs=None, + model_dir=model_dir)['accuracy'] + logging.info('accuracy_tf_native: %f', accuracy_tf_native) + self.assertAllClose(accuracy_tf_native, 0.9662) + + if trt_convert.get_linked_tensorrt_version()[0] < 5: + return + + accuracy_tf_trt = self._Run( + is_training=False, + use_trt=True, + batch_size=128, + num_epochs=None, + model_dir=model_dir)['accuracy'] + logging.info('accuracy_tf_trt: %f', accuracy_tf_trt) + self.assertAllClose(accuracy_tf_trt, 0.9677) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/tensorrt/test/quantization_test.py b/tensorflow/contrib/tensorrt/test/quantization_test.py new file mode 100644 index 00000000000..28353273ede --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/quantization_test.py @@ -0,0 +1,144 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Model script to test TF-TensorRT integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.tensorrt.python import trt_convert +from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +def _GetParams(add_quantization_nodes, dtype=dtypes.float32): + input_name = "input" + input_dims = [8, 8] + output_name = "output" + + def _Quantize(x, r): + if add_quantization_nodes: + x = gen_array_ops.fake_quant_with_min_max_vars(x, -r, r) + return x + + g = ops.Graph() + with g.as_default(): + x = array_ops.placeholder( + dtype=dtype, shape=[None] + input_dims[1:], name=input_name) + x = _Quantize(x, 10.0) + x = x + 5 + x = _Quantize(x, 15.0) + x = x - 5 + x = _Quantize(x, 10.0) + x = x * 0.1 + x = _Quantize(x, 1.0) + w = constant_op.constant(np.ones((8, 1)), dtype=dtypes.float32) + x = math_ops.matmul(x, w) + x = _Quantize(x, 10.0) + x = array_ops.identity(x, name=output_name) + + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + output_names=[output_name], + expected_output_dims=[(8, 1)]) + + +class QuantizationMissingAllRangesTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Create a graph containing single segment with no quantization ranges.""" + return _GetParams(add_quantization_nodes=False) + + def ShouldRunTest(self, run_params): + if trt_convert.get_linked_tensorrt_version()[0] < 5: + return False + # Only test static engine mode, with or without calibration. + return (trt_test.IsQuantizationMode(run_params.precision_mode) and + not run_params.use_optimizer and not run_params.dynamic_engine) + + def ExpectedEnginesToBuild(self, run_params): + """Return the expected engines to build.""" + if run_params.use_calibration: + # In static engine mode with calibration, it should build a calibration + # engine. + return ["my_trt_op_0"] + # In static engine mode without calibration, the engine building will fail + # since no quantization ranges are set, which results in no TRT nodes. + return [] + + +class QuantizationWithRangesTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Create a graph containing single segment with no quantization ranges.""" + return _GetParams(add_quantization_nodes=True) + + def ShouldRunTest(self, run_params): + if trt_convert.get_linked_tensorrt_version()[0] < 5: + return False + # Test static/dynamic engine with/without calibration. + return (trt_test.IsQuantizationMode(run_params.precision_mode) and + not run_params.use_optimizer) + + def ExpectedEnginesToBuild(self, run_params): + """Return the expected engines to build.""" + return ["my_trt_op_0"] + + def ExpectedAbsoluteTolerance(self, run_params): + """The absolute tolerance to compare floating point results.""" + return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-01 + + def ExpectedRelativeTolerance(self, run_params): + """The relative tolerance to compare floating point results.""" + return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-01 + + +class NonQuantizedPrecisionsWithRangesTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Create a graph containing single segment with no quantization ranges.""" + return _GetParams(add_quantization_nodes=True) + + def ShouldRunTest(self, run_params): + # Only test FP32/FP16 mode. + return not trt_test.IsQuantizationMode(run_params.precision_mode) + + def ExpectedEnginesToBuild(self, run_params): + """Return the expected engines to build.""" + # The fake quant ops are not supported in FP32/FP16 mode, and will split the + # graph into three TRT segments. + return ["my_trt_op_0", "my_trt_op_1", "my_trt_op_2", "my_trt_op_3"] + + def ExpectedAbsoluteTolerance(self, run_params): + """The absolute tolerance to compare floating point results.""" + return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-01 + + def ExpectedRelativeTolerance(self, run_params): + """The relative tolerance to compare floating point results.""" + return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-01 + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/tensorrt/test/rank_two_test.py b/tensorflow/contrib/tensorrt/test/rank_two_test.py index 74a4a059257..0cd733dca13 100644 --- a/tensorflow/contrib/tensorrt/test/rank_two_test.py +++ b/tensorflow/contrib/tensorrt/test/rank_two_test.py @@ -68,11 +68,11 @@ class RankTwoTest(trt_test.TfTrtIntegrationTestBase): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" return { - "my_trt_op_0": [ + "TRTEngineOp_0": [ "add0_1", "add0_2", "add0_3", "c0_1", "c0_2", "c0_3", "abs0_1", "abs0_2" ], - "my_trt_op_1": [ + "TRTEngineOp_1": [ "add", "add1_1", "add1_2", "add1_3", "c1_1", "c1_2", "c1_3", "abs1_1", "abs1_2", "reciprocal0", "reciprocal1" ], diff --git a/tensorflow/contrib/tensorrt/test/reshape_transpose_test.py b/tensorflow/contrib/tensorrt/test/reshape_transpose_test.py index bbc724ab18e..207944468ab 100644 --- a/tensorflow/contrib/tensorrt/test/reshape_transpose_test.py +++ b/tensorflow/contrib/tensorrt/test/reshape_transpose_test.py @@ -79,8 +79,8 @@ class ReshapeTest(trt_test.TfTrtIntegrationTestBase): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" return { - "my_trt_op_0": ["reshape-%d" % i for i in range(7)] + - ["reshape-%d/shape" % i for i in range(7)] + "TRTEngineOp_0": ["reshape-%d" % i for i in range(7)] + + ["reshape-%d/shape" % i for i in range(7)] } def ShouldRunTest(self, run_params): @@ -117,7 +117,7 @@ class TransposeTest(trt_test.TfTrtIntegrationTestBase): # Note: by default Grappler will run the TRT optimizer twice. At the # first time it will group the two transpose ops below to same segment # then fail the conversion due to the expected batch dimension problem. - # At the second time, since the input of bridge op is my_trt_op_0, it + # At the second time, since the input of bridge op is TRTEngineOp_0, it # will fail to do shape inference which then cause conversion to fail. # TODO(laigd): support shape inference, make TRT optimizer run only # once, and fix this. @@ -136,7 +136,7 @@ class TransposeTest(trt_test.TfTrtIntegrationTestBase): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" return { - "my_trt_op_0": [ + "TRTEngineOp_0": [ "transpose-1", "transpose-1/perm", "transposeback", "transposeback/perm" ] diff --git a/tensorflow/contrib/tensorrt/test/testdata/checkpoint b/tensorflow/contrib/tensorrt/test/testdata/checkpoint new file mode 100644 index 00000000000..a603e1aec91 --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/testdata/checkpoint @@ -0,0 +1,3 @@ +model_checkpoint_path: "model.ckpt-46900" +all_model_checkpoint_paths: "model.ckpt-0" +all_model_checkpoint_paths: "model.ckpt-46900" diff --git a/tensorflow/contrib/tensorrt/test/testdata/model.ckpt-46900.data-00000-of-00001 b/tensorflow/contrib/tensorrt/test/testdata/model.ckpt-46900.data-00000-of-00001 new file mode 100644 index 00000000000..88a998f184b Binary files /dev/null and b/tensorflow/contrib/tensorrt/test/testdata/model.ckpt-46900.data-00000-of-00001 differ diff --git a/tensorflow/contrib/tensorrt/test/testdata/model.ckpt-46900.index b/tensorflow/contrib/tensorrt/test/testdata/model.ckpt-46900.index new file mode 100644 index 00000000000..53797657133 Binary files /dev/null and b/tensorflow/contrib/tensorrt/test/testdata/model.ckpt-46900.index differ diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py index a725d0651c9..495a9391a1e 100644 --- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py +++ b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py @@ -30,6 +30,7 @@ from tensorflow.contrib.tensorrt.python import trt_convert from tensorflow.contrib.tensorrt.python.ops import trt_engine_op # pylint: enable=unused-import from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.framework import dtypes from tensorflow.python.framework import graph_io from tensorflow.python.framework import importer @@ -42,14 +43,15 @@ TfTrtIntegrationTestParams = namedtuple("TfTrtIntegrationTestParams", [ "gdef", "input_names", "input_dims", "output_names", "expected_output_dims" ]) -RunParams = namedtuple( - "RunParams", - ["use_optimizer", "precision_mode", "dynamic_engine", "test_name"]) +RunParams = namedtuple("RunParams", [ + "use_optimizer", "precision_mode", "dynamic_engine", "test_name", + "use_calibration" +]) ConversionParams = namedtuple("ConversionParams", [ "max_batch_size", "max_workspace_size_bytes", "precision_mode", "minimum_segment_size", "is_dynamic_op", "maximum_cached_engines", - "cached_engine_batch_sizes", "rewriter_config" + "cached_engine_batch_sizes", "rewriter_config", "use_calibration" ]) PRECISION_MODES = ["FP32", "FP16", "INT8"] @@ -65,6 +67,34 @@ class GraphState(object): INFERENCE = 2 +def OptimizerDisabledRewriterConfig(): + """Returns a RewriterConfig with all default Grappler optimizers disabled.""" + rewriter_config = rewriter_config_pb2.RewriterConfig() + + # Turn off all default Grappler optimizers. + off = rewriter_config_pb2.RewriterConfig.OFF + rewriter_config.layout_optimizer = off + rewriter_config.constant_folding = off + rewriter_config.shape_optimization = off + rewriter_config.remapping = off + rewriter_config.arithmetic_optimization = off + rewriter_config.dependency_optimization = off + rewriter_config.loop_optimization = off + rewriter_config.function_optimization = off + rewriter_config.debug_stripper = off + rewriter_config.disable_model_pruning = True + rewriter_config.scoped_allocator_optimization = off + rewriter_config.memory_optimization = ( + rewriter_config_pb2.RewriterConfig.NO_MEM_OPT) + rewriter_config.pin_to_host_optimization = off + rewriter_config.auto_parallel.enable = False + + # Run only once for each enabled optimizer. + rewriter_config.meta_optimizer_iterations = ( + rewriter_config_pb2.RewriterConfig.ONE) + return rewriter_config + + class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): """Class to test Tensorflow-TensorRT integration.""" @@ -139,11 +169,15 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): is_dynamic_op=run_params.dynamic_engine, maximum_cached_engines=1, cached_engine_batch_sizes=None, - rewriter_config=None) + rewriter_config=None, + use_calibration=run_params.use_calibration) def ShouldRunTest(self, run_params): """Whether to run the test.""" - return True + # This setting combination requires quantization nodes to be present in + # order to build the engine. + return not (IsQuantizationMode(run_params.precision_mode) and + not run_params.use_calibration) def VerifyRunForEngine(self, engine_name, graph_state, expect_run=True): """Verify the state of a particular engine after sess.run().""" @@ -194,34 +228,35 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): def _PrepareRun(self, graph_state): """Set up necessary testing environment before calling sess.run().""" # Clear test values added by TRTEngineOp. - trt_convert.clear_test_values("my_trt_op_.*:ExecuteTrtEngine") - trt_convert.clear_test_values("my_trt_op_.*:ExecuteCalibration") - trt_convert.clear_test_values("my_trt_op_.*:ExecuteNativeSegment") + trt_convert.clear_test_values("TRTEngineOp_.*:ExecuteTrtEngine") + trt_convert.clear_test_values("TRTEngineOp_.*:ExecuteCalibration") + trt_convert.clear_test_values("TRTEngineOp_.*:ExecuteNativeSegment") + + def _GetGPUOptions(self): + gpu_options = config_pb2.GPUOptions() + gpu_options.allow_growth = True + return gpu_options def _GetConfigProto(self, run_params, graph_state): """Get config proto based on specific settings.""" if graph_state != GraphState.ORIGINAL and run_params.use_optimizer: conversion_params = self.GetConversionParams(run_params) - rewriter_cfg = trt_convert.tensorrt_rewriter_config( + rewriter_cfg = trt_convert.get_tensorrt_rewriter_config( conversion_params.rewriter_config, conversion_params.max_batch_size, conversion_params.max_workspace_size_bytes, conversion_params.precision_mode, conversion_params.minimum_segment_size, conversion_params.is_dynamic_op, conversion_params.maximum_cached_engines, - conversion_params.cached_engine_batch_sizes) + conversion_params.cached_engine_batch_sizes, + conversion_params.use_calibration) graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg) else: graph_options = config_pb2.GraphOptions() - gpu_options = config_pb2.GPUOptions() - gpu_options.allow_growth = True - if trt_convert.get_linked_tensorrt_version()[0] == 3: - gpu_options.per_process_gpu_memory_fraction = 0.50 - config = config_pb2.ConfigProto( - gpu_options=gpu_options, graph_options=graph_options) + gpu_options=self._GetGPUOptions(), graph_options=graph_options) return config def _ExpectTestValue(self, engine_name, method, expected_value): @@ -291,6 +326,11 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): params = self._GetParamsCached() conversion_params = self.GetConversionParams(run_params) logging.info(conversion_params) + + config_for_trt = config_pb2.ConfigProto(gpu_options=self._GetGPUOptions()) + if conversion_params.rewriter_config is not None: + config_for_trt.graph_options.rewrite_options.CopyFrom( + conversion_params.rewriter_config) return trt_convert.create_inference_graph( input_graph_def=gdef, outputs=params.input_names + params.output_names, @@ -301,7 +341,8 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): is_dynamic_op=conversion_params.is_dynamic_op, maximum_cached_engines=conversion_params.maximum_cached_engines, cached_engine_batch_sizes=conversion_params.cached_engine_batch_sizes, - rewriter_config=conversion_params.rewriter_config) + use_calibration=conversion_params.use_calibration, + session_config=config_for_trt) def _WriteGraph(self, run_params, gdef, graph_state): if graph_state == GraphState.ORIGINAL: @@ -400,10 +441,12 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): is_dynamic_engine = not node.attr["static_engine"].b self.assertEqual(run_params.dynamic_engine, is_dynamic_engine, node.name) + self.assertEqual(node.attr["use_calibration"].b, + run_params.use_calibration, node.name) has_calibration_data = len(node.attr["calibration_data"].s) if (IsQuantizationMode(run_params.precision_mode) and - graph_state == GraphState.INFERENCE): + run_params.use_calibration and graph_state == GraphState.INFERENCE): self.assertTrue(has_calibration_data, node.name) else: self.assertFalse(has_calibration_data, node.name) @@ -438,6 +481,11 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): # types. scale = 10.0 if np.issubdtype(dtype, np.integer) else 1.0 dims = params.input_dims[i] + # TODO(laigd): add debug options. E.g. we can set the input data to be + # continuous natural numbers: + # seq = np.arange(np.prod(dims)) + # seq.resize(dims) + # input_data.append(scale * seq.astype(dtype)) input_data.append((scale * np.random.random_sample(dims)).astype(dtype)) self._VerifyGraphDef(run_params, input_gdef, GraphState.ORIGINAL) @@ -449,7 +497,8 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): config_no_trt, GraphState.ORIGINAL) # Run calibration if necessary. - if IsQuantizationMode(run_params.precision_mode): + if (IsQuantizationMode(run_params.precision_mode) and + run_params.use_calibration): calib_config = self._GetConfigProto(run_params, GraphState.CALIBRATE) logging.info("Running calibration graph, config:\n%s", str(calib_config)) @@ -519,27 +568,38 @@ def _AddTests(test_class): use_optimizer_options = [False, True] dynamic_engine_options = [False, True] - for (use_optimizer, precision_mode, dynamic_engine) in itertools.product( - use_optimizer_options, PRECISION_MODES, dynamic_engine_options): + use_calibration_options = [False, True] + opts = itertools.product(use_optimizer_options, PRECISION_MODES, + dynamic_engine_options, use_calibration_options) + for (use_optimizer, precision_mode, dynamic_engine, use_calibration) in opts: if IsQuantizationMode(precision_mode): if use_optimizer: # TODO(aaroey): if use_optimizer is True we need to get the inference # graphdef using custom python wrapper class, which is not currently # supported yet. continue - if not dynamic_engine: + if use_calibration and not dynamic_engine: + # Static engine with use_calibration=False will be static, so we want to + # test that. If use_calibration=True, only dynamic op is supported. # TODO(aaroey): construction of static calibration engine is not # supported yet. continue + else: + if use_calibration: + # Don't calibrate in FP32 or FP16 mode + continue conversion = "OptimizerConversion" if use_optimizer else "ToolConversion" - engine_type = ("DynamicEngine" if dynamic_engine else "StaticEngine") - test_name = "%s_%s_%s" % (conversion, precision_mode, engine_type) + engine_type = "DynamicEngine" if dynamic_engine else "StaticEngine" + calibration_type = "UseCalibration" if use_calibration else "NoCalibration" + test_name = "%s_%s_%s_%s" % (conversion, engine_type, precision_mode, + calibration_type) run_params = RunParams( use_optimizer=use_optimizer, precision_mode=precision_mode, dynamic_engine=dynamic_engine, - test_name=test_name) + test_name=test_name, + use_calibration=use_calibration) setattr(test_class, "testTfTrt_" + test_name, _GetTest(run_params)) diff --git a/tensorflow/contrib/tensorrt/test/unary_test.py b/tensorflow/contrib/tensorrt/test/unary_test.py index 8736bfb6449..9fc50e05952 100644 --- a/tensorflow/contrib/tensorrt/test/unary_test.py +++ b/tensorflow/contrib/tensorrt/test/unary_test.py @@ -107,8 +107,8 @@ class UnaryTest(trt_test.TfTrtIntegrationTestBase): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" return [ - "my_trt_op_0", "my_trt_op_1", "my_trt_op_2", "my_trt_op_3", - "my_trt_op_4" + "TRTEngineOp_0", "TRTEngineOp_1", "TRTEngineOp_2", "TRTEngineOp_3", + "TRTEngineOp_4" ] diff --git a/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py b/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py index b0271a04b36..b29626d2c28 100644 --- a/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py +++ b/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py @@ -76,7 +76,7 @@ class VGGBlockNCHWTest(trt_test.TfTrtIntegrationTestBase): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" - return ["my_trt_op_0"] + return ["TRTEngineOp_0"] if __name__ == "__main__": diff --git a/tensorflow/contrib/tensorrt/test/vgg_block_test.py b/tensorflow/contrib/tensorrt/test/vgg_block_test.py index d7c165784bf..9b0b1896260 100644 --- a/tensorflow/contrib/tensorrt/test/vgg_block_test.py +++ b/tensorflow/contrib/tensorrt/test/vgg_block_test.py @@ -67,7 +67,7 @@ class VGGBlockTest(trt_test.TfTrtIntegrationTestBase): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" - return ["my_trt_op_0"] + return ["TRTEngineOp_0"] if __name__ == "__main__": diff --git a/tensorflow/contrib/timeseries/python/timeseries/BUILD b/tensorflow/contrib/timeseries/python/timeseries/BUILD index c230919168b..4b90b596b28 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/BUILD +++ b/tensorflow/contrib/timeseries/python/timeseries/BUILD @@ -104,8 +104,10 @@ py_test( srcs = [ "estimators_test.py", ], + shard_count = 3, srcs_version = "PY2AND3", tags = [ + "no_mac", "no_pip_gpu", # b/63391119 "nomsan", # Takes too long to run. "notsan", # b/67865658 diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators.py b/tensorflow/contrib/timeseries/python/timeseries/estimators.py index af68aa03cf6..146ed9f2713 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/estimators.py +++ b/tensorflow/contrib/timeseries/python/timeseries/estimators.py @@ -32,7 +32,7 @@ from tensorflow.contrib.timeseries.python.timeseries.state_space_models.filterin from tensorflow.python.estimator import estimator_lib from tensorflow.python.estimator.canned import optimizers from tensorflow.python.estimator.export import export_lib -from tensorflow.python.feature_column import feature_column +from tensorflow.python.feature_column import feature_column_lib as feature_column from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py index ffd838be40e..7d780559f97 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py +++ b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py @@ -30,7 +30,7 @@ from tensorflow.contrib.timeseries.python.timeseries import saved_model_utils from tensorflow.python.client import session from tensorflow.python.estimator import estimator_lib -from tensorflow.python.feature_column import feature_column +from tensorflow.python.feature_column import feature_column_lib as feature_column from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.platform import test diff --git a/tensorflow/contrib/timeseries/python/timeseries/head_test.py b/tensorflow/contrib/timeseries/python/timeseries/head_test.py index 90c7d8ac1a9..8f692d94da4 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/head_test.py +++ b/tensorflow/contrib/timeseries/python/timeseries/head_test.py @@ -38,7 +38,7 @@ from tensorflow.core.example import example_pb2 from tensorflow.python.client import session as session_lib from tensorflow.python.estimator import estimator_lib -from tensorflow.python.feature_column import feature_column +from tensorflow.python.feature_column import feature_column_lib as feature_column from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops diff --git a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py index 43c5267e632..aab33064386 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py +++ b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py @@ -802,7 +802,7 @@ class InputStatisticsFromMiniBatch(object): array_ops.shape(times)[1] - 1, self._dtype)) # Co-locate updates with their variables to minimize race conditions when # updating statistics. - with ops.colocate_with(auxiliary_variables.max_time_seen): + with ops.device(auxiliary_variables.max_time_seen.device): # There is a race condition if this value is being updated from multiple # workers. However, it should eventually reach the correct value if the # last chunk is presented enough times. @@ -810,16 +810,16 @@ class InputStatisticsFromMiniBatch(object): auxiliary_variables.max_time_seen, gen_math_ops.maximum(auxiliary_variables.max_time_seen, math_ops.reduce_max(times))) - with ops.colocate_with(auxiliary_variables.chunk_count): + with ops.device(auxiliary_variables.chunk_count.device): chunk_count_assign = state_ops.assign_add(auxiliary_variables.chunk_count, array_ops.shape( times, out_type=dtypes.int64)[0]) - with ops.colocate_with(auxiliary_variables.inter_observation_duration_sum): + with ops.device(auxiliary_variables.inter_observation_duration_sum.device): inter_observation_duration_assign = state_ops.assign_add( auxiliary_variables.inter_observation_duration_sum, math_ops.reduce_sum(batch_inter_observation_duration)) - with ops.colocate_with(auxiliary_variables.example_count): + with ops.device(auxiliary_variables.example_count.device): example_count_assign = state_ops.assign_add( auxiliary_variables.example_count, array_ops.size(times, out_type=dtypes.int64)) @@ -829,11 +829,11 @@ class InputStatisticsFromMiniBatch(object): # the series are then members of fewer chunks. For series which are much # longer than the chunk size (the usual/expected case), this effect becomes # irrelevant. - with ops.colocate_with(auxiliary_variables.overall_feature_sum): + with ops.device(auxiliary_variables.overall_feature_sum.device): overall_feature_sum_assign = state_ops.assign_add( auxiliary_variables.overall_feature_sum, math_ops.reduce_sum(values, axis=[0, 1])) - with ops.colocate_with(auxiliary_variables.overall_feature_sum_of_squares): + with ops.device(auxiliary_variables.overall_feature_sum_of_squares.device): overall_feature_sum_of_squares_assign = state_ops.assign_add( auxiliary_variables.overall_feature_sum_of_squares, math_ops.reduce_sum(values**2, axis=[0, 1])) @@ -869,7 +869,7 @@ class InputStatisticsFromMiniBatch(object): state_ops.assign(statistics.series_start_moments.mean, mean), state_ops.assign(statistics.series_start_moments.variance, variance)) - with ops.colocate_with(statistics.start_time): + with ops.device(statistics.start_time.device): series_start_update = control_flow_ops.cond( # Update moments whenever we even match the lowest time seen so far, # to ensure that series start statistics are eventually updated to diff --git a/tensorflow/contrib/timeseries/python/timeseries/model.py b/tensorflow/contrib/timeseries/python/timeseries/model.py index edd97b2a4c1..a8cd4287e00 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/model.py +++ b/tensorflow/contrib/timeseries/python/timeseries/model.py @@ -27,7 +27,7 @@ from tensorflow.contrib.timeseries.python.timeseries import math_utils from tensorflow.contrib.timeseries.python.timeseries.feature_keys import PredictionFeatures from tensorflow.contrib.timeseries.python.timeseries.feature_keys import TrainEvalFeatures -from tensorflow.python.feature_column import feature_column +from tensorflow.python.feature_column import feature_column_lib as feature_column from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD index 3c07a74ed8a..125750e7639 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD +++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD @@ -40,7 +40,10 @@ py_test( timeout = "long", # Moderate but for asan srcs = ["state_space_model_test.py"], srcs_version = "PY2AND3", - tags = ["no_windows"], # TODO: needs investigation on Windows + tags = [ + "no_mac", + "no_windows", # TODO: needs investigation on Windows + ], deps = [ ":state_space_model", "//tensorflow/contrib/layers:layers_py", diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index a0a9cb3f31a..05d2ebd2e8a 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -14,6 +14,7 @@ load("//tensorflow:tensorflow.bzl", "tf_py_test") package( default_visibility = [ "//cloud/vmm/testing/tests/tpu:__subpackages__", + "//knowledge/cerebra/sense/im2query:__subpackages__", "//learning/brain:__subpackages__", "//learning/deepmind:__subpackages__", "//medical/pathology:__subpackages__", @@ -215,7 +216,7 @@ py_library( ], deps = [ ":tpu_lib", - "//tensorflow/contrib/cluster_resolver:tpu_cluster_resolver_py", + "//tensorflow/contrib/cluster_resolver:cluster_resolver_py", "//tensorflow/contrib/distribute", "//tensorflow/contrib/framework:framework_py", "//tensorflow/contrib/tpu/proto:compilation_result_proto_py", @@ -263,7 +264,7 @@ py_library( ":tpu_py", "//tensorflow/compiler/xla/experimental/xla_sharding", "//tensorflow/compiler/xla/python_api:xla_shape", - "//tensorflow/contrib/cluster_resolver:tpu_cluster_resolver_py", + "//tensorflow/contrib/cluster_resolver:cluster_resolver_py", "//tensorflow/contrib/compiler:xla", "//tensorflow/contrib/tpu/proto:compilation_result_proto_py", "//tensorflow/contrib/tpu/proto:optimization_parameters_proto_py", diff --git a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py index 63641e00c5d..a081c4354a7 100644 --- a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py +++ b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py @@ -90,12 +90,12 @@ def main(unused_argv=None): tf_version = tf.__version__ print('TensorFlow version %s detected' % tf_version) - if FLAGS.service_addr is None and FLAGS.tpu is None: + if not FLAGS.service_addr and not FLAGS.tpu: sys.exit('You must specify either --service_addr or --tpu.') tpu_cluster_resolver = None - if FLAGS.service_addr is not None: - if FLAGS.tpu is not None: + if FLAGS.service_addr: + if FLAGS.tpu: tf.logging.warn('Both --service_addr and --tpu are set. Ignoring ' '--tpu and using --service_addr.') service_addr = FLAGS.service_addr diff --git a/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py b/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py index 1cf7f9fcf67..1b09ce173a6 100644 --- a/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py +++ b/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py @@ -80,6 +80,8 @@ class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook): self._summary_writer = None self._global_step_tensor = None + self._last_checkpoint_step = None + def _set_steps_per_run(self, steps_per_run): self._steps_per_run = steps_per_run @@ -137,8 +139,7 @@ class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook): last_step = session.run(self._global_step_tensor) - # Save the last checkpoint synchronously if needed. - if last_step != self._timer.last_triggered_step(): + if self._last_checkpoint_step != last_step: self._save(session, last_step, asynchronous=False) for l in self._listeners: @@ -174,6 +175,7 @@ class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook): logging.info("Checkpoint finished for %d into %s.", step, self._save_path) if not asynchronous: + self._last_checkpoint_step = step _save_fn() return @@ -183,6 +185,7 @@ class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook): logging.info("Saver thread still in progress, skipping checkpoint.") return + self._last_checkpoint_step = step self._save_thread = threading.Thread(target=_save_fn) self._save_thread.start() diff --git a/tensorflow/contrib/tpu/python/tpu/datasets.py b/tensorflow/contrib/tpu/python/tpu/datasets.py index c694e9c1bca..8d6245390fc 100644 --- a/tensorflow/contrib/tpu/python/tpu/datasets.py +++ b/tensorflow/contrib/tpu/python/tpu/datasets.py @@ -133,7 +133,7 @@ def StreamingFilesDataset(files, with ops.device('/job:%s' % file_reader_job): if isinstance(files, str): source_dataset = dataset_ops.Dataset.list_files(files) - elif isinstance(files, dataset_ops.Dataset): + elif isinstance(files, dataset_ops.DatasetV2): source_dataset = files else: raise ValueError('files was not a string or a dataset: %s' % files) @@ -156,7 +156,7 @@ def StreamingFilesDataset(files, source_dataset = source_dataset.prefetch(1) - source_iterator = source_dataset.make_one_shot_iterator() + source_iterator = dataset_ops.make_one_shot_iterator(source_dataset) source_handle = source_iterator.string_handle() @function.Defun(dtypes.string) diff --git a/tensorflow/contrib/tpu/python/tpu/datasets_test.py b/tensorflow/contrib/tpu/python/tpu/datasets_test.py index b58d05eac56..52d87b80040 100644 --- a/tensorflow/contrib/tpu/python/tpu/datasets_test.py +++ b/tensorflow/contrib/tpu/python/tpu/datasets_test.py @@ -70,7 +70,7 @@ class DatasetsTest(test.TestCase): dataset = datasets.StreamingFilesDataset( os.path.join(self.get_temp_dir(), 'text_line.*.txt'), filetype='text') - iterator = dataset.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(dataset) self._sess.run(iterator.initializer) get_next = iterator.get_next() @@ -94,7 +94,7 @@ class DatasetsTest(test.TestCase): dataset = datasets.StreamingFilesDataset( os.path.join(self.get_temp_dir(), 'tf_record*'), filetype='tfrecord') - iterator = dataset.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(dataset) self._sess.run(iterator.initializer) get_next = iterator.get_next() @@ -121,7 +121,7 @@ class DatasetsTest(test.TestCase): dataset = datasets.StreamingFilesDataset(filenames, filetype='tfrecord') - iterator = dataset.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(dataset) self._sess.run(iterator.initializer) get_next = iterator.get_next() @@ -154,7 +154,7 @@ class DatasetsTest(test.TestCase): os.path.join(self.get_temp_dir(), 'fixed_length*'), filetype=FixedLengthFile) - iterator = dataset.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(dataset) self._sess.run(iterator.initializer) get_next = iterator.get_next() @@ -177,7 +177,7 @@ class DatasetsTest(test.TestCase): dataset = datasets.StreamingFilesDataset( dataset_ops.Dataset.range(10), filetype=gen_dataset) - iterator = dataset.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(dataset) self._sess.run(iterator.initializer) get_next = iterator.get_next() diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index 08f58a5f5b8..ebf40827e45 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -81,6 +81,7 @@ from tensorflow.python.keras import metrics as metrics_module from tensorflow.python.keras import models from tensorflow.python.keras import optimizers as keras_optimizers from tensorflow.python.keras.engine import base_layer +from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.keras.engine import training_arrays from tensorflow.python.keras.engine import training_utils from tensorflow.python.keras.layers import embeddings @@ -438,7 +439,7 @@ class TPURewriteContext(object): self._default_placeholder = array_ops.placeholder self._default_name_scope = ops.name_scope - self._default_make_variable = base_layer.make_variable + self._default_make_variable = base_layer_utils.make_variable self._default_random_normal = random_ops.random_normal self._default_qr = gen_linalg_ops.qr @@ -486,14 +487,14 @@ class TPURewriteContext(object): gen_linalg_ops.qr = qr ops.name_scope = _name_scope - base_layer.make_variable = variable_scope.get_variable + base_layer_utils.make_variable = variable_scope.get_variable logging.info('Overriding default placeholder.') return def __exit__(self, exc_type, exc_val, exc_tb): array_ops.placeholder = self._default_placeholder ops.name_scope = self._default_name_scope - base_layer.make_variable = self._default_make_variable + base_layer_utils.make_variable = self._default_make_variable random_ops.random_normal = self._default_random_normal gen_linalg_ops.qr = self._default_qr @@ -728,7 +729,7 @@ class TPUDatasetInfeedManager(TPUInfeedManager): dummy_x_shape[0] *= tpu_assignment.num_towers dummy_y_shape = dataset.output_shapes[1].as_list() dummy_y_shape[0] *= tpu_assignment.num_towers - self._iterator = dataset.make_initializable_iterator() + self._iterator = dataset_ops.make_initializable_iterator(dataset) K.get_session().run(self._iterator.initializer) self._get_next_ops = [] @@ -769,7 +770,7 @@ class TPUDatasetInfeedManager(TPUInfeedManager): def _verify_dataset_shape(self, dataset): """Verifies a dataset is of an appropriate shape for TPUs.""" - if not isinstance(dataset, dataset_ops.Dataset): + if not isinstance(dataset, dataset_ops.DatasetV2): raise ValueError('The function passed as the `x` parameter did not ' 'return a `tf.data.Dataset`.') if not isinstance(dataset.output_classes, tuple): @@ -1012,9 +1013,10 @@ class TPUFunction(object): optimizer=_replicated_optimizer(self._cloned_optimizer), loss=self.model.loss, loss_weights=self.model.loss_weights, - metrics=metrics_module.clone_metrics(self.model.metrics), + metrics=metrics_module.clone_metrics( + self.model._compile_metrics), weighted_metrics=metrics_module.clone_metrics( - self.model.weighted_metrics), + self.model._compile_weighted_metrics), target_tensors=tpu_targets, ) @@ -1184,12 +1186,9 @@ class TPUFunction(object): # pipelined loop. return None, None - if not isinstance(K.learning_phase(), int): + if isinstance(inputs[-1], int): # Remove the learning_phase flag at the end. We currently hard code the # learning_phase in TPUFunction. - assert isinstance(inputs[-1], int), ( - 'Expect the final element be learning_phase flag. Got {}'.format( - inputs[-1])) inputs = inputs[:-1] if (self.execution_mode == model_fn_lib.ModeKeys.TRAIN or @@ -1379,6 +1378,7 @@ class KerasTPUModel(models.Model): self.train_function = None self._fit_function = None self._eval_function = None + self._stateful_metric_functions = [] cluster_resolver = strategy._tpu_cluster_resolver self._tpu_name_or_address = cluster_resolver.get_master() @@ -1393,10 +1393,10 @@ class KerasTPUModel(models.Model): self.compile( self._cpu_model.optimizer, self._cpu_model.loss, - self._cpu_model.metrics, + self._cpu_model._compile_metrics, self._cpu_model.loss_weights, self._cpu_model.sample_weight_mode, - self._cpu_model.weighted_metrics, + self._cpu_model._compile_weighted_metrics, self._cpu_model.target_tensors, ) @@ -1466,7 +1466,7 @@ class KerasTPUModel(models.Model): assert not self._numpy_to_infeed_manager_list # Ensure empty. infeed_managers = [] # Managers to clean up at the end of the fit call. - if isinstance(x, dataset_ops.Dataset): + if isinstance(x, dataset_ops.DatasetV2): # TODO(b/111413240): Support taking a tf.data.Dataset directly. raise ValueError( 'Taking a Dataset directly is not yet supported. Please ' @@ -1492,7 +1492,7 @@ class KerasTPUModel(models.Model): y = infeed_manager.dummy_y infeed_managers.append((x, infeed_manager)) - if isinstance(validation_data, dataset_ops.Dataset): + if isinstance(validation_data, dataset_ops.DatasetV2): # TODO(b/111413240): Support taking a tf.data.Dataset directly. raise ValueError( 'Taking a Dataset directly is not yet supported. Please ' @@ -1551,7 +1551,7 @@ class KerasTPUModel(models.Model): with _tpu_session_context(): # Managers to clean up at the end of the evaluate call. infeed_managers = [] - if isinstance(x, dataset_ops.Dataset): + if isinstance(x, dataset_ops.DatasetV2): # TODO(b/111413240): Support taking a tf.data.Dataset directly. raise ValueError( 'Taking a Dataset directly is not yet supported. Please ' @@ -1676,14 +1676,10 @@ class KerasTPUModel(models.Model): callbacks, self, do_validation=do_validation, - val_inputs=val_inputs, - val_targets=val_targets, - val_sample_weights=val_sample_weights, batch_size=batch_size, epochs=epochs, steps_per_epoch=steps_per_epoch, samples=num_training_samples, - validation_steps=validation_steps, verbose=verbose, count_mode=count_mode) @@ -1700,7 +1696,7 @@ class KerasTPUModel(models.Model): callbacks.on_train_begin() for epoch in range(initial_epoch, epochs): # Reset stateful metrics - for m in self.stateful_metric_functions: + for m in self.metrics: m.reset_states() # Update callbacks callbacks.on_epoch_begin(epoch) @@ -1923,7 +1919,7 @@ class KerasTPUModel(models.Model): if validation_data: if (isinstance(validation_data, iterator_ops.Iterator) or isinstance(validation_data, iterator_ops.EagerIterator) or - isinstance(validation_data, dataset_ops.Dataset)): + isinstance(validation_data, dataset_ops.DatasetV2)): raise ValueError('KerasTPUModel cannot handle a Dataset or Iterator ' 'for validation_data. Please instead pass a function ' 'that returns a `tf.data.Dataset`.') @@ -1998,14 +1994,14 @@ class KerasTPUModel(models.Model): self._optimizer = optimizer @property - def stateful_metric_functions(self): + def metrics(self): if self._tpu_model: - return self._tpu_model.stateful_metric_functions + return self._tpu_model.metrics return self._stateful_metric_functions - @stateful_metric_functions.setter - def stateful_metric_functions(self, stateful_metric_functions): - self._stateful_metric_functions = stateful_metric_functions + @metrics.setter + def metrics(self, metrics): + self._stateful_metric_functions = metrics def _make_train_function(self): if not self.train_function: @@ -2230,10 +2226,10 @@ def tpu_model(model, strategy=None): cpu_model.compile( _clone_optimizer(model.optimizer, optimizer_config), model.loss, - metrics_module.clone_metrics(model.metrics), + metrics_module.clone_metrics(model._compile_metrics), model.loss_weights, model.sample_weight_mode, - metrics_module.clone_metrics(model.weighted_metrics), + metrics_module.clone_metrics(model._compile_weighted_metrics), ) if model_weights: diff --git a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py index 28d3a938510..8b0b240dc73 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py @@ -217,6 +217,10 @@ class ReplicatedVariable(object): def get(self): return self._primary_var + @property + def _in_graph_mode(self): + return self._primary_var._in_graph_mode # pylint: disable=protected-access + def _should_act_as_resource_variable(self): """Pass resource_variable_ops.is_resource_variable check.""" pass diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index e3e791faacb..def57da20d6 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -1001,8 +1001,8 @@ def rewrite(computation, `rewrite` is a list of tensors corresponding to the tensors from the output of `computation`. - All `Operation`s returned from `computation` will be executed when - evaluating any of the returned output tensors. + All `Operation`s constructed during `computation` will be executed when + evaluating any of the returned output tensors, not just the ones returned. inputs: A list of input tensors or `None` (equivalent to an empty list). infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple of arguments as inputs to `computation`. @@ -1111,7 +1111,7 @@ def validate_inference_rewrite_for_variables(graph): Raises: RuntimeError: if validation failed. """ - if not any([x.type == "GuaranteeConst" for x in graph.get_operations()]): + if not any(x.type == "GuaranteeConst" for x in graph.get_operations()): raise RuntimeError( "No GuaranteeConst ops found in the graph after running " "tpu.rewrite_for_inference(...). Please check that you are using " diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py index da6bdf67d68..67246244794 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py @@ -41,7 +41,7 @@ _NUM_CORES_TO_COMPUTATION_SHAPE = { class TPUContext(object): - """The context of current input_fn invocation.""" + """A context that holds the current configuration of the TPU computation.""" def __init__(self, internal_ctx, diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py b/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py index 3fe896426a7..ccba8a46c7c 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py @@ -1069,17 +1069,14 @@ def _create_partitioned_variables(name, 'As TPU embedding is not optimized for small tables, ' 'please consider other ways for this embedding lookup.') - slicing = [num_hosts, 1] - - # TODO(shizhiw): deprecated, use tf.get_variable()? - return partitioned_variables.create_partitioned_variables( - name=name, - slicing=slicing, + return list(variable_scope.get_variable( + name, shape=(vocabulary_size, embedding_dimension), + partitioner=partitioned_variables.fixed_size_partitioner(num_hosts), dtype=dtypes.float32, initializer=initializer, collections=collections, - trainable=False) + trainable=False)) @ops.RegisterGradient('TPUEmbeddingActivations') diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 7cb8c4aa7f1..a9dc542ae5e 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -298,9 +298,9 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote host_calls['host_call'] = host_call _OutfeedHostCall.validate(host_calls) - training_hooks = list(training_hooks or []) - evaluation_hooks = list(evaluation_hooks or []) - prediction_hooks = list(prediction_hooks or []) + training_hooks = tuple(training_hooks or []) + evaluation_hooks = tuple(evaluation_hooks or []) + prediction_hooks = tuple(prediction_hooks or []) for hook in training_hooks + evaluation_hooks + prediction_hooks: if not isinstance(hook, session_run_hook.SessionRunHook): @@ -335,7 +335,7 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote hooks = None if self.host_call is not None: hooks = [_OutfeedHostCallHook(host_call_ret['host_call'])] - hooks = list(hooks or []) + hooks = tuple(hooks or []) scaffold = self.scaffold_fn() if self.scaffold_fn else None return model_fn_lib.EstimatorSpec( mode=self.mode, @@ -2169,7 +2169,6 @@ class TPUEstimator(estimator_lib.Estimator): builder, input_receiver_fn_map, checkpoint_path, - strip_default_attrs, save_variables=True, mode=model_fn_lib.ModeKeys.PREDICT, export_tags=None, @@ -2184,7 +2183,6 @@ class TPUEstimator(estimator_lib.Estimator): builder, input_receiver_fn_map, checkpoint_path, - strip_default_attrs, save_variables, mode=mode, export_tags=export_tags, @@ -2201,7 +2199,6 @@ class TPUEstimator(estimator_lib.Estimator): builder, input_receiver_fn_map, checkpoint_path, - strip_default_attrs, save_variables=False, mode=mode, export_tags=export_tags, @@ -2783,7 +2780,7 @@ def _export_output_to_tensors(export_output): elif isinstance(export_output, export_output_lib.RegressionOutput): return [export_output.value] elif isinstance(export_output, export_output_lib.PredictOutput): - return export_output.outputs.values() + return list(export_output.outputs.values()) else: raise ValueError( '`export_output` must be have type `ClassificationOutput`, ' @@ -3059,7 +3056,7 @@ class _Inputs(object): @staticmethod def from_input_fn(return_values): """Returns an `_Inputs` instance according to `input_fn` return value.""" - if isinstance(return_values, dataset_ops.Dataset): + if isinstance(return_values, dataset_ops.DatasetV2): dataset = return_values return _Inputs(dataset=dataset) @@ -3084,7 +3081,7 @@ class _Inputs(object): The initializer must be run before calling `features_and_labels`. """ - self._iterator = self._dataset.make_initializable_iterator() + self._iterator = dataset_ops.make_initializable_iterator(self._dataset) return self._iterator.initializer def features_and_labels(self): diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py index 3786e52b949..55235556de0 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py @@ -71,7 +71,7 @@ class TPUEstimatorStoppingSignalsTest(test.TestCase): with ops.Graph().as_default(): dataset = input_fn(params) - features = dataset.make_one_shot_iterator().get_next() + features = dataset_lib.make_one_shot_iterator(dataset).get_next() # With tf.data.Dataset.batch, the batch is None, i.e., dynamic shape. self.assertIsNone(features['a'].shape.as_list()[0]) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_feed.py b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py index e75a09492ec..d5957b7e8ec 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_feed.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py @@ -26,7 +26,6 @@ import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding -from tensorflow.compiler.xla.python_api import xla_shape from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu from tensorflow.contrib.tpu.python.tpu import tpu_sharding @@ -92,8 +91,7 @@ class InfeedQueue(object): else: raise ValueError( "number of tuple elements cannot be inferred from InfeedQueue " - "constructor" - ) + "constructor") if number_of_tuple_elements <= 0: raise ValueError("number_of_tuple_elements %d must be > 0" % number_of_tuple_elements) @@ -293,9 +291,8 @@ class InfeedQueue(object): self.number_of_tuple_elements """ if len(input_tensors) != self.number_of_tuple_elements: - raise ValueError( - "input_tensors is %s, but should be a list of %d Tensors", ( - str(input_tensors), self.number_of_tuple_elements)) + raise ValueError("input_tensors is %s, but should be a list of %d Tensors" + % (str(input_tensors), self.number_of_tuple_elements)) self.set_tuple_shapes([t.shape for t in input_tensors]) self.set_tuple_types([t.dtype for t in input_tensors]) @@ -451,8 +448,8 @@ class InfeedQueue(object): for i in xrange(1, self.number_of_tuple_elements): if devices[0] != devices[i]: raise ValueError( - "input devices for shard %d are %s, but should all be the same", - index, str(devices)) + "input devices for shard %d are %s, but should all be the same" % + (index, str(devices))) with ops.colocate_with(inputs[0]): return tpu_ops.infeed_enqueue_tuple( inputs=inputs, @@ -792,18 +789,14 @@ class _PartitionedInfeedQueue(InfeedQueue): Args: tensor: Input tensor for partitioning. - dims: A list of integer describes how to partition the input tensor. + dims: 1-D np.array of the list of integer describes how to partition the + input tensor. Raises: ValueError: If the tensor can't be partitioned by dims or the num_cores_per_replica doesn't match the number of partitions(dims.prod()). """ - if dims is None: - return - - dims = np.array(dims) - if (dims < 1).any(): raise ValueError("All input partition dims must be >= 1.") @@ -823,11 +816,6 @@ class _PartitionedInfeedQueue(InfeedQueue): "partition dims = {}).".format(tensor.shape.as_list(), dims)) tensor.shape.assert_is_fully_defined() - if (np.array(tensor.shape.as_list()) % dims != 0).any(): - raise ValueError( - "All input partition dims must divide exactly into the `Tensor` " - "shape (tensor shape = {}, input partition dims = {}).".format( - tensor.shape.as_list(), dims)) def _partition_or_replicate_on_host(self, tensor, dims): """Partitions or replicates the input tensor. @@ -840,16 +828,39 @@ class _PartitionedInfeedQueue(InfeedQueue): Returns: An iterator of `Tensor`s or a list of partioned tensors. """ - self._check_input_partition_dims(tensor, dims) if dims is None: return itertools.repeat(tensor) - else: - output = [tensor] - for axis, dim in enumerate(dims): - if dim > 1: - output = [array_ops.split(x, dim, axis=axis) for x in output] - output = nest.flatten(output) - return output + dims = np.array(dims) + self._check_input_partition_dims(tensor, dims) + output = [tensor] + shape_list = np.array(tensor.shape.as_list()) + quotients, remainders = np.divmod(shape_list, dims) + for axis, (quotient, remainder, dim, original_size) in enumerate( + zip(quotients, remainders, dims, shape_list)): + if dim <= 1: + continue + if remainder > 0: + # For each dimension, when it cannot be evenly partitioned, XLA assumes + # tensors are partitioned in a greedy manner by using + # ceil_ratio(size/dim) first. E.g. 2D tensor with shape (5, 14) and dims + # are (2, 4). Since 5 % 2 = 1 and 14 % 4 = 2, [5, 14] => + # [[(3, 4), (3, 4), (2, 4), (2, 2)], + # [(2, 4), (2, 4), (2, 4), (2, 2)]] + ceil_ratio = quotient + 1 + num_full_slots, left_over = np.divmod(original_size, ceil_ratio) + num_or_size_splits = [ceil_ratio] * num_full_slots + [left_over] + if len(num_or_size_splits) < dim: + num_or_size_splits += [0] * (dim - len(num_or_size_splits)) + new_output = [] + for x in output: + new_output.append( + array_ops.split( + x, num_or_size_splits=num_or_size_splits, axis=axis)) + output = new_output + else: + output = [array_ops.split(x, dim, axis=axis) for x in output] + output = nest.flatten(output) + return output def _tag_sharding_attribute_for_dequeued_tensor(self, tensor, dims): """Tags appropriate XLA sharding attribute to the dequeued tensor. @@ -866,13 +877,9 @@ class _PartitionedInfeedQueue(InfeedQueue): elif np.prod(dims) == 1: return xla_sharding.assign_device(tensor, 0) else: - tile_shape = np.array(tensor.shape.as_list()) // dims tile_assignment = np.arange(np.prod(dims)).reshape(dims) return xla_sharding.tile( tensor=tensor, - tile_shape=xla_shape.CreateShapeFromDtypeAndTuple( - dtype=np.dtype(tensor.dtype.as_numpy_dtype), - shape_tuple=tile_shape), tile_assignment=tile_assignment) def _tag_sharding_attribute_for_dequeued_tensors(self, dequeues, dims): diff --git a/tensorflow/contrib/tpu/python/tpu/training_loop.py b/tensorflow/contrib/tpu/python/tpu/training_loop.py index b6c350ecd75..0187b4bec6e 100644 --- a/tensorflow/contrib/tpu/python/tpu/training_loop.py +++ b/tensorflow/contrib/tpu/python/tpu/training_loop.py @@ -166,8 +166,8 @@ def while_loop(condition, body, inputs=None, infeed_queue=None, name=None): # control dependencies from any side-effecting operations. if input_arity == 0: inputs = [array_ops.constant(0)] - return control_flow_ops.while_loop(condition_wrapper, body_wrapper, inputs, - name="") + return control_flow_ops.while_loop( + condition_wrapper, body_wrapper, inputs, name="", parallel_iterations=1) def repeat(n, body, inputs=None, infeed_queue=None, name=None): diff --git a/tensorflow/contrib/tpu/tpu_estimator.md b/tensorflow/contrib/tpu/tpu_estimator.md index b6514e19dc9..552febd80bd 100644 --- a/tensorflow/contrib/tpu/tpu_estimator.md +++ b/tensorflow/contrib/tpu/tpu_estimator.md @@ -89,12 +89,9 @@ handle training: dataset = tf.data.TFRecordDataset( filename, buffer_size=FLAGS.dataset_reader_buffer_size) - dataset = dataset.map(parser).cache().repeat().batch(batch_size) - images, labels = dataset.make_one_shot_iterator().get_next() - # set_shape to give inputs statically known shapes. - images.set_shape([batch_size, 28 * 28]) - labels.set_shape([batch_size]) - return images, labels + dataset = dataset.map(parser).cache().repeat().batch( + batch_size, drop_remainder=True) + return dataset return input_fn diff --git a/tensorflow/contrib/training/BUILD b/tensorflow/contrib/training/BUILD index 00295f57f60..f6427ae05a2 100644 --- a/tensorflow/contrib/training/BUILD +++ b/tensorflow/contrib/training/BUILD @@ -26,7 +26,6 @@ py_library( "python/training/resample.py", "python/training/sampling_ops.py", "python/training/sequence_queueing_state_saver.py", - "python/training/tensor_queue_dataset.py", "python/training/training.py", "python/training/tuner.py", ], @@ -287,28 +286,6 @@ py_test( ], ) -py_test( - name = "tensor_queue_dataset_test", - size = "large", - srcs = ["python/training/tensor_queue_dataset_test.py"], - srcs_version = "PY2AND3", - tags = ["notsan"], - deps = [ - ":training_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:gradients", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform", - "//tensorflow/python:random_seed", - "//tensorflow/python:training", - "//tensorflow/python:variables", - "//tensorflow/python/data", - "//tensorflow/python/data/experimental/kernel_tests/serialization:dataset_serialization_test_base", - "//third_party/py/numpy", - ], -) - tf_proto_library( name = "protos_all", srcs = glob(["**/*.proto"]), diff --git a/tensorflow/contrib/training/__init__.py b/tensorflow/contrib/training/__init__.py index 3547e71184e..87ce57ef060 100644 --- a/tensorflow/contrib/training/__init__.py +++ b/tensorflow/contrib/training/__init__.py @@ -59,8 +59,6 @@ from tensorflow.contrib.training.python.training.hparam import * from tensorflow.contrib.training.python.training.resample import * from tensorflow.contrib.training.python.training.sampling_ops import * from tensorflow.contrib.training.python.training.sequence_queueing_state_saver import * -from tensorflow.contrib.training.python.training.tensor_queue_dataset import enqueue_in_queue_dataset -from tensorflow.contrib.training.python.training.tensor_queue_dataset import prepend_from_queue_and_padded_batch_dataset from tensorflow.contrib.training.python.training.training import add_gradients_summaries from tensorflow.contrib.training.python.training.training import clip_gradient_norms from tensorflow.contrib.training.python.training.training import clip_gradient_norms_fn @@ -79,7 +77,6 @@ _allowed_symbols = [ 'FeedingQueueRunner', 'get_or_create_eval_step', 'StopAfterNEvalsHook', 'SummaryAtEndHook', 'wait_for_new_checkpoint', 'add_gradients_summaries', 'clip_gradient_norms', 'clip_gradient_norms_fn', 'create_train_op', - 'multiply_gradients', 'enqueue_in_queue_dataset', - 'prepend_from_queue_and_padded_batch_dataset', 'train'] + 'multiply_gradients', 'train'] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/training/python/training/tensor_queue_dataset.py b/tensorflow/contrib/training/python/training/tensor_queue_dataset.py deleted file mode 100644 index 8896a95327a..00000000000 --- a/tensorflow/contrib/training/python/training/tensor_queue_dataset.py +++ /dev/null @@ -1,201 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Python wrappers for Datasets and Iterators.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import convert -from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.util import nest as tf_nest - - -class _PrependFromQueueAndPaddedBatchDataset(dataset_ops.UnaryDataset): - """A `Dataset` that prepends a queue to another `Dataset`. - - A vector of handles to the queue is returned as the first component of - the associated iterator. This vector can be passed to - `enqueue_in_queue_dataset` to add new elements to the queue. - """ - - def __init__(self, input_dataset, batch_size, padded_shapes, padding_values): - """Initialize `PrependFromQueueAndPaddedBatchDataset`.""" - super(_PrependFromQueueAndPaddedBatchDataset, self).__init__(input_dataset) - if sparse.any_sparse(input_dataset.output_classes): - raise TypeError( - "Batching of padded sparse tensors is not currently supported") - self._input_dataset = input_dataset - self._batch_size = ops.convert_to_tensor( - batch_size, dtype=dtypes.int64, name="batch_size") - if padded_shapes is None: - self._padded_shapes = nest.map_structure( - convert.partial_shape_to_tensor, input_dataset.output_shapes) - else: - self._padded_shapes = nest.map_structure_up_to( - input_dataset.output_shapes, convert.partial_shape_to_tensor, - padded_shapes) - # pylint: disable=protected-access - padding_values = ( - padding_values if padding_values is not None else - dataset_ops._default_padding(input_dataset)) - self._padding_values = nest.map_structure_up_to( - input_dataset.output_shapes, dataset_ops._padding_value_to_tensor, - padding_values, input_dataset.output_types) - # pylint: enable=protected-access - - def _as_variant_tensor(self): - # pylint: disable=protected-access - return gen_dataset_ops.prepend_from_queue_and_padded_batch_dataset( - self._input_dataset._as_variant_tensor(), - batch_size=self._batch_size, - padded_shapes=[ - ops.convert_to_tensor(s, dtype=dtypes.int64) - for s in nest.flatten(self._padded_shapes) - ], - padding_values=nest.flatten(self._padding_values), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) - # pylint: enable=protected-access - - @property - def output_classes(self): - return (ops.Tensor, self._input_dataset.output_classes) - - def _as_batch_shape(self, shape_like): - return tensor_shape.vector(None).concatenate( - tensor_util.constant_value_as_shape(shape_like)) - - @property - def output_shapes(self): - # First output is a variant representing the Queue - return (tensor_shape.vector(None), - nest.map_structure(self._as_batch_shape, self._padded_shapes)) - - @property - def output_types(self): - # First output is a variant representing the Queue - return (dtypes.variant, self._input_dataset.output_types) - - -def prepend_from_queue_and_padded_batch_dataset(batch_size, - padding_values=None, - padded_shapes=None): - """A transformation that prepends a queue to a `Dataset` and batches results. - - A vector of handles to the queue is returned as the first component of the - associated iterator. This vector can be passed to `enqueue_in_queue_dataset` - to add new elements to the queue. - - Below is an example of how this dataset might be used to split incoming - variable-length sequences into "head" and "rest" parts, where "rest" parts - are re-enqueued back into the dataset. A more realistic example would - perform some calculation on the "head" and modify some components of "rest" - with the result (before re-enqueueing). - - ```python - dataset = tf.data.Dataset.from_tensor_slices([2*x for x in range(10)]) - # Make a dataset of variable-length vectors and their lengths. - dataset = dataset.map(lambda count: (count, tf.ones((count,)))) - # Emit a queue we can prepend to, and counts/values as padded batch. - dataset = dataset.apply( - tf.contrib.training.prepend_from_queue_and_padded_batch_dataset( - batch_size=10)) - dataset = dataset.prefetch(1) - - iterator = dataset.make_one_shot_iterator() - queue, (count, padded_value) = iterator.get_next() - - # Split the padded_value into two pieces: head and rest - rest_indices = tf.squeeze(tf.where(count > 3), axis=1) - bound = tf.minimum(3, tf.reduce_max(count)) - value_head = padded_value[:, :bound] - count_rest = tf.gather(count - 3, rest_indices) - value_rest = tf.gather(padded_value[:, bound:], rest_indices) - queue_rest = tf.gather(queue, rest_indices) - enqueue_rest_op = tf.contrib.training.enqueue_in_queue_dataset( - queue_rest, (count_rest, value_rest)) - with tf.control_dependencies([enqueue_rest_op]): - calculation = fn(value_head) - - while True: # Will raise OutOfRange when finished with all pieces. - session.run(calculation) - ``` - - Args: - batch_size: `int64` scalar tensor. The batch size to use when performing - padded batching. - padding_values: (optional) Nested tuple of scalar tensors. If provided, - the structure and dtypes of padding_values should match that of - incoming dataset's `output_types`. - padded_shapes: (optional) Nested tuple of `int64` vector tensors. - If provided, the structure must match that of the incoming dataset's - `output_types`. If not provided, the incoming dataset's `output_shapes` - is used. Any unknown (`None` or `-1`) dimensions in the shapes are - treated as being unique per-batch: for each batch time, an unknown - dimension is replaced with the maximum given value of this dimension - across all tensors for the given component in the batch. - - Returns: - A `Dataset` transformation function, which can be passed to - `tf.data.Dataset.apply`. - """ - - def _apply_fn(dataset): - return _PrependFromQueueAndPaddedBatchDataset( - dataset, - batch_size=batch_size, - padding_values=padding_values, - padded_shapes=padded_shapes) - - return _apply_fn - - -def enqueue_in_queue_dataset(queue, components): - """Enqueue components into queue from `PrependFromQueueAndPaddedBatchDataset`. - - The components' dtypes and shapes must be compatible with the `output_shapes` - attribute of the `dataset` created by - `prepend_from_queue_and_padded_batch_dataset`. This operation supports both - non-batched and batched modes. - - For more details, see the example in the docstring for - `prepend_from_queue_and_padded_batch_dataset`. - - Args: - queue: `variant` scalar or vector tensor. - The tensor emitted by the first component of the iterator associated with - `prepend_from_queue_and_padded_batch_dataset`. If this is a scalar, - then the `components` input tensors should not have a prepended batch - dimension. - components: Nested tuple of tensors, each with a leading batch dimension - if `queue` is a vector. The structure, dtypes, and shapes - (excluding batch dimension) must match the nested tuples - `dataset.output_types[1]` and `dataset.output_shapes[1]` (the non-queue - output types and shapes) of the `dataset` emitted by - the original `prepend_from_queue_and_padded_batch_dataset` call. - - Returns: - An `Operation` that enqueues `components` into the dataset(s) associated - with entries of `queue`. - """ - return gen_dataset_ops.enqueue_in_queue_dataset( - queue=queue, components=tf_nest.flatten(components)) diff --git a/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py b/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py deleted file mode 100644 index c1657fec7bb..00000000000 --- a/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py +++ /dev/null @@ -1,355 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for TensorQueueDataset.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.training.python.training import tensor_queue_dataset as tqd -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import string_ops -from tensorflow.python.platform import test - - -class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase): - - def testNoEnqueue(self): - dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2]) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1)) - self.assertEqual((dtypes.variant, dtypes.int32), dataset.output_types) - self.assertAllEqual(([None],) * 2, - [x.as_list() for x in dataset.output_shapes]) - iterator = dataset.make_one_shot_iterator() - _, value = iterator.get_next() - self.assertEqual([0], self.evaluate(value)) - self.assertEqual([1], self.evaluate(value)) - self.assertEqual([2], self.evaluate(value)) - with self.assertRaisesOpError("End of sequence"): - self.evaluate(value) - - def testBatchedNoEnqueue(self): - dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2]) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=2)) - iterator = dataset.make_one_shot_iterator() - _, value = iterator.get_next() - self.assertAllEqual([0, 1], self.evaluate(value)) - self.assertAllEqual([2], self.evaluate(value)) - with self.assertRaisesOpError("End of sequence"): - self.evaluate(value) - - def testBatchedWithBiggerPaddingNoEnqueue(self): - dataset = dataset_ops.Dataset.from_tensor_slices([[0], [1], [2]]) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset( - batch_size=2, padded_shapes=[3])) - iterator = dataset.make_one_shot_iterator() - _, value = iterator.get_next() - self.assertAllEqual([[0, 0, 0], [1, 0, 0]], self.evaluate(value)) - self.assertAllEqual([[2, 0, 0]], self.evaluate(value)) - with self.assertRaisesOpError("End of sequence"): - self.evaluate(value) - - def testBatchedWithBiggerPaddingOneEnqueue(self): - dataset = dataset_ops.Dataset.from_tensor_slices([[0], [1], [2]]) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset( - batch_size=1, padded_shapes=[3])) - iterator = dataset.make_one_shot_iterator() - queue_handle, value = iterator.get_next() - enqueue_negative = tqd.enqueue_in_queue_dataset(queue_handle, -value) - with self.cached_session() as sess: - self.assertAllEqual([[0, 0, 0]], sess.run(value)) - value_1, _ = sess.run([value, enqueue_negative]) - self.assertAllEqual([[1, 0, 0]], value_1) - value_2, _ = sess.run([value, enqueue_negative]) - self.assertAllEqual([[-1, 0, 0]], value_2) - value_3 = sess.run(value) - self.assertAllEqual([[1, 0, 0]], value_3) - value_4, _ = sess.run([value, enqueue_negative]) - self.assertAllEqual([[2, 0, 0]], value_4) - value_5 = sess.run(value) - self.assertAllEqual([[-2, 0, 0]], value_5) - with self.assertRaisesOpError("End of sequence"): - sess.run(value) - - def testOneEnqueue(self): - dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2]) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1)) - iterator = dataset.make_one_shot_iterator() - queue_handle, value = iterator.get_next() - enqueue_negative = tqd.enqueue_in_queue_dataset(queue_handle, -value) - with self.cached_session() as sess: - self.assertEqual([0], sess.run(value)) - value_1, _ = sess.run([value, enqueue_negative]) - self.assertEqual([1], value_1) - value_2, _ = sess.run([value, enqueue_negative]) - self.assertEqual([-1], value_2) - value_3 = sess.run(value) - self.assertEqual([1], value_3) - value_4, _ = sess.run([value, enqueue_negative]) - self.assertEqual([2], value_4) - value_5 = sess.run(value) - self.assertEqual([-2], value_5) - with self.assertRaisesOpError("End of sequence"): - sess.run(value) - - def testBatchedOneEnqueue(self): - dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2]) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=2)) - iterator = dataset.make_one_shot_iterator() - queue_handle, value = iterator.get_next() - enqueue_negative = tqd.enqueue_in_queue_dataset(queue_handle, -value) - enqueue_zeroth = tqd.enqueue_in_queue_dataset([queue_handle[0]], - array_ops.expand_dims( - value[0], axis=0)) - with self.cached_session() as sess: - value_0, _ = sess.run([value, enqueue_negative]) - self.assertAllEqual([0, 1], value_0) - value_1, _ = sess.run([value, enqueue_zeroth]) - self.assertAllEqual([0, -1], value_1) - value_2, _ = sess.run([value, enqueue_negative]) - self.assertAllEqual([0, 2], value_2) - self.assertAllEqual([0, -2], sess.run(value)) - with self.assertRaisesOpError("End of sequence"): - sess.run(value) - - def testManyEnqueue(self): - dataset = dataset_ops.Dataset.from_tensor_slices([0, 1]) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1)) - iterator = dataset.make_one_shot_iterator() - queue_handle, value = iterator.get_next() - enqueue_many_more = [ - tqd.enqueue_in_queue_dataset(queue_handle, value + 100 + i) - for i in range(1000) - ] - with self.cached_session() as sess: - value_0, _ = sess.run((value, enqueue_many_more)) - self.assertEqual([0], value_0) - rest = [] - for _ in range(1000): - rest.append(sess.run(value)) - self.assertEquals([[100 + i] for i in range(1000)], sorted(rest)) - # Going back to the original input. - value_1, _ = sess.run((value, enqueue_many_more)) - self.assertEqual(1, value_1) - rest = [] - for _ in range(1000): - rest.append(sess.run(value)) - self.assertEquals([[100 + i + 1] for i in range(1000)], sorted(rest)) - with self.assertRaisesOpError("End of sequence"): - sess.run(value) - - def testEnqueueWithPrefetch(self): - dataset = dataset_ops.Dataset.from_tensor_slices([0]) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1)) - # Prefetching will request additional values before they are - # available to the queue. - dataset = dataset.prefetch(buffer_size=3) - iterator = dataset.make_one_shot_iterator() - queue_handle, value = iterator.get_next() - enqueue = tqd.enqueue_in_queue_dataset(queue_handle, value + 1) - with self.cached_session() as sess: - i = 0 - while i < 4: - received, _ = sess.run((value, enqueue)) - if received.size > 0: - self.assertAllEqual([i], received) - i += 1 - received_last = False - while True: - try: - received = sess.run(value) - if received.size > 0: - self.assertAllEqual([4], received) - received_last = True - except errors.OutOfRangeError: - break - self.assertTrue(received_last) - - def testDatasetWithPaddedShapeSmallerThanInputFails(self): - dataset = dataset_ops.Dataset.from_tensor_slices([[0, 0, 0]]).repeat(None) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset( - batch_size=1, padded_shapes=[2])) - iterator = dataset.make_one_shot_iterator() - _, value = iterator.get_next() - with self.cached_session() as sess: - with self.assertRaisesOpError( - r"Incompatible input shapes at component 0 between " - r"input dataset this dataset: \[3\] vs. \[2\]"): - sess.run(value) - - def testEnqueueWithIncompatibleInputsFailsWithInformativeError(self): - dataset = dataset_ops.Dataset.from_tensor_slices([0]).repeat(None) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1)) - iterator = dataset.make_one_shot_iterator() - queue_handle, value = iterator.get_next() - - enqueue_bad_structure = tqd.enqueue_in_queue_dataset( - queue_handle, (value, value)) - enqueue_bad_dtype = tqd.enqueue_in_queue_dataset(queue_handle, - np.array( - [1.0], - dtype=np.float32)) - enqueue_bad_shape_no_batch_dim = tqd.enqueue_in_queue_dataset( - queue_handle, ([1],)) - enqueue_bad_shape = tqd.enqueue_in_queue_dataset(queue_handle, - np.array( - [[1]], dtype=np.int32)) - - with self.cached_session() as sess: - with self.assertRaisesOpError( - "mismatched number of tensors. Queue expects 1 tensors but " - "tried to insert 2"): - sess.run(enqueue_bad_structure) - with self.assertRaisesOpError(r"Expected component 0 to have batched " - r"shape \[1,...\], but saw shape: \[\]"): - sess.run(enqueue_bad_shape_no_batch_dim) - with self.assertRaisesOpError( - r"mismatched shapes at component 0. Attempted to insert tensor " - r"with shape \[1\] but queue expected shape: \[\]"): - sess.run(enqueue_bad_shape) - with self.assertRaisesOpError( - r"mismatched dtypes at component 0. Attempted to insert tensor " - r"of type float but queue expected type: int32"): - sess.run(enqueue_bad_dtype) - - def testEnqueueWithPaddedBatchFailsWithInformativeError(self): - dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2]) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1)) - with self.assertRaisesRegexp( - TypeError, r"Unable to create padding for field of type 'variant'"): - dataset.padded_batch(batch_size=10, padded_shapes=[1]) - - def testOneEnqueueWithPadding(self): - dataset = dataset_ops.Dataset.from_tensor_slices([0, 2, 4, 6]) - # Make a dataset of variable-length vectors and their lengths. - dataset = dataset.map( - lambda c: (c, c * array_ops.ones((c,), dtype=c.dtype))) - # Emit a queue we can prepend to, and counts/values as padded - # batch. - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=3)) - - iterator = dataset.make_one_shot_iterator() - queue, (count, padded_value) = iterator.get_next() - - # Split the padded_value into two pieces: head and rest - rest_indices = array_ops.squeeze(array_ops.where(count > 2), axis=1) - bound = math_ops.minimum(2, math_ops.reduce_max(count)) - value_head = padded_value[:, :bound] - count_rest = array_ops.gather(count - 2, rest_indices) - value_rest = array_ops.gather(padded_value, rest_indices)[:, bound:] - queue_rest = array_ops.gather(queue, rest_indices) - enqueue_rest_op = tqd.enqueue_in_queue_dataset(queue_rest, - (count_rest, value_rest)) - with ops.control_dependencies([enqueue_rest_op]): - calc = array_ops.identity(value_head) - - with self.cached_session() as sess: - self.assertAllEqual([[0, 0], [2, 2], [4, 4]], sess.run(calc)) - self.assertAllEqual([[4, 4], [6, 6]], sess.run(calc)) - self.assertAllEqual([[6, 6]], sess.run(calc)) - self.assertAllEqual([[6, 6]], sess.run(calc)) - # Get some final batches due to prefetching. - for _ in range(3): - try: - self.assertAllEqual( - np.empty(shape=(0, 0), dtype=np.int32), sess.run(calc)) - except errors.OutOfRangeError as e: - self.assertTrue(str(e).startswith("End of sequence")) - - def testNonstandardPadding(self): - dataset = dataset_ops.Dataset.from_tensor_slices([0, 2, 4, 6]) - # Make a dataset of variable-length vectors and their lengths. - dataset = dataset.map( - lambda c: (c, c * array_ops.ones((c,), dtype=c.dtype))) - # Emit a queue we can prepend to, and counts/values as padded - # batch. - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset( - batch_size=3, padding_values=( - 0, - -1, - ))) - - iterator = dataset.make_one_shot_iterator() - _, (unused_count, padded_value) = iterator.get_next() - - with self.cached_session() as sess: - self.assertAllEqual([[-1, -1, -1, -1], [2, 2, -1, -1], [4, 4, 4, 4]], - sess.run(padded_value)) - self.assertAllEqual([[6] * 6], sess.run(padded_value)) - with self.assertRaisesOpError("End of sequence"): - sess.run(padded_value) - - -# TODO(ebrevdo): Figure out how to use run_core_tests to test state -# saving of an iterator that's had some tensors enqueued into its queue. -class PrependFromQueueAndPaddedBatchDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def testPrependFromQueueAndPaddedBatch(self): - - def build_dataset(seq_lens): - return dataset_ops.Dataset.from_tensor_slices(seq_lens).map( - lambda x: array_ops.fill([x], x)).apply( - tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=4)) - - seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32) - seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32) - self.run_core_tests(lambda: build_dataset(seq_lens1), - lambda: build_dataset(seq_lens2), 8) - - def testPrependFromQueueAndPaddedBatchNonDefaultPadding(self): - - def build_dataset(seq_lens): - - def fill_tuple(x): - filled = array_ops.fill([x], x) - return (filled, string_ops.as_string(filled)) - - padded_shape = [-1] - return dataset_ops.Dataset.from_tensor_slices(seq_lens).map( - fill_tuple).apply( - tqd.prepend_from_queue_and_padded_batch_dataset( - batch_size=4, - padded_shapes=(padded_shape, padded_shape), - padding_values=(-1, ""))) - - seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32) - seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32) - self.run_core_tests(lambda: build_dataset(seq_lens1), - lambda: build_dataset(seq_lens2), 8) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/verbs/rdma.cc b/tensorflow/contrib/verbs/rdma.cc index f7c979e8632..9db80f6b573 100644 --- a/tensorflow/contrib/verbs/rdma.cc +++ b/tensorflow/contrib/verbs/rdma.cc @@ -30,7 +30,6 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" #include "tensorflow/core/distributed_runtime/session_mgr.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" @@ -1028,7 +1027,10 @@ Status RdmaTensorResponse::PrepareRecvTensor( return errors::Aborted( "RecvTensor expects a different device incarnation: ", parsed.src_incarnation, " vs. ", (*src_dev)->attributes().incarnation(), - ". Your worker job was probably restarted. Check your " + ". Your worker job (\"", + channel_->adapter_->worker_env_->session_mgr->LegacySession() + ->worker_name, + "\") was probably restarted. Check your " "worker job for the reason why it was restarted."); } diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index a701b38d4b3..575edfe7a93 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -95,7 +95,8 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu") load("//tensorflow:tensorflow.bzl", "tf_cc_tests_gpu") load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") load("//tensorflow:tensorflow.bzl", "tf_version_info_genrule") -load("//tensorflow:tensorflow.bzl", "if_not_tx2_llvm_or_windows_cuda") +load("//tensorflow:tensorflow.bzl", "if_nccl") +load("//tensorflow:tensorflow.bzl", "tensorflow_opensource_extra_deps") load("//tensorflow:tensorflow.bzl", "tf_cuda_only_cc_test") # For platform specific build config @@ -112,6 +113,7 @@ load( "tf_additional_device_tracer_test_flags", "tf_additional_gdr_lib_defines", "tf_additional_human_readable_json_deps", + "tf_additional_logger_deps", "tf_additional_lib_defines", "tf_additional_lib_deps", "tf_additional_lib_hdrs", @@ -300,6 +302,7 @@ filegroup( "platform/env_time.h", "platform/logging.h", "platform/macros.h", + "platform/platform_strings.h", "platform/types.h", ], visibility = ["//visibility:private"], @@ -442,6 +445,18 @@ cc_library( ] + tf_additional_human_readable_json_deps(), ) +cc_library( + name = "logger", + srcs = tf_platform_srcs(["logger.cc"]), + hdrs = ["platform/logger.h"] + tf_platform_hdrs(["logger.h"]), + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = [ + ":lib", + ":lib_internal", + ] + tf_additional_logger_deps(), +) + filegroup( name = "platform_env_hdrs", srcs = [ @@ -519,6 +534,19 @@ cc_library( ], ) +cc_library( + name = "platform_strings", + srcs = tf_platform_srcs([ + "platform/platform_strings.cc", + "platform/platform_strings_computed.h", + ]), + hdrs = [ + "platform/platform_strings.h", + ], + visibility = ["//tensorflow/core:__subpackages__"], + deps = [":lib"], +) + filegroup( name = "platform_other_hdrs", srcs = [ @@ -841,6 +869,7 @@ tf_cuda_library( "framework/dataset_stateful_op_whitelist.h", "framework/device_base.h", "framework/function.h", + "framework/function_handle_cache.h", "framework/graph_def_util.h", "framework/graph_to_functiondef.h", "framework/kernel_def_builder.h", @@ -884,6 +913,7 @@ tf_cuda_library( "util/bcast.h", "util/cuda_kernel_helper.h", "util/device_name_utils.h", + "util/dump_graph.h", "util/events_writer.h", "util/example_proto_fast_parsing.h", "util/example_proto_helper.h", @@ -901,6 +931,7 @@ tf_cuda_library( "util/stream_executor_util.h", "util/strided_slice_op.h", "util/tensor_format.h", + "util/tensor_ops_util.h", "util/tensor_slice_reader.h", "util/tensor_slice_reader_cache.h", "util/tensor_slice_writer.h", @@ -1038,6 +1069,7 @@ tf_gen_op_libs( "batch_ops", "bitwise_ops", "boosted_trees_ops", + "tensor_forest_ops", "candidate_sampling_ops", "checkpoint_ops", "collective_ops", @@ -1085,7 +1117,11 @@ tf_gen_op_libs( op_lib_names = [ "string_ops", ], - deps = ["@com_google_absl//absl/strings"], + deps = [ + ":lib_internal", + ":lib_proto_parsing", + "@com_google_absl//absl/strings", + ], ) tf_gen_op_libs( @@ -1187,6 +1223,7 @@ cc_library( ":batch_ops_op_lib", ":bitwise_ops_op_lib", ":boosted_trees_ops_op_lib", + ":tensor_forest_ops_op_lib", ":candidate_sampling_ops_op_lib", ":checkpoint_ops_op_lib", ":collective_ops_op_lib", @@ -1340,6 +1377,7 @@ cc_library( "//tensorflow/core/kernels:batch_kernels", "//tensorflow/core/kernels:bincount_op", "//tensorflow/core/kernels:boosted_trees_ops", + "//tensorflow/core/kernels:tensor_forest_ops", "//tensorflow/core/kernels:candidate_sampler_ops", "//tensorflow/core/kernels:checkpoint_ops", "//tensorflow/core/kernels:collective_ops", @@ -1386,9 +1424,7 @@ cc_library( "//tensorflow/core/kernels:summary_kernels", "//tensorflow/core/kernels:training_ops", "//tensorflow/core/kernels:word2vec_kernels", - ] + tf_additional_cloud_kernel_deps() + if_not_tx2_llvm_or_windows_cuda([ - "//tensorflow/core/kernels:nccl_kernels", - ]) + if_not_windows([ + ] + tf_additional_cloud_kernel_deps() + if_not_windows([ "//tensorflow/core/kernels:fact_op", "//tensorflow/core/kernels:array_not_windows", "//tensorflow/core/kernels:math_not_windows", @@ -1413,6 +1449,8 @@ cc_library( ]) + if_cuda([ "//tensorflow/core/grappler/optimizers:gpu_swapping_kernels", "//tensorflow/core/grappler/optimizers:gpu_swapping_ops", + ]) + if_nccl([ + "//tensorflow/core/kernels:nccl_kernels", ]), ) @@ -1437,7 +1475,7 @@ tf_cuda_library( ":gpu_runtime", ":lib", ":ops", - ], + ] + tensorflow_opensource_extra_deps(), ) cc_library( @@ -1577,6 +1615,8 @@ filegroup( "util/stats_calculator.*", "util/reporter.*", "platform/**/cuda_libdevice_path.*", + "platform/**/logger.cc", + "platform/**/logger.h", "platform/default/test_benchmark.*", "platform/cuda.h", "platform/google/**/*", @@ -1671,8 +1711,8 @@ cc_library( cc_library( name = "mobile_additional_lib_deps", deps = tf_additional_lib_deps() + [ + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", - "@com_google_absl//absl/time", ], ) @@ -1763,7 +1803,7 @@ cc_library( # registration of ops to prune code size. cc_library( name = "android_tensorflow_lib_selective_registration", - srcs = if_android(["//tensorflow/core:android_srcs"]), + srcs = if_android(["//tensorflow/core:android_srcs_only_runtime"]), copts = tf_copts(android_optimization_level_override = None) + [ "-DSUPPORT_SELECTIVE_REGISTRATION", ], @@ -1775,9 +1815,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":protos_all_cc_impl", - "//third_party/eigen3", - "@double_conversion//:double-conversion", - "@nsync//:nsync_cpp", + "@com_google_absl//absl/container:flat_hash_set", "@protobuf_archive//:protobuf", ], alwayslink = 1, @@ -1787,7 +1825,7 @@ cc_library( # no proto_rtti. cc_library( name = "android_tensorflow_lib_selective_registration_nortti", - srcs = if_android(["//tensorflow/core:android_srcs"]), + srcs = if_android(["//tensorflow/core:android_srcs_only_runtime"]), copts = tf_copts(android_optimization_level_override = None) + tf_opts_nortti_if_android() + [ "-DSUPPORT_SELECTIVE_REGISTRATION", ], @@ -1799,9 +1837,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":protos_all_cc_impl", - "//third_party/eigen3", - "@double_conversion//:double-conversion", - "@nsync//:nsync_cpp", + "@com_google_absl//absl/container:flat_hash_set", "@protobuf_archive//:protobuf", ], alwayslink = 1, @@ -2045,9 +2081,7 @@ tf_proto_library_cc( srcs = ["protobuf/master.proto"], cc_api_version = 2, protodeps = tf_additional_all_protos(), - visibility = [ - "//tensorflow:internal", - ], + visibility = ["//tensorflow:internal"], ) tf_proto_library_cc( @@ -2187,6 +2221,7 @@ cc_library( "platform/**/env_time.cc", "platform/**/cuda_libdevice_path.cc", "platform/**/device_tracer.cc", + "platform/**/logger.cc", "platform/**/logging.cc", "platform/**/human_readable_json.cc", "platform/abi.cc", @@ -2199,6 +2234,7 @@ cc_library( "platform/**/stream_executor.h", "platform/**/env_time.cc", "platform/**/device_tracer.cc", + "platform/**/logger.cc", "platform/**/logging.cc", "platform/**/human_readable_json.cc", "platform/abi.cc", @@ -2641,6 +2677,8 @@ tf_cuda_library( ":stats_calculator_portable", ":version_lib", "@com_google_absl//absl/base", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", "//tensorflow/core/platform/default/build_config:platformlib", "//tensorflow/core/kernels:bounds_check", "//third_party/eigen3", @@ -2943,6 +2981,7 @@ tf_cuda_library( ":lib_internal", ":proto_text", ":protos_all_cc", + "@com_google_absl//absl/memory", "//third_party/eigen3", "//tensorflow/core/grappler:grappler_item", ] + mkl_deps(), @@ -3008,7 +3047,6 @@ tf_cuda_library( hdrs = ["common_runtime/metrics.h"], deps = [ ":lib", - "@com_google_absl//absl/time", ], ) @@ -3033,7 +3071,6 @@ tf_cuda_library( ":protos_all_cc", "//tensorflow/core/debug:debug_graph_utils", "//tensorflow/core/kernels:function_ops", - "@com_google_absl//absl/time", ], alwayslink = 1, ) @@ -3393,6 +3430,7 @@ tf_cc_tests( "platform/profile_utils/cpu_utils_test.cc", "platform/stacktrace_handler_test.cc", "platform/subprocess_test.cc", + "platform/vmodule_benchmark_test.cc", ], deps = [ ":lib", @@ -3406,6 +3444,20 @@ tf_cc_tests( ], ) +tf_cc_test( + name = "vmodule_test", + srcs = ["platform/vmodule_test.cc"], + tags = ["optonly"], + deps = [ + ":lib", + ":lib_internal", + ":lib_test_internal", + ":protos_all_cc", + ":test", + "//third_party/eigen3", + ], +) + tf_cc_test( name = "lib_random_random_distributions_test", srcs = ["lib/random/random_distributions_test.cc"], @@ -3421,6 +3473,16 @@ tf_cc_test( ], ) +tf_cc_test( + name = "platform_strings_test", + size = "small", + srcs = ["platform/platform_strings_test.cc"], + deps = [ + ":lib", + ":platform_strings", + ], +) + tf_cc_test( name = "platform_env_test", size = "small", @@ -3668,6 +3730,7 @@ tf_cc_tests( "util/bcast_test.cc", "util/command_line_flags_test.cc", "util/device_name_utils_test.cc", + "util/dump_graph_test.cc", "util/equal_graph_def_test.cc", "util/events_writer_test.cc", "util/example_proto_fast_parsing_test.cc", @@ -3798,6 +3861,7 @@ tf_cc_tests_gpu( ":test", ":test_main", ":testlib", + "@com_google_absl//absl/memory", ], ) @@ -3826,6 +3890,7 @@ tf_cc_tests_gpu( ":test", ":test_main", ":testlib", + "@com_google_absl//absl/memory", ], ) @@ -4099,6 +4164,7 @@ tf_cc_test( "//tensorflow/core/kernels:identity_op", "//tensorflow/core/kernels:immutable_constant_op", "//tensorflow/core/kernels:matmul_op", + "//tensorflow/core/kernels:topk_op", "//third_party/eigen3", ], ) @@ -4392,6 +4458,7 @@ tf_cc_test( "//tensorflow/core/kernels:random_ops", "//tensorflow/core/kernels:shape_ops", "//third_party/eigen3", + "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], ) @@ -4871,6 +4938,7 @@ transitive_hdrs( "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:platform_strings", "//tensorflow/core:protos_all_cc", "//tensorflow/core:stream_executor", ], diff --git a/tensorflow/core/api_def/api_test.cc b/tensorflow/core/api_def/api_test.cc index 6f988569159..d38a8424eb1 100644 --- a/tensorflow/core/api_def/api_test.cc +++ b/tensorflow/core/api_def/api_test.cc @@ -182,11 +182,14 @@ void TestDeprecationVersionSetCorrectly( for (const auto& name_and_api_def : api_defs_map) { const auto& name = name_and_api_def.first; const auto& api_def = name_and_api_def.second; - ASSERT_TRUE(api_def.deprecation_version() == 0 || - api_def.deprecation_message().empty()) - << "ApiDef that includes deprecation_version > 0 must also specify " - << "a deprecation_message. Op " << name - << " has deprecation_version > 0 but deprecation_message is not set."; + if (api_def.deprecation_version() != 0) { + ASSERT_TRUE(api_def.deprecation_version() > 0) + << "Found ApiDef with negative deprecation_version"; + ASSERT_FALSE(api_def.deprecation_message().empty()) + << "ApiDef that includes deprecation_version > 0 must also specify " + << "a deprecation_message. Op " << name + << " has deprecation_version > 0 but deprecation_message is not set."; + } } } } // namespace diff --git a/tensorflow/core/api_def/base_api/api_def_BatchDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_BatchDataset.pbtxt index 639d962874d..32def912f83 100644 --- a/tensorflow/core/api_def/base_api/api_def_BatchDataset.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_BatchDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "BatchDataset" + visibility: HIDDEN in_arg { name: "batch_size" description: <