diff --git a/.bazelrc b/.bazelrc index d4d7ad61867..27172e929b0 100644 --- a/.bazelrc +++ b/.bazelrc @@ -39,32 +39,46 @@ build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=0 build:download_clang --crosstool_top=@local_config_download_clang//:toolchain build:download_clang --define=using_clang=true +build:download_clang --action_env TF_DOWNLOAD_CLANG=1 # Instruct clang to use LLD for linking. # This only works with GPU builds currently, since Bazel sets -B/usr/bin in # auto-generated CPU crosstool, forcing /usr/bin/ld.lld to be preferred over # the downloaded one. build:download_clang_use_lld --linkopt='-fuse-ld=lld' -build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain -build:cuda --define=using_cuda=true --define=using_cuda_nvcc=true +# This config refers to building with CUDA available. It does not necessarily +# mean that we build CUDA op kernels. +build:using_cuda --define=using_cuda=true +build:using_cuda --action_env TF_NEED_CUDA=1 +build:using_cuda --crosstool_top=@local_config_cuda//crosstool:toolchain + +# This config refers to building CUDA op kernels with nvcc. +build:cuda --config=using_cuda +build:cuda --define=using_cuda_nvcc=true + +# This config refers to building CUDA op kernels with clang. +build:cuda_clang --config=using_cuda +build:cuda_clang --define=using_cuda_clang=true +build:cuda_clang --define=using_clang=true + +build:tensorrt --action_env TF_NEED_TENSORRT=1 build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true - -build:cuda_clang --crosstool_top=@local_config_cuda//crosstool:toolchain -build:cuda_clang --define=using_cuda=true --define=using_cuda_clang=true --define=using_clang=true +build:rocm --action_env TF_NEED_ROCM=1 build:sycl --crosstool_top=@local_config_sycl//crosstool:toolchain -build:sycl --define=using_sycl=true --define=using_trisycl=false +build:sycl --define=using_sycl=true +build:sycl --action_env TF_NEED_OPENCL_SYCL=1 -build:sycl_nodouble --crosstool_top=@local_config_sycl//crosstool:toolchain -build:sycl_nodouble --define=using_sycl=true --cxxopt -DTENSORFLOW_SYCL_NO_DOUBLE +build:sycl_nodouble --config=sycl +build:sycl_nodouble --cxxopt -DTENSORFLOW_SYCL_NO_DOUBLE -build:sycl_asan --crosstool_top=@local_config_sycl//crosstool:toolchain -build:sycl_asan --define=using_sycl=true --define=using_trisycl=false --copt -fno-omit-frame-pointer --copt -fsanitize-coverage=3 --copt -DGPR_NO_DIRECT_SYSCALLS --linkopt -fPIC --linkopt -fsanitize=address +build:sycl_nodouble --config=sycl +build:sycl_asan --copt -fno-omit-frame-pointer --copt -fsanitize-coverage=3 --copt -DGPR_NO_DIRECT_SYSCALLS --linkopt -fPIC --linkopt -fsanitize=address -build:sycl_trisycl --crosstool_top=@local_config_sycl//crosstool:toolchain -build:sycl_trisycl --define=using_sycl=true --define=using_trisycl=true +build:sycl_nodouble --config=sycl +build:sycl_trisycl --define=using_trisycl=true # Options extracted from configure script build:gdr --define=with_gdr_support=true @@ -87,6 +101,9 @@ build --spawn_strategy=standalone build --strategy=Genrule=standalone build -c opt +# Make Bazel print out all options from rc files. +build --announce_rc + # Other build flags. build --define=grpc_no_ares=true @@ -97,8 +114,7 @@ build:dynamic_kernels --copt=-DAUTOLOAD_DYNAMIC_KERNELS # Build TF with C++ 17 features. build:c++17 --cxxopt=-std=c++1z build:c++17 --cxxopt=-stdlib=libc++ -build:c++1z --cxxopt=-std=c++1z -build:c++1z --cxxopt=-stdlib=libc++ +build:c++1z --config=c++17 # Default paths for TF_SYSTEM_LIBS build --define=PREFIX=/usr diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 73782143a3d..b460bdde24f 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -38,7 +38,13 @@ working on getting your pull request submitted to our internal repository. After the change has been submitted internally, your pull request will be merged automatically on GitHub. -If you want to contribute but you're not sure where to start, take a look at the +If you want to contribute, start working through the TensorFlow codebase, +navigate to the +[Github "issues" tab](https://github.com/tensorflow/tensorflow/issues) and start +looking through interesting issues. If you are not sure of where to start, then +start by trying one of the smaller/easier issues here i.e. +[issues with the "good first issue" label](https://github.com/tensorflow/tensorflow/labels/good%20first%20issue) +and then take a look at the [issues with the "contributions welcome" label](https://github.com/tensorflow/tensorflow/labels/stat%3Acontributions%20welcome). These are issues that we believe are particularly well suited for outside contributions, often because we probably won't get to them right now. If you diff --git a/configure.py b/configure.py index 2120a4b27d6..8d6772c199b 100644 --- a/configure.py +++ b/configure.py @@ -403,7 +403,8 @@ def set_action_env_var(environ_cp, enabled_by_default, question=None, yes_reply=None, - no_reply=None): + no_reply=None, + bazel_config_name=None): """Set boolean action_env variable. Ask user if query_item will be enabled. Default is used if no input is given. @@ -418,12 +419,16 @@ def set_action_env_var(environ_cp, question: optional string for how to ask for user input. yes_reply: optional string for reply when feature is enabled. no_reply: optional string for reply when feature is disabled. + bazel_config_name: adding config to .bazelrc instead of action_env. """ var = int( get_var(environ_cp, var_name, query_item, enabled_by_default, question, yes_reply, no_reply)) - write_action_env_to_bazelrc(var_name, var) + if not bazel_config_name: + write_action_env_to_bazelrc(var_name, var) + elif var: + write_to_bazelrc('build --config=%s' % bazel_config_name) environ_cp[var_name] = str(var) @@ -543,7 +548,8 @@ def set_tf_cuda_clang(environ_cp): False, question=question, yes_reply=yes_reply, - no_reply=no_reply) + no_reply=no_reply, + bazel_config_name='cuda_clang') def set_tf_download_clang(environ_cp): @@ -558,7 +564,8 @@ def set_tf_download_clang(environ_cp): False, question=question, yes_reply=yes_reply, - no_reply=no_reply) + no_reply=no_reply, + bazel_config_name='download_clang') def get_from_env_or_user_or_default(environ_cp, var_name, ask_for_var, @@ -782,8 +789,8 @@ def get_ndk_api_level(environ_cp, android_ndk_home_path): print('WARNING: The NDK version in %s is %s, which is not ' 'supported by Bazel (officially supported versions: %s). Please use ' 'another version. Compiling Android targets may result in confusing ' - 'errors.\n' % (android_ndk_home_path, ndk_version, - _SUPPORTED_ANDROID_NDK_VERSIONS)) + 'errors.\n' % + (android_ndk_home_path, ndk_version, _SUPPORTED_ANDROID_NDK_VERSIONS)) # Now grab the NDK API level to use. Note that this is different from the # SDK API level, as the NDK API level is effectively the *min* target SDK @@ -952,6 +959,7 @@ def set_tf_nccl_version(environ_cp): ask_nccl_version, '') environ_cp['TF_NCCL_VERSION'] = tf_nccl_version + def get_native_cuda_compute_capabilities(environ_cp): """Get native cuda compute capabilities. @@ -1293,9 +1301,6 @@ def configure_ios(): """ if not is_macos(): return - if _TF_CURRENT_BAZEL_VERSION is None or _TF_CURRENT_BAZEL_VERSION < 23000: - print( - 'Building Bazel rules on Apple platforms requires Bazel 0.23 or later.') for filepath in APPLE_BAZEL_FILES: existing_filepath = os.path.join(_TF_WORKSPACE_ROOT, filepath + '.apple') renamed_filepath = os.path.join(_TF_WORKSPACE_ROOT, filepath) @@ -1386,7 +1391,7 @@ def main(): # environment variables. environ_cp = dict(os.environ) - current_bazel_version = check_bazel_version('0.24.1', '0.25.2') + current_bazel_version = check_bazel_version('0.24.1', '0.25.3') _TF_CURRENT_BAZEL_VERSION = convert_version_to_int(current_bazel_version) reset_tf_configure_bazelrc() @@ -1422,7 +1427,12 @@ def main(): set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support', xla_enabled_by_default, 'xla') - set_action_env_var(environ_cp, 'TF_NEED_OPENCL_SYCL', 'OpenCL SYCL', False) + set_action_env_var( + environ_cp, + 'TF_NEED_OPENCL_SYCL', + 'OpenCL SYCL', + False, + bazel_config_name='sycl') if environ_cp.get('TF_NEED_OPENCL_SYCL') == '1': set_host_cxx_compiler(environ_cp) set_host_c_compiler(environ_cp) @@ -1432,30 +1442,44 @@ def main(): else: set_trisycl_include_dir(environ_cp) - set_action_env_var(environ_cp, 'TF_NEED_ROCM', 'ROCm', False) + set_action_env_var( + environ_cp, 'TF_NEED_ROCM', 'ROCm', False, bazel_config_name='rocm') if (environ_cp.get('TF_NEED_ROCM') == '1' and 'LD_LIBRARY_PATH' in environ_cp and environ_cp.get('LD_LIBRARY_PATH') != '1'): write_action_env_to_bazelrc('LD_LIBRARY_PATH', environ_cp.get('LD_LIBRARY_PATH')) - set_action_env_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False) + environ_cp['TF_NEED_CUDA'] = str( + int(get_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False))) if (environ_cp.get('TF_NEED_CUDA') == '1' and 'TF_CUDA_CONFIG_REPO' not in environ_cp): - set_action_env_var(environ_cp, 'TF_NEED_TENSORRT', 'TensorRT', False) + set_action_env_var( + environ_cp, + 'TF_NEED_TENSORRT', + 'TensorRT', + False, + bazel_config_name='tensorrt') environ_save = dict(environ_cp) for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): if validate_cuda_config(environ_cp): cuda_env_names = [ - 'TF_CUDA_VERSION', 'TF_CUBLAS_VERSION', 'TF_CUDNN_VERSION', - 'TF_TENSORRT_VERSION', 'TF_NCCL_VERSION', 'TF_CUDA_PATHS', + 'TF_CUDA_VERSION', + 'TF_CUBLAS_VERSION', + 'TF_CUDNN_VERSION', + 'TF_TENSORRT_VERSION', + 'TF_NCCL_VERSION', + 'TF_CUDA_PATHS', # Items below are for backwards compatibility when not using # TF_CUDA_PATHS. - 'CUDA_TOOLKIT_PATH', 'CUDNN_INSTALL_PATH', 'NCCL_INSTALL_PATH', - 'NCCL_HDR_PATH', 'TENSORRT_INSTALL_PATH' + 'CUDA_TOOLKIT_PATH', + 'CUDNN_INSTALL_PATH', + 'NCCL_INSTALL_PATH', + 'NCCL_HDR_PATH', + 'TENSORRT_INSTALL_PATH' ] # Note: set_action_env_var above already writes to bazelrc. for name in cuda_env_names: @@ -1506,8 +1530,6 @@ def main(): # CUDA not required. Ask whether we should download the clang toolchain and # use it for the CPU build. set_tf_download_clang(environ_cp) - if environ_cp.get('TF_DOWNLOAD_CLANG') == '1': - write_to_bazelrc('build --config=download_clang') # SYCL / ROCm / CUDA are mutually exclusive. # At most 1 GPU platform can be configured. diff --git a/tensorflow/api_template_v1.__init__.py b/tensorflow/api_template_v1.__init__.py index 8a14abc3c2c..a83ff3a16c2 100644 --- a/tensorflow/api_template_v1.__init__.py +++ b/tensorflow/api_template_v1.__init__.py @@ -59,7 +59,7 @@ except ImportError: from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top _CONTRIB_WARNING = """ -WARNING: The TensorFlow contrib module will not be included in TensorFlow 2.0. +The TensorFlow contrib module will not be included in TensorFlow 2.0. For more information, please see: * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md * https://github.com/tensorflow/addons diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 6928cf5d0ac..99eb28c1295 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -21,6 +21,9 @@ filegroup( srcs = [ "c_api.h", "c_api_experimental.h", + "tf_attrtype.h", + "tf_datatype.h", + "tf_status.h", ], visibility = ["//tensorflow:__subpackages__"], ) @@ -51,6 +54,8 @@ tf_cuda_library( hdrs = [ "c_api.h", "c_api_internal.h", + "tf_datatype.h", + "tf_status.h", ], visibility = [ "//tensorflow:internal", @@ -61,6 +66,7 @@ tf_cuda_library( "//tensorflow/core:android_tensorflow_lib_lite", ], "//conditions:default": [ + ":tf_attrtype", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -71,16 +77,26 @@ tf_cuda_library( }), ) +cc_library( + name = "tf_attrtype", + hdrs = ["tf_attrtype.h"], + visibility = ["//visibility:public"], +) + tf_cuda_library( name = "c_api", hdrs = [ "c_api.h", + "tf_attrtype.h", + "tf_datatype.h", + "tf_status.h", ], copts = tf_copts(), visibility = ["//visibility:public"], deps = [ ":c_api_no_xla", ":c_api_internal", + ":tf_attrtype", ] + select({ "//tensorflow:with_xla_support": [ "//tensorflow/compiler/tf2xla:xla_compiler", @@ -96,14 +112,21 @@ tf_cuda_library( "c_api.cc", "c_api_function.cc", ], - hdrs = ["c_api.h"], + hdrs = [ + "c_api.h", + ], copts = tf_copts(), visibility = ["//tensorflow/c:__subpackages__"], - deps = [":c_api_internal"] + select({ + deps = [ + ":c_api_internal", + ":tf_attrtype", + ":tf_datatype", + ] + select({ "//tensorflow:android": [ "//tensorflow/core:android_tensorflow_lib_lite", ], "//conditions:default": [ + ":tf_status", "@com_google_absl//absl/strings", "//tensorflow/cc/saved_model:loader_lite", "//tensorflow/cc:gradients", @@ -124,6 +147,37 @@ tf_cuda_library( }), ) +cc_library( + name = "tf_status", + srcs = ["tf_status.cc"], + hdrs = ["tf_status.h"], + visibility = ["//visibility:public"], + deps = select({ + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib_lite", + ], + "//conditions:default": [ + "//tensorflow/c:c_api_internal", + "//tensorflow/core:lib", + ], + }), +) + +cc_library( + name = "tf_datatype", + srcs = ["tf_datatype.cc"], + hdrs = ["tf_datatype.h"], + visibility = ["//visibility:public"], + deps = select({ + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib_lite", + ], + "//conditions:default": [ + "//tensorflow/core:framework", + ], + }), +) + tf_cuda_library( name = "c_api_experimental", srcs = [ @@ -137,6 +191,7 @@ tf_cuda_library( deps = [ ":c_api", ":c_api_internal", + ":checkpoint_reader", "//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api_internal", "//tensorflow/compiler/jit:flags", @@ -151,15 +206,6 @@ tf_cuda_library( ], ) -cc_library( - name = "c_api_headers", - hdrs = [ - "c_api.h", - ], - copts = tf_copts(), - visibility = ["//tensorflow:__subpackages__"], -) - exports_files( [ "version_script.lds", diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 21d72ac96b5..4f519a7bd11 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/strings/match.h" // Required for IS_MOBILE_PLATFORM #include "tensorflow/core/platform/platform.h" // NOLINT @@ -97,7 +98,6 @@ using tensorflow::TensorId; using tensorflow::TensorShape; using tensorflow::TensorShapeProto; using tensorflow::VersionDef; -using tensorflow::error::Code; using tensorflow::errors::FailedPrecondition; using tensorflow::errors::InvalidArgument; using tensorflow::gtl::ArraySlice; @@ -108,34 +108,6 @@ extern "C" { // -------------------------------------------------------------------------- const char* TF_Version() { return TF_VERSION_STRING; } -// -------------------------------------------------------------------------- -size_t TF_DataTypeSize(TF_DataType dt) { - return static_cast( - tensorflow::DataTypeSize(static_cast(dt))); -} - -// -------------------------------------------------------------------------- - -TF_Status* TF_NewStatus() { return new TF_Status; } - -void TF_DeleteStatus(TF_Status* s) { delete s; } - -void TF_SetStatus(TF_Status* s, TF_Code code, const char* msg) { - if (code == TF_OK) { - s->status = Status::OK(); - return; - } - s->status = Status(static_cast(code), tensorflow::StringPiece(msg)); -} - -TF_Code TF_GetCode(const TF_Status* s) { - return static_cast(s->status.code()); -} - -const char* TF_Message(const TF_Status* s) { - return s->status.error_message().c_str(); -} - // -------------------------------------------------------------------------- namespace { @@ -1697,7 +1669,7 @@ TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper, if (metadata.list_size == 0) { for (int i = 0; i < oper->node.op_def().attr_size(); ++i) { const auto& a = oper->node.op_def().attr(i); - if (a.name().compare(attr_name) != 0) continue; + if (a.name() != attr_name) continue; const string& typestr = a.type(); if (typestr == "list(string)") { metadata.type = TF_ATTR_STRING; @@ -2517,8 +2489,7 @@ void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y, // used in this graph for (const auto& pair : g->name_map) { const string& name = pair.first; - if (name.compare(prefix) == 0 || - tensorflow::str_util::StartsWith(name, prefix_cmp)) { + if ((name == prefix) || absl::StartsWith(name, prefix_cmp)) { status->status = InvalidArgument( "prefix [", prefix, "] conflicts with existing node in the graph named [", name, "]"); @@ -2548,8 +2519,7 @@ void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y, // Adding the gradients to the graph can alter the prefix to prevent // name collisions only if this prefix has not been provided explicitly // by the user. If it was provided, assert that it remained intact. - if (prefix != nullptr && - !tensorflow::str_util::StartsWith(n->name(), prefix_cmp)) { + if (prefix != nullptr && !absl::StartsWith(n->name(), prefix_cmp)) { status->status = tensorflow::errors::Internal( "BUG: The gradients prefix have been unexpectedly altered when " "adding the nodes to the graph. This is a bug. Please file an " diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index 051de3a7dc0..9a538cb98db 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -19,6 +19,10 @@ limitations under the License. #include #include +#include "tensorflow/c/tf_attrtype.h" +#include "tensorflow/c/tf_datatype.h" +#include "tensorflow/c/tf_status.h" + // -------------------------------------------------------------------------- // C API for TensorFlow. // @@ -69,7 +73,7 @@ limitations under the License. // .dylib, .dll). // This duplicates the TF_EXPORT macro definition in // tensorflow/core/platform/macros.h in order to keep this .h file independent -// of any other includes.$a +// of any other includes. #ifdef SWIG #define TF_CAPI_EXPORT #else @@ -93,89 +97,6 @@ extern "C" { // TensorFlow library. TensorFlow using semantic versioning. TF_CAPI_EXPORT extern const char* TF_Version(void); -// -------------------------------------------------------------------------- -// TF_DataType holds the type for a scalar value. E.g., one slot in a tensor. -// The enum values here are identical to corresponding values in types.proto. -typedef enum TF_DataType { - TF_FLOAT = 1, - TF_DOUBLE = 2, - TF_INT32 = 3, // Int32 tensors are always in 'host' memory. - TF_UINT8 = 4, - TF_INT16 = 5, - TF_INT8 = 6, - TF_STRING = 7, - TF_COMPLEX64 = 8, // Single-precision complex - TF_COMPLEX = 8, // Old identifier kept for API backwards compatibility - TF_INT64 = 9, - TF_BOOL = 10, - TF_QINT8 = 11, // Quantized int8 - TF_QUINT8 = 12, // Quantized uint8 - TF_QINT32 = 13, // Quantized int32 - TF_BFLOAT16 = 14, // Float32 truncated to 16 bits. Only for cast ops. - TF_QINT16 = 15, // Quantized int16 - TF_QUINT16 = 16, // Quantized uint16 - TF_UINT16 = 17, - TF_COMPLEX128 = 18, // Double-precision complex - TF_HALF = 19, - TF_RESOURCE = 20, - TF_VARIANT = 21, - TF_UINT32 = 22, - TF_UINT64 = 23, -} TF_DataType; - -// TF_DataTypeSize returns the sizeof() for the underlying type corresponding -// to the given TF_DataType enum value. Returns 0 for variable length types -// (eg. TF_STRING) or on failure. -TF_CAPI_EXPORT extern size_t TF_DataTypeSize(TF_DataType dt); - -// -------------------------------------------------------------------------- -// TF_Code holds an error code. The enum values here are identical to -// corresponding values in error_codes.proto. -typedef enum TF_Code { - TF_OK = 0, - TF_CANCELLED = 1, - TF_UNKNOWN = 2, - TF_INVALID_ARGUMENT = 3, - TF_DEADLINE_EXCEEDED = 4, - TF_NOT_FOUND = 5, - TF_ALREADY_EXISTS = 6, - TF_PERMISSION_DENIED = 7, - TF_UNAUTHENTICATED = 16, - TF_RESOURCE_EXHAUSTED = 8, - TF_FAILED_PRECONDITION = 9, - TF_ABORTED = 10, - TF_OUT_OF_RANGE = 11, - TF_UNIMPLEMENTED = 12, - TF_INTERNAL = 13, - TF_UNAVAILABLE = 14, - TF_DATA_LOSS = 15, -} TF_Code; - -// -------------------------------------------------------------------------- -// TF_Status holds error information. It either has an OK code, or -// else an error code with an associated error message. -typedef struct TF_Status TF_Status; - -// Return a new status object. -TF_CAPI_EXPORT extern TF_Status* TF_NewStatus(void); - -// Delete a previously created status object. -TF_CAPI_EXPORT extern void TF_DeleteStatus(TF_Status*); - -// Record in *s. Any previous information is lost. -// A common use is to clear a status: TF_SetStatus(s, TF_OK, ""); -TF_CAPI_EXPORT extern void TF_SetStatus(TF_Status* s, TF_Code code, - const char* msg); - -// Return the code record in *s. -TF_CAPI_EXPORT extern TF_Code TF_GetCode(const TF_Status* s); - -// Return a pointer to the (null-terminated) error message in *s. The -// return value points to memory that is only usable until the next -// mutation to *s. Always returns an empty string if TF_GetCode(s) is -// TF_OK. -TF_CAPI_EXPORT extern const char* TF_Message(const TF_Status* s); - // -------------------------------------------------------------------------- // TF_Buffer holds a pointer to a block of data and its associated length. // Typically, the data consists of a serialized protocol buffer, but other data @@ -686,19 +607,6 @@ TF_CAPI_EXPORT extern int TF_OperationGetControlOutputs( TF_Operation* oper, TF_Operation** control_outputs, int max_control_outputs); -// TF_AttrType describes the type of the value of an attribute on an operation. -typedef enum TF_AttrType { - TF_ATTR_STRING = 0, - TF_ATTR_INT = 1, - TF_ATTR_FLOAT = 2, - TF_ATTR_BOOL = 3, - TF_ATTR_TYPE = 4, - TF_ATTR_SHAPE = 5, - TF_ATTR_TENSOR = 6, - TF_ATTR_PLACEHOLDER = 7, - TF_ATTR_FUNC = 8, -} TF_AttrType; - // TF_AttrMetadata describes the value of an attribute on an operation. typedef struct TF_AttrMetadata { // A boolean: 1 if the attribute value is a list, 0 otherwise. diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index 726ce2784ae..246fa91eccd 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/strings/substitute.h" #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" +#include "tensorflow/c/checkpoint_reader.h" #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/compiler/jit/flags.h" @@ -37,6 +38,7 @@ using tensorflow::FunctionDef; using tensorflow::Node; using tensorflow::NodeBuilder; using tensorflow::Status; +using tensorflow::errors::InvalidArgument; namespace { typedef std::unique_ptr @@ -149,7 +151,7 @@ const char* TF_GraphDebugString(TF_Graph* graph, size_t* len) { } char* TF_FunctionDebugString(TF_Function* func, size_t* len) { - const auto& debug_str = func->fdef.DebugString(); + const auto& debug_str = DebugString(func->fdef); *len = debug_str.size(); char* ret = static_cast(malloc(*len + 1)); memcpy(ret, debug_str.c_str(), *len + 1); @@ -576,6 +578,73 @@ void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg) { status->status = tensorflow::errors::Internal(errMsg); } +struct TF_CheckpointReader : public tensorflow::checkpoint::CheckpointReader { + using tensorflow::checkpoint::CheckpointReader::CheckpointReader; + std::vector variable_list; +}; + +TF_CheckpointReader* TF_NewCheckpointReader(const char* filename, + TF_Status* status) { + TF_CheckpointReader* reader = new TF_CheckpointReader(filename, status); + if (!status->status.ok()) return nullptr; + const auto& m = reader->GetVariableToDataTypeMap(); + for (auto it = m.begin(); it != m.end(); ++it) + reader->variable_list.push_back(it->first); + std::sort(reader->variable_list.begin(), reader->variable_list.end()); + return reader; +} + +void TF_DeleteCheckpointReader(TF_CheckpointReader* reader) { delete reader; } + +int TF_CheckpointReaderHasTensor(TF_CheckpointReader* reader, + const char* name) { + return reader->HasTensor(name); +} + +const char* TF_CheckpointReaderGetVariable(TF_CheckpointReader* reader, + int index) { + return reader->variable_list[index].c_str(); +} + +int TF_CheckpointReaderSize(TF_CheckpointReader* reader) { + return reader->variable_list.size(); +} + +TF_DataType TF_CheckpointReaderGetVariableDataType(TF_CheckpointReader* reader, + const char* name) { + const auto& m = reader->GetVariableToDataTypeMap(); + return static_cast(m.at(name)); +} + +TF_Tensor* TF_CheckpointReaderGetTensor(TF_CheckpointReader* reader, + const char* name, TF_Status* status) { + std::unique_ptr tensor; + reader->GetTensor(name, &tensor, status); + if (!status->status.ok()) return nullptr; + return tensorflow::TF_TensorFromTensor(*tensor.get(), status); +} + +void TF_CheckpointReaderGetVariableShape(TF_CheckpointReader* reader, + const char* name, int64_t* dims, + int num_dims, TF_Status* status) { + const auto& shape = reader->GetVariableToShapeMap().at(name); + int rank = shape.dims(); + if (num_dims != rank) { + status->status = InvalidArgument("Expected rank is ", num_dims, + " but actual rank is ", rank); + return; + } + for (int i = 0; i < num_dims; i++) { + dims[i] = shape.dim_size(i); + } +} + +int TF_CheckpointReaderGetVariableNumDims(TF_CheckpointReader* reader, + const char* name) { + const auto& m = reader->GetVariableToShapeMap(); + return m.at(name).dims(); +} + // This builder is used in the eager API to build a NodeDef. struct TF_AttrBuilder : public tensorflow::AttrBuilder { using tensorflow::AttrBuilder::AttrBuilder; diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index 795768a1415..25056904423 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -208,6 +208,34 @@ TF_CAPI_EXPORT extern void TFE_ExecuteOpNotificationWaitAndDelete( TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg); +// TF_NewCheckpointReader() return the CheckpointReader that can be use to +// investigate or load the variable from the checkpoint file +typedef struct TF_CheckpointReader TF_CheckpointReader; +TF_CAPI_EXPORT extern TF_CheckpointReader* TF_NewCheckpointReader( + const char* filename, TF_Status* status); +TF_CAPI_EXPORT extern void TF_DeleteCheckpointReader( + TF_CheckpointReader* reader); +TF_CAPI_EXPORT extern int TF_CheckpointReaderHasTensor( + TF_CheckpointReader* reader, const char* name); +// Get the variable name at the given index +TF_CAPI_EXPORT extern const char* TF_CheckpointReaderGetVariable( + TF_CheckpointReader* reader, int index); +// Get the number of variable in the checkpoint +TF_CAPI_EXPORT extern int TF_CheckpointReaderSize(TF_CheckpointReader* reader); +// Get the DataType of a variable +TF_CAPI_EXPORT extern TF_DataType TF_CheckpointReaderGetVariableDataType( + TF_CheckpointReader* reader, const char* name); +// Read the shape of a variable and write to `dims` +TF_CAPI_EXPORT extern void TF_CheckpointReaderGetVariableShape( + TF_CheckpointReader* reader, const char* name, int64_t* dims, int num_dims, + TF_Status* status); +// Get the number of dimension of a variable +TF_CAPI_EXPORT extern int TF_CheckpointReaderGetVariableNumDims( + TF_CheckpointReader* reader, const char* name); +// Load the weight of a variable +TF_CAPI_EXPORT extern TF_Tensor* TF_CheckpointReaderGetTensor( + TF_CheckpointReader* reader, const char* name, TF_Status* status); + // TF_NewAttrBuilder() returns an object that you can set attributes on as // though it were an op. This allows querying properties of that op for // type-checking purposes like if the op will run on a particular device type. diff --git a/tensorflow/c/c_api_experimental_test.cc b/tensorflow/c/c_api_experimental_test.cc index 6eb289107c5..55f3a8599fd 100644 --- a/tensorflow/c/c_api_experimental_test.cc +++ b/tensorflow/c/c_api_experimental_test.cc @@ -62,8 +62,8 @@ protocol: "grpc" TF_Buffer* null_result = TFE_GetServerDef(malformed_text_proto.c_str(), status); EXPECT_NE(TF_GetCode(status), TF_OK); - EXPECT_TRUE(tensorflow::str_util::StrContains( - TF_Message(status), "Invalid text proto for ServerDef")); + EXPECT_TRUE(absl::StrContains(TF_Message(status), + "Invalid text proto for ServerDef")); EXPECT_EQ(null_result, nullptr); // Cleanup diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc index 760f14cac5b..847a81f5424 100644 --- a/tensorflow/c/c_api_function_test.cc +++ b/tensorflow/c/c_api_function_test.cc @@ -253,7 +253,7 @@ class CApiFunctionTest : public ::testing::Test { const std::unordered_set& nodes) { ASSERT_EQ(nodes.size(), fdef.node_def_size()) << "Got unexpected number of nodes. Expected: [" - << str_util::Join(nodes, ", ") + << absl::StrJoin(nodes, ", ") << "] Actual nodes in fdef: " << fdef.DebugString(); for (const NodeDef& node_def : fdef.node_def()) { ASSERT_TRUE(nodes.find(node_def.name()) != nodes.end()) diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 2be03bf0de6..49076039fa7 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -56,7 +56,7 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst); namespace { static void ExpectHasSubstr(StringPiece s, StringPiece expected) { - EXPECT_TRUE(str_util::StrContains(s, expected)) + EXPECT_TRUE(absl::StrContains(s, expected)) << "'" << s << "' does not contain '" << expected << "'"; } diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 8c2be2af3e0..0db85a17802 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -1,6 +1,8 @@ # Experimental extensions to the C API for eager execution of kernels. -licenses(["notice"]) # Apache 2.0 +package( + licenses = ["notice"], # Apache 2.0 +) load( "//tensorflow:tensorflow.bzl", diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 0b86a78d41e..d8476bec2e4 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -30,7 +30,9 @@ limitations under the License. #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h" +#include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/platform/host_info.h" +#include "tensorflow/core/platform/platform.h" // NOLINT #ifdef TENSORFLOW_EAGER_USE_XLA #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #endif // TENSORFLOW_EAGER_USE_XLA @@ -135,11 +137,12 @@ tensorflow::Status CreateRemoteContexts( const std::vector& remote_workers, int64 rendezvous_id, int keep_alive_secs, const tensorflow::ServerDef& server_def, tensorflow::eager::EagerClientCache* remote_eager_workers, bool async, + const tensorflow::eager::CreateContextRequest& base_request, tensorflow::gtl::FlatMap* remote_contexts) { for (int i = 0; i < remote_workers.size(); i++) { const string& remote_worker = remote_workers[i]; - tensorflow::eager::CreateContextRequest request; + tensorflow::eager::CreateContextRequest request(base_request); tensorflow::eager::CreateContextResponse response; request.set_rendezvous_id(rendezvous_id); tensorflow::DeviceNameUtils::ParsedName parsed_name; @@ -221,6 +224,23 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( remote_workers, grpc_server->master_env()->worker_cache, &remote_device_mgr)); + std::vector cluster_device_attributes; + remote_device_mgr->ListDeviceAttributes(&cluster_device_attributes); + + std::vector local_device_attributes; + grpc_server->worker_env()->device_mgr->ListDeviceAttributes( + &local_device_attributes); + + // This request make sure that we can create Rendevzous properly between + // Local and Remote context. + tensorflow::eager::CreateContextRequest base_request; + for (const auto& da : cluster_device_attributes) { + *base_request.add_cluster_device_attributes() = da; + } + for (const auto& da : local_device_attributes) { + *base_request.add_cluster_device_attributes() = da; + } + std::shared_ptr channel_cache = grpc_server->channel_cache(); std::unique_ptr remote_eager_workers( @@ -230,14 +250,16 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( tensorflow::gtl::FlatMap remote_contexts; LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts( remote_workers, rendezvous_id, keep_alive_secs, server_def, - remote_eager_workers.get(), ctx->context->Async(), &remote_contexts)); + remote_eager_workers.get(), ctx->context->Async(), base_request, + &remote_contexts)); tensorflow::RemoteRendezvous* r = grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id); auto session_name = tensorflow::strings::StrCat("eager_", rendezvous_id); TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession( - session_name, server_def, true)); + session_name, server_def, base_request.cluster_device_attributes(), + true)); std::shared_ptr worker_session; TF_RETURN_IF_ERROR( @@ -250,9 +272,10 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( auto* device_mgr = grpc_server->worker_env()->device_mgr; return ctx->context->InitializeRemote( - std::move(server), std::move(remote_eager_workers), - std::move(remote_device_mgr), remote_contexts, r, device_mgr, - keep_alive_secs); + std::move(server), grpc_server->worker_env(), worker_session, + std::move(remote_eager_workers), std::move(remote_device_mgr), + remote_contexts, r, device_mgr, keep_alive_secs, + worker_session->cluster_flr.get()); #undef LOG_AND_RETURN_IF_ERROR } #endif // !IS_MOBILE_PLATFORM @@ -970,6 +993,23 @@ const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory( return t; } +TFE_TensorHandle* TFE_TensorHandleMaybeCopyToHostCPU(TFE_TensorHandle* h, + TF_Status* status) { + // TensorHandles created by PyFuncOp lack context and therefore could + // not be copied. + if (!h->handle->OnHostCPU() && h->handle->Context() != nullptr) { + tensorflow::TensorHandle* handle; + status->status = tensorflow::EagerCopyToDevice( + h->handle, h->handle->Context(), "CPU:0", &handle); + if (status->status.ok()) { + return new TFE_TensorHandle(handle); + } else { + return nullptr; + } + } + return h; +} + void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, TF_Status* status) { TFE_ContextAsyncWait(ctx, status); diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index d5223e63f13..076760161e1 100755 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -462,6 +462,9 @@ class Tensor; const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory( TFE_TensorHandle* h, TF_Status* status); + +TFE_TensorHandle* TFE_TensorHandleMaybeCopyToHostCPU(TFE_TensorHandle* h, + TF_Status* status); TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t); #endif diff --git a/tensorflow/c/eager/c_api_debug.cc b/tensorflow/c/eager/c_api_debug.cc index b4192716c4f..eaa520d72cc 100644 --- a/tensorflow/c/eager/c_api_debug.cc +++ b/tensorflow/c/eager/c_api_debug.cc @@ -78,7 +78,7 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo( status->status = tensorflow::Status::OK(); } else { VLOG(3) << "Fully padded shape of [" - << tensorflow::str_util::Join(shape_to_log, ", ") << "] is " + << absl::StrJoin(shape_to_log, ", ") << "] is " << padded_shape.DebugString(); } } diff --git a/tensorflow/c/eager/c_api_experimental_test.cc b/tensorflow/c/eager/c_api_experimental_test.cc index 4e48a7591a9..53984c0e6c0 100644 --- a/tensorflow/c/eager/c_api_experimental_test.cc +++ b/tensorflow/c/eager/c_api_experimental_test.cc @@ -33,7 +33,7 @@ namespace tensorflow { namespace { static bool HasSubstr(absl::string_view base, absl::string_view substr) { - bool ok = str_util::StrContains(base, substr); + bool ok = absl::StrContains(base, substr); EXPECT_TRUE(ok) << base << ", expected substring " << substr; return ok; } diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 57aa71d5b3b..57bea7311e6 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -1408,29 +1408,34 @@ void FunctionDefAndExecute(bool async) { status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_TensorHandle* m = TestMatrixTensorHandle(); - TFE_TensorHandle* retval[1] = {nullptr}; - int num_retvals = 1; - TFE_Op* op = TFE_NewOp(ctx, "MatMulFunction", status); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_OpAddInput(op, m, status); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_Execute(op, &retval[0], &num_retvals, status); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - ASSERT_EQ(1, num_retvals); - TFE_DeleteOp(op); - TFE_DeleteTensorHandle(m); - TF_Tensor* t = TFE_TensorHandleResolve(retval[0], status); - TFE_DeleteTensorHandle(retval[0]); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - float product[4] = {0}; - EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); - memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); - TF_DeleteTensor(t); - EXPECT_EQ(7, product[0]); - EXPECT_EQ(10, product[1]); - EXPECT_EQ(15, product[2]); - EXPECT_EQ(22, product[3]); + for (bool clear_cache : {true, false, true}) { + if (clear_cache) { + TFE_ContextClearCaches(ctx); + } + TFE_TensorHandle* m = TestMatrixTensorHandle(); + TFE_TensorHandle* retval[1] = {nullptr}; + int num_retvals = 1; + TFE_Op* op = TFE_NewOp(ctx, "MatMulFunction", status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, m, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_Execute(op, &retval[0], &num_retvals, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(1, num_retvals); + TFE_DeleteOp(op); + TFE_DeleteTensorHandle(m); + TF_Tensor* t = TFE_TensorHandleResolve(retval[0], status); + TFE_DeleteTensorHandle(retval[0]); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + float product[4] = {0}; + EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); + memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(7, product[0]); + EXPECT_EQ(10, product[1]); + EXPECT_EQ(15, product[2]); + EXPECT_EQ(22, product[3]); + } TFE_ContextRemoveFunction(ctx, "MatMulFunction", status); ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); TFE_DeleteContext(ctx); diff --git a/tensorflow/c/experimental/BUILD b/tensorflow/c/experimental/BUILD index b66969eb3ff..bc408e637c2 100644 --- a/tensorflow/c/experimental/BUILD +++ b/tensorflow/c/experimental/BUILD @@ -1,7 +1,9 @@ # Description: # Experimental C APIs for TensorFlow. -licenses(["notice"]) # Apache 2.0 +package( + licenses = ["notice"], # Apache 2.0 +) load( "//tensorflow:tensorflow.bzl", diff --git a/tensorflow/c/kernels/BUILD b/tensorflow/c/kernels/BUILD index 597182ab016..c71f6f1cca2 100644 --- a/tensorflow/c/kernels/BUILD +++ b/tensorflow/c/kernels/BUILD @@ -6,10 +6,9 @@ load( package( default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 ) -licenses(["notice"]) # Apache 2.0 - tf_kernel_library( name = "bitcast_op", prefix = "bitcast_op", diff --git a/tensorflow/c/tf_attrtype.h b/tensorflow/c/tf_attrtype.h new file mode 100644 index 00000000000..0c1545db232 --- /dev/null +++ b/tensorflow/c/tf_attrtype.h @@ -0,0 +1,39 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_TF_ATTRTYPE_H_ +#define TENSORFLOW_C_TF_ATTRTYPE_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +// TF_AttrType describes the type of the value of an attribute on an operation. +typedef enum TF_AttrType { + TF_ATTR_STRING = 0, + TF_ATTR_INT = 1, + TF_ATTR_FLOAT = 2, + TF_ATTR_BOOL = 3, + TF_ATTR_TYPE = 4, + TF_ATTR_SHAPE = 5, + TF_ATTR_TENSOR = 6, + TF_ATTR_PLACEHOLDER = 7, + TF_ATTR_FUNC = 8, +} TF_AttrType; + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#endif // TENSORFLOW_C_TF_ATTRTYPE_H_ diff --git a/tensorflow/c/tf_datatype.cc b/tensorflow/c/tf_datatype.cc new file mode 100644 index 00000000000..d2a66d99dac --- /dev/null +++ b/tensorflow/c/tf_datatype.cc @@ -0,0 +1,23 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/tf_datatype.h" + +#include "tensorflow/core/framework/types.h" + +size_t TF_DataTypeSize(TF_DataType dt) { + return static_cast( + tensorflow::DataTypeSize(static_cast(dt))); +} diff --git a/tensorflow/c/tf_datatype.h b/tensorflow/c/tf_datatype.h new file mode 100644 index 00000000000..3e6121bf989 --- /dev/null +++ b/tensorflow/c/tf_datatype.h @@ -0,0 +1,83 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_TF_DATATYPE_H_ +#define TENSORFLOW_C_TF_DATATYPE_H_ + +#include + +// Macro to control visibility of exported symbols in the shared library (.so, +// .dylib, .dll). +// This duplicates the TF_EXPORT macro definition in +// tensorflow/core/platform/macros.h in order to keep this .h file independent +// of any other includes. +#ifdef SWIG +#define TF_CAPI_EXPORT +#else +#if defined(_WIN32) +#ifdef TF_COMPILE_LIBRARY +#define TF_CAPI_EXPORT __declspec(dllexport) +#else +#define TF_CAPI_EXPORT __declspec(dllimport) +#endif // TF_COMPILE_LIBRARY +#else +#define TF_CAPI_EXPORT __attribute__((visibility("default"))) +#endif // _WIN32 +#endif // SWIG + +#ifdef __cplusplus +extern "C" { +#endif + +// -------------------------------------------------------------------------- +// TF_DataType holds the type for a scalar value. E.g., one slot in a tensor. +// The enum values here are identical to corresponding values in types.proto. +typedef enum TF_DataType { + TF_FLOAT = 1, + TF_DOUBLE = 2, + TF_INT32 = 3, // Int32 tensors are always in 'host' memory. + TF_UINT8 = 4, + TF_INT16 = 5, + TF_INT8 = 6, + TF_STRING = 7, + TF_COMPLEX64 = 8, // Single-precision complex + TF_COMPLEX = 8, // Old identifier kept for API backwards compatibility + TF_INT64 = 9, + TF_BOOL = 10, + TF_QINT8 = 11, // Quantized int8 + TF_QUINT8 = 12, // Quantized uint8 + TF_QINT32 = 13, // Quantized int32 + TF_BFLOAT16 = 14, // Float32 truncated to 16 bits. Only for cast ops. + TF_QINT16 = 15, // Quantized int16 + TF_QUINT16 = 16, // Quantized uint16 + TF_UINT16 = 17, + TF_COMPLEX128 = 18, // Double-precision complex + TF_HALF = 19, + TF_RESOURCE = 20, + TF_VARIANT = 21, + TF_UINT32 = 22, + TF_UINT64 = 23, +} TF_DataType; + +// TF_DataTypeSize returns the sizeof() for the underlying type corresponding +// to the given TF_DataType enum value. Returns 0 for variable length types +// (eg. TF_STRING) or on failure. +TF_CAPI_EXPORT extern size_t TF_DataTypeSize(TF_DataType dt); + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#endif // TENSORFLOW_C_TF_DATATYPE_H_ diff --git a/tensorflow/c/tf_status.cc b/tensorflow/c/tf_status.cc new file mode 100644 index 00000000000..a77b18c2ca0 --- /dev/null +++ b/tensorflow/c/tf_status.cc @@ -0,0 +1,42 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/tf_status.h" + +#include "tensorflow/c/c_api_internal.h" +#include "tensorflow/core/lib/core/status.h" + +using ::tensorflow::Status; +using ::tensorflow::error::Code; + +TF_Status* TF_NewStatus() { return new TF_Status; } + +void TF_DeleteStatus(TF_Status* s) { delete s; } + +void TF_SetStatus(TF_Status* s, TF_Code code, const char* msg) { + if (code == TF_OK) { + s->status = Status::OK(); + return; + } + s->status = Status(static_cast(code), tensorflow::StringPiece(msg)); +} + +TF_Code TF_GetCode(const TF_Status* s) { + return static_cast(s->status.code()); +} + +const char* TF_Message(const TF_Status* s) { + return s->status.error_message().c_str(); +} diff --git a/tensorflow/c/tf_status.h b/tensorflow/c/tf_status.h new file mode 100644 index 00000000000..937f6bed2d7 --- /dev/null +++ b/tensorflow/c/tf_status.h @@ -0,0 +1,88 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_TF_STATUS_H_ +#define TENSORFLOW_C_TF_STATUS_H_ + +#ifdef SWIG +#define TF_CAPI_EXPORT +#else +#if defined(_WIN32) +#ifdef TF_COMPILE_LIBRARY +#define TF_CAPI_EXPORT __declspec(dllexport) +#else +#define TF_CAPI_EXPORT __declspec(dllimport) +#endif // TF_COMPILE_LIBRARY +#else +#define TF_CAPI_EXPORT __attribute__((visibility("default"))) +#endif // _WIN32 +#endif // SWIG + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct TF_Status TF_Status; + +// -------------------------------------------------------------------------- +// TF_Code holds an error code. The enum values here are identical to +// corresponding values in error_codes.proto. +typedef enum TF_Code { + TF_OK = 0, + TF_CANCELLED = 1, + TF_UNKNOWN = 2, + TF_INVALID_ARGUMENT = 3, + TF_DEADLINE_EXCEEDED = 4, + TF_NOT_FOUND = 5, + TF_ALREADY_EXISTS = 6, + TF_PERMISSION_DENIED = 7, + TF_UNAUTHENTICATED = 16, + TF_RESOURCE_EXHAUSTED = 8, + TF_FAILED_PRECONDITION = 9, + TF_ABORTED = 10, + TF_OUT_OF_RANGE = 11, + TF_UNIMPLEMENTED = 12, + TF_INTERNAL = 13, + TF_UNAVAILABLE = 14, + TF_DATA_LOSS = 15, +} TF_Code; + +// -------------------------------------------------------------------------- + +// Return a new status object. +TF_CAPI_EXPORT extern TF_Status* TF_NewStatus(void); + +// Delete a previously created status object. +TF_CAPI_EXPORT extern void TF_DeleteStatus(TF_Status*); + +// Record in *s. Any previous information is lost. +// A common use is to clear a status: TF_SetStatus(s, TF_OK, ""); +TF_CAPI_EXPORT extern void TF_SetStatus(TF_Status* s, TF_Code code, + const char* msg); + +// Return the code record in *s. +TF_CAPI_EXPORT extern TF_Code TF_GetCode(const TF_Status* s); + +// Return a pointer to the (null-terminated) error message in *s. The +// return value points to memory that is only usable until the next +// mutation to *s. Always returns an empty string if TF_GetCode(s) is +// TF_OK. +TF_CAPI_EXPORT extern const char* TF_Message(const TF_Status* s); + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#endif // TENSORFLOW_C_TF_STATUS_H_ diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index bd741249cf2..c5e1262cec3 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -4,10 +4,9 @@ package( default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 ) -licenses(["notice"]) # Apache 2.0 - filegroup( name = "srcs", srcs = [ @@ -638,6 +637,7 @@ cc_library( "//tensorflow/core:op_gen_lib", "//tensorflow/core:proto_text", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", ], ) @@ -657,6 +657,7 @@ tf_cc_test( "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index 0605a62b83a..a0353bf17a6 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -13,11 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/cc/framework/cc_op_gen.h" + #include #include #include -#include "tensorflow/cc/framework/cc_op_gen.h" +#include "absl/strings/escaping.h" #include "tensorflow/core/framework/api_def.pb.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value_util.h" @@ -133,7 +135,7 @@ string MakeComment(StringPiece text, StringPiece indent) { } string PrintString(const string& str) { - return strings::StrCat("\"", str_util::CEscape(str), "\""); + return strings::StrCat("\"", absl::CEscape(str), "\""); } string PrintTensorShape(const TensorShapeProto& shape_proto) { @@ -191,7 +193,7 @@ string PrintTensor(const TensorProto& tensor_proto) { string ret; for (int64 i = 0; i < num_elts; ++i) { if (i > 0) strings::StrAppend(&ret, " "); - strings::StrAppend(&ret, str_util::CEscape(t.flat()(i))); + strings::StrAppend(&ret, absl::CEscape(t.flat()(i))); } return ret; } diff --git a/tensorflow/cc/framework/cc_op_gen_test.cc b/tensorflow/cc/framework/cc_op_gen_test.cc index 5d9dfd95a55..698ef5be20b 100644 --- a/tensorflow/cc/framework/cc_op_gen_test.cc +++ b/tensorflow/cc/framework/cc_op_gen_test.cc @@ -62,12 +62,12 @@ op { )"; void ExpectHasSubstr(StringPiece s, StringPiece expected) { - EXPECT_TRUE(str_util::StrContains(s, expected)) + EXPECT_TRUE(absl::StrContains(s, expected)) << "'" << s << "' does not contain '" << expected << "'"; } void ExpectDoesNotHaveSubstr(StringPiece s, StringPiece expected) { - EXPECT_FALSE(str_util::StrContains(s, expected)) + EXPECT_FALSE(absl::StrContains(s, expected)) << "'" << s << "' contains '" << expected << "'"; } diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc index e74ba009083..e93ca8633e6 100644 --- a/tensorflow/cc/framework/scope.cc +++ b/tensorflow/cc/framework/scope.cc @@ -275,7 +275,7 @@ std::unordered_set Scope::Impl::GetColocationConstraints( if (GetNodeAttr(attrs, kColocationAttrName, &node_constraints).ok()) { for (const string& entry : node_constraints) { StringPiece s(entry); - if (str_util::ConsumePrefix(&s, kColocationGroupPrefix)) { + if (absl::ConsumePrefix(&s, kColocationGroupPrefix)) { current_constraints.emplace(s); } } diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index 13bc88f7cd3..8626ed0087e 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -3,10 +3,9 @@ package( default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 ) -licenses(["notice"]) # Apache 2.0 - exports_files(["LICENSE"]) load( diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index 70f362cfeae..dfc7ccd9542 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -308,7 +308,7 @@ Status LoadSavedModel(const SessionOptions& session_options, const Status status = LoadSavedModelInternal(session_options, run_options, export_dir, tags, bundle); auto log_and_count = [&](const string& status_str) { - LOG(INFO) << "SavedModel load for tags { " << str_util::Join(tags, " ") + LOG(INFO) << "SavedModel load for tags { " << absl::StrJoin(tags, " ") << " }; Status: " << status_str << ". Took " << GetLatencyMicroseconds(start_microseconds) << " microseconds."; load_attempt_count->GetCell(export_dir, status_str)->IncrementBy(1); diff --git a/tensorflow/cc/saved_model/loader_test.cc b/tensorflow/cc/saved_model/loader_test.cc index 597e42bb65a..422994ba07c 100644 --- a/tensorflow/cc/saved_model/loader_test.cc +++ b/tensorflow/cc/saved_model/loader_test.cc @@ -136,7 +136,7 @@ TEST_F(LoaderTest, NoTagMatch) { Status st = LoadSavedModel(session_options, run_options, export_dir, {"missing-tag"}, &bundle); EXPECT_FALSE(st.ok()); - EXPECT_TRUE(str_util::StrContains( + EXPECT_TRUE(absl::StrContains( st.error_message(), "Could not find meta graph def matching supplied tags: { missing-tag }")) << st.error_message(); @@ -152,7 +152,7 @@ TEST_F(LoaderTest, NoTagMatchMultiple) { Status st = LoadSavedModel(session_options, run_options, export_dir, {kSavedModelTagServe, "missing-tag"}, &bundle); EXPECT_FALSE(st.ok()); - EXPECT_TRUE(str_util::StrContains( + EXPECT_TRUE(absl::StrContains( st.error_message(), "Could not find meta graph def matching supplied tags: ")) << st.error_message(); @@ -172,7 +172,7 @@ TEST_F(LoaderTest, SessionCreationFailure) { Status st = LoadSavedModel(session_options, run_options, export_dir, {kSavedModelTagServe}, &bundle); EXPECT_FALSE(st.ok()); - EXPECT_TRUE(str_util::StrContains(st.error_message(), kInvalidTarget)) + EXPECT_TRUE(absl::StrContains(st.error_message(), kInvalidTarget)) << st.error_message(); } diff --git a/tensorflow/cc/saved_model/reader.cc b/tensorflow/cc/saved_model/reader.cc index 2146c8a1974..799856f7fd4 100644 --- a/tensorflow/cc/saved_model/reader.cc +++ b/tensorflow/cc/saved_model/reader.cc @@ -51,7 +51,7 @@ Status ReadSavedModel(const string& export_dir, SavedModel* saved_model_proto) { Status FindMetaGraphDef(const SavedModel& saved_model_proto, const std::unordered_set& tags, MetaGraphDef* meta_graph_def) { - LOG(INFO) << "Reading meta graph with tags { " << str_util::Join(tags, " ") + LOG(INFO) << "Reading meta graph with tags { " << absl::StrJoin(tags, " ") << " }"; for (const MetaGraphDef& graph_def : saved_model_proto.meta_graphs()) { // Get tags from the graph_def. @@ -69,7 +69,7 @@ Status FindMetaGraphDef(const SavedModel& saved_model_proto, error::Code::NOT_FOUND, strings::StrCat( "Could not find meta graph def matching supplied tags: { ", - str_util::Join(tags, " "), + absl::StrJoin(tags, " "), " }. To inspect available tag-sets in the SavedModel, please " "use the SavedModel CLI: `saved_model_cli`")); } diff --git a/tensorflow/cc/saved_model/reader_test.cc b/tensorflow/cc/saved_model/reader_test.cc index 620e9c2eece..e898664c221 100644 --- a/tensorflow/cc/saved_model/reader_test.cc +++ b/tensorflow/cc/saved_model/reader_test.cc @@ -64,7 +64,7 @@ TEST_F(ReaderTest, NoTagMatch) { Status st = ReadMetaGraphDefFromSavedModel(export_dir, {"missing-tag"}, &meta_graph_def); EXPECT_FALSE(st.ok()); - EXPECT_TRUE(str_util::StrContains( + EXPECT_TRUE(absl::StrContains( st.error_message(), "Could not find meta graph def matching supplied tags: { missing-tag }")) << st.error_message(); @@ -78,7 +78,7 @@ TEST_F(ReaderTest, NoTagMatchMultiple) { Status st = ReadMetaGraphDefFromSavedModel( export_dir, {kSavedModelTagServe, "missing-tag"}, &meta_graph_def); EXPECT_FALSE(st.ok()); - EXPECT_TRUE(str_util::StrContains( + EXPECT_TRUE(absl::StrContains( st.error_message(), "Could not find meta graph def matching supplied tags: ")) << st.error_message(); diff --git a/tensorflow/cc/tools/BUILD b/tensorflow/cc/tools/BUILD index c173569a095..8e509aeeae8 100644 --- a/tensorflow/cc/tools/BUILD +++ b/tensorflow/cc/tools/BUILD @@ -3,10 +3,9 @@ package( default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 ) -licenses(["notice"]) # Apache 2.0 - exports_files(["LICENSE"]) load( diff --git a/tensorflow/cc/tutorials/example_trainer.cc b/tensorflow/cc/tutorials/example_trainer.cc index 5dbc4f5f6aa..789662f84d0 100644 --- a/tensorflow/cc/tutorials/example_trainer.cc +++ b/tensorflow/cc/tutorials/example_trainer.cc @@ -167,8 +167,7 @@ namespace { bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, int32* dst) { - if (tensorflow::str_util::ConsumePrefix(&arg, flag) && - tensorflow::str_util::ConsumePrefix(&arg, "=")) { + if (absl::ConsumePrefix(&arg, flag) && absl::ConsumePrefix(&arg, "=")) { char extra; return (sscanf(arg.data(), "%d%c", dst, &extra) == 1); } @@ -178,7 +177,7 @@ bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, bool* dst) { - if (tensorflow::str_util::ConsumePrefix(&arg, flag)) { + if (absl::ConsumePrefix(&arg, flag)) { if (arg.empty()) { *dst = true; return true; diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index 6362470abef..6daf18b51c4 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -1,7 +1,6 @@ -licenses(["notice"]) # Apache 2.0 - package( default_visibility = ["//visibility:private"], + licenses = ["notice"], # Apache 2.0 ) load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 4b3726b8475..7d5e889bf7d 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -1,4 +1,12 @@ -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = [ + ":internal", + # BEGIN-GOOGLE-INTERNAL + "//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__", + # END-GOOGLE-INTERNAL + ], + licenses = ["notice"], # Apache 2.0 +) package_group( name = "internal", @@ -14,15 +22,6 @@ package_group( ], ) -package( - default_visibility = [ - ":internal", - # BEGIN-GOOGLE-INTERNAL - "//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__", - # END-GOOGLE-INTERNAL - ], -) - load("//tensorflow:tensorflow.bzl", "tf_cc_test", "cc_header_only_library") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") @@ -200,6 +199,7 @@ cc_library( "//tensorflow/core/kernels:host_constant_op", "//tensorflow/core/kernels:identity_n_op", "//tensorflow/core/kernels:identity_op", + "//tensorflow/core/kernels:logging_ops", "//tensorflow/core/kernels:no_op", "//tensorflow/core/kernels:queue_op", "//tensorflow/core/kernels:resource_variable_ops", @@ -257,10 +257,8 @@ cc_library( name = "xla_launch_util", srcs = ["xla_launch_util.cc"], hdrs = ["xla_launch_util.h"], - # TODO(skyewm): remove this once XlaAllocator is factored out. visibility = [ ":internal", - "//tensorflow/compiler/xla/python:__pkg__", ], deps = [ ":common", diff --git a/tensorflow/compiler/jit/compilability_check_util.cc b/tensorflow/compiler/jit/compilability_check_util.cc index 91e85970cc0..2d12de53b45 100644 --- a/tensorflow/compiler/jit/compilability_check_util.cc +++ b/tensorflow/compiler/jit/compilability_check_util.cc @@ -244,11 +244,11 @@ bool RecursiveCompilabilityChecker::IsCompilableNode( "resource variable op in called function"); } - if (!op_filter_.allow_slow_and_inaccurate_ops && OpIsInaccurate(node)) { + if (!op_filter_.allow_inaccurate_ops && OpIsInaccurate(node)) { return LogNotCompilableAndReturn(node, "operation with correctness issues"); } - if (!op_filter_.allow_slow_and_inaccurate_ops && OpIsSlow(node)) { + if (!op_filter_.allow_slow_ops && OpIsSlow(node)) { return LogNotCompilableAndReturn(node, "slow operation"); } @@ -268,8 +268,8 @@ RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter( registration.elide_assert_and_checknumerics; op_filter.allow_ops_producing_or_consuming_variant = registration.cluster_variant_ops; - op_filter.allow_slow_and_inaccurate_ops = - registration.cluster_slow_and_inaccurate_ops; + op_filter.allow_slow_ops = registration.cluster_slow_ops; + op_filter.allow_inaccurate_ops = registration.cluster_inaccurate_ops; return op_filter; } diff --git a/tensorflow/compiler/jit/compilability_check_util.h b/tensorflow/compiler/jit/compilability_check_util.h index 4be8050f7da..a20fc976289 100644 --- a/tensorflow/compiler/jit/compilability_check_util.h +++ b/tensorflow/compiler/jit/compilability_check_util.h @@ -97,9 +97,12 @@ class RecursiveCompilabilityChecker { // live-out DT_VARIANT values. bool allow_ops_producing_or_consuming_variant; - // Whether ops known to be slow or to have correctness issues should be - // auto-clustered. - bool allow_slow_and_inaccurate_ops; + // Whether ops known to be slow on XLA-GPU should be considered compilable.. + bool allow_slow_ops; + + // Whether ops known to have numerical accuracy issues should be considered + // compilable.. + bool allow_inaccurate_ops; }; RecursiveCompilabilityChecker(const OperationFilter* op_filter, diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc index 0a92c06ad10..1f23c0880db 100644 --- a/tensorflow/compiler/jit/deadness_analysis.cc +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -14,10 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/deadness_analysis.h" + #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/jit/deadness_analysis_internal.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -43,12 +45,12 @@ limitations under the License. // ------------------------------------------ // // If we ignore cycles for a moment, computing predicates is fairly -// straightforward. We traverse the graph in RPO, mapping each node to a -// predicate based on the predicates its inputs are mapped to. For instance a -// Merge(X, Y) node will be mapped to OR(PredicateFor(X), PredicateFor(Y)). -// Roughtly speaking, we abstract interpret each node on the "liveness" domain, -// where values in the domain represent if a tensor carries a dead signal or -// not. +// straightforward. We traverse the graph in a topological order, mapping each +// node to a predicate based on the predicates its inputs are mapped to. For +// instance a Merge(X, Y) node will be mapped to OR(PredicateFor(X), +// PredicateFor(Y)). Roughtly speaking, we abstractly interpret each node on +// the "liveness" domain, where values in the domain represent if a tensor +// carries a dead signal or not. // // // DEALING WITH CYCLES @@ -85,22 +87,28 @@ limitations under the License. // true on iteration 0, 1, 2 respectively. This is made more precise in the // comment on the AndRecurrence class. // -// The general algorithm that deals with cycles does two RPO (reverse post -// order) passes over the graph. On the first pass it assigns a symbolic -// predicate to merge nodes with backedges. On the second pass it tries to -// pattern matche the predicates for the backedges of these merges and infer an -// AndRecurrence for the merge. +// The general algorithm that deals with cycles does two topological-order +// iterations over the graph. On the first iteration it assigns a symbolic +// predicate to merge nodes with backedges. On the second iteration it tries +// to pattern match the predicates for the backedges of these merges and infer +// an AndRecurrence for the merge. In other words, we do a data flow analysis +// where the data-flow lattice has two elements, Symbolic and NonSymbolic with +// Symbolic > NonSymbolic. The lattice has height = 2 so two iterations are +// sufficient to converge. // -// In other words, we do a pessimistic data flow analysis where the data-flow -// lattice has two elements, Symbolic and NonSymbolic with Symbolic > -// NonSymbolic. The lattice has height = 2 so two iterations are sufficient to -// converge. We don't do an optimistic data flow analysis to make pattern -// matching easier: if we assigned the predicate of the initial value to the -// merge during the first pass, on the second pass the backedge may see a -// simplified value that would be difficult to pattern match. +// We first do an optimisitc analysis and, if it does not converge, we then fall +// back to a pessimistic analysis. The optimistic analysis assigns the same +// symbolic predicate to all the merge nodes whose preceding enter nodes have +// the same frame name on the first iteration. On the second iteration, if all +// the merge nodes are pattern matched into the same AndRecurrence predicate +// instance, the optimistic assignment of the same symbolic predicate is correct +// and the analyzed result is taken. // -// We still use symbolic predicates for merges for which we can't pattern match -// on the backedge predicate. This is conservatively correct. +// Otherwise, if the optimistic analysis fails to converge, we then obtain the +// result by falling back to the pessimistic analysis which assigns a unique +// symbolic predicate to each merge on the first iteration. We still use +// symbolic predicates for merges for which we can't pattern match on the +// backedge predicate. This is conservatively correct. namespace tensorflow { @@ -636,6 +644,35 @@ Predicate* PredicateFactory::MakeAndOrImpl( negated_ops.insert(negated_op); } + // Simplify {S,&,X} & ~X & ... => S & ... + if (is_and) { + absl::flat_hash_set to_remove; + std::vector to_add; + for (Predicate* op : simplified_ops) { + if (op->kind() == Predicate::Kind::kAndRecurrence) { + auto* and_rec = static_cast(op); + if (negated_ops.contains(and_rec->step())) { + // Remove and_rec and ~X and insert S. Note that checking the + // existence of ~X through negated_ops is sufficient since it makes + // sure the predicate is in the input operands. It does not need to + // be in simplified_ops if it was already cancelled out. + to_remove.insert(and_rec); + to_remove.insert(MakeNotPredicate(and_rec->step())); + to_add.push_back(and_rec->start()); + } + } + } + auto it = simplified_ops.begin(); + while (it != simplified_ops.end()) { + if (to_remove.contains(*it)) { + it = simplified_ops.erase(it); + } else { + ++it; + } + } + simplified_ops.insert(simplified_ops.end(), to_add.begin(), to_add.end()); + } + // If all ops contain the same subop, then factor it out thanks to the // distributive property. Such as: // - (A & B) | (A & C) | (A & D) => A & (B | C | D) @@ -699,8 +736,9 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis { explicit DeadnessAnalysisImpl(const Graph* graph) : graph_(*graph), vlog_(VLOG_IS_ON(2)) {} - Status Populate(); - Status PopulateWithReversePostOrder(absl::Span rpo); + Status Populate(bool enable_optimistic); + Status PopulateFrame(absl::Span topo, bool use_optimistic_mode, + bool* success); StatusOr GetPredicateFor( Node* n, int oidx) const override; void Print() const override; @@ -742,16 +780,29 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis { } Status HandleSwitch(Node* n, std::vector* should_revisit); - Status HandleMerge(Node* n, std::vector* should_revisit); + Status HandleMerge(Node* n, std::vector* should_revisit, + bool use_optimistic_mode); Status HandleRecv(Node* n, std::vector* should_revisit); Status HandleGeneric(Node* n, std::vector* should_revisit); - Status HandleNode(Node* n, std::vector* should_revisit); + Status HandleNode(Node* n, std::vector* should_revisit, + bool use_optimistic_mode = false); + + Status GetFrameBasedTopologicalOrder(std::vector* order); + + bool IsRootEnter(const Node* n) const { + return IsEnter(n) && control_flow_info_[n->id()].parent_frame->IsSource(); + } + + bool IsRootExit(const Node* n) const { + return IsExit(n) && control_flow_info_[n->id()].parent_frame->IsSource(); + } const Graph& graph_; absl::flat_hash_map predicate_map_; PredicateFactory predicate_factory_; std::vector control_flow_info_; bool vlog_; + absl::flat_hash_map frame_to_merge_node_; }; TensorId InputEdgeToTensorId(const Edge* e) { @@ -914,10 +965,32 @@ Status GetFullFrame(const Node* n, absl::Span cfi_infos, return Status::OK(); } + +// If the node is inside some frames, get the name of the outermost non-empty +// frame. Otherwise, get an empty frame name. +Status GetRootFrame(const Node* n, absl::Span cfi_infos, + absl::string_view* frame) { + int depth = 0; + const ControlFlowInfo* cfi_iter = &cfi_infos[n->id()]; + while (!cfi_iter->parent_frame->IsSource()) { + n = cfi_iter->parent_frame; + cfi_iter = &cfi_infos[n->id()]; + + if (depth++ > 5000) { + return errors::Internal( + "Frame of depth > 5000: Probably malformed graph or a bug in " + "BuildControlFlowInfo"); + } + } + + *frame = cfi_iter->frame_name; + return Status::OK(); +} } // namespace Status DeadnessAnalysisImpl::HandleMerge(Node* n, - std::vector* should_revisit) { + std::vector* should_revisit, + bool use_optimistic_mode) { // Merge ignores deadness of its control inputs. A merge that isn't the // target of a backedge has is alive iff any of its data inputs are. The // liveness of a merge that is the target of a backedge can sometimes be @@ -937,8 +1010,21 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n, // We're visiting this merge for the first time and it has an unvisited // backedge. Predicate* input_data_pred; - TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate( - n, /*output_idx=*/0, /*must_be_true=*/false, &input_data_pred)); + if (use_optimistic_mode) { + // In the optimistic mode, we use the first-seen Merge node per + // frame as the representative Merge node. It is just convenient and + // does not affect the result after pattern-matching into the + // AndRecurrence form. + absl::string_view frame_name = control_flow_info_[n->id()].frame_name; + auto insert_result = frame_to_merge_node_.insert({frame_name, n}); + Node* representative = insert_result.first->second; + TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate( + representative, /*output_idx=*/0, /*must_be_true=*/false, + &input_data_pred)); + } else { + TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate( + n, /*output_idx=*/0, /*must_be_true=*/false, &input_data_pred)); + } SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred, should_revisit); @@ -948,7 +1034,7 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n, std::vector input_preds; TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataOnly, &input_preds)); - // We're visiting this merge for the first time and it is a acyclic merge. + // We're visiting this merge for the first time and it is an acyclic merge. Predicate* input_data_pred = predicate_factory_.MakeOrPredicate(input_preds); SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred, @@ -1022,11 +1108,12 @@ Status DeadnessAnalysisImpl::HandleGeneric(Node* n, } Status DeadnessAnalysisImpl::HandleNode(Node* n, - std::vector* should_revisit) { + std::vector* should_revisit, + bool use_optimistic_mode) { if (n->IsSwitch()) { TF_RETURN_IF_ERROR(HandleSwitch(n, should_revisit)); } else if (n->IsMerge()) { - TF_RETURN_IF_ERROR(HandleMerge(n, should_revisit)); + TF_RETURN_IF_ERROR(HandleMerge(n, should_revisit, use_optimistic_mode)); } else if (n->IsControlTrigger()) { SetPredicate(n, Graph::kControlSlot, predicate_factory_.MakeTrue(), nullptr); @@ -1040,17 +1127,129 @@ Status DeadnessAnalysisImpl::HandleNode(Node* n, return Status::OK(); } -Status DeadnessAnalysisImpl::Populate() { - std::vector rpo; - GetReversePostOrder(graph_, &rpo, /*stable_comparator=*/NodeComparatorName(), - /*edge_filter=*/[](const Edge& edge) { - return !edge.src()->IsNextIteration(); - }); - return PopulateWithReversePostOrder(rpo); +// Compute a special topological order for the Graph, where nodes having the +// same root frame are placed adjacent to each other. The traversal uses a +// variant of Kahn's algorithm. num_ready_inputs is used to keep track of how +// many inputs of each node are ready; a node is ready to be scheduled if all +// of its inputs are ready. +// Ref. to https://en.wikipedia.org/wiki/Topological_sorting for details. +Status DeadnessAnalysisImpl::GetFrameBasedTopologicalOrder( + std::vector* order) { + absl::flat_hash_map num_enters_for_frame; + absl::flat_hash_map num_exits_for_frame; + std::vector num_ready_inputs(graph_.num_node_ids(), 0); + Node* src_node = graph_.source_node(); + for (const auto* node : graph_.op_nodes()) { + const ControlFlowInfo& cf = control_flow_info_[node->id()]; + if (IsRootEnter(node)) { + // Since we care only the root-level frame, full frame names are the same + // as frame names. + ++num_enters_for_frame[cf.frame_name]; + } else if (IsRootExit(node)) { + ++num_exits_for_frame[cf.frame_name]; + } + // Edge NextIteration->Merge is counted before starting the traveral to + // break the backedges. + if (IsMerge(node)) { + for (const Edge* e : node->in_edges()) { + if (IsNextIteration(e->src())) { + ++num_ready_inputs[node->id()]; + } + } + } + } + + // dequeue is used to ensure that the nodes are first-in-first-out. This + // order guarantees that the exits in the ready queue are visited before + // nodes that will become ready in the future. + std::deque ready; + ready.push_back(src_node); + // ready_enters_per_frame and ready_exits serve as a staging area to buffer + // the ready enters/exits before they are moved to the `ready` queue for + // controlling the start and end of a processing frame. + absl::flat_hash_map> + ready_enters_per_frame; + // Exit nodes shall all be from the same frame, as we process a frame at a + // time. So, one vector is enough. + std::vector ready_exits; + while (!ready.empty()) { + Node* curr_node = ready.front(); + ready.pop_front(); + + VLOG(4) << "Visiting " << curr_node->name(); + order->push_back(curr_node); + + for (const Edge* out_edge : curr_node->out_edges()) { + Node* out = out_edge->dst(); + int out_id = out->id(); + if (IsNextIteration(curr_node) && IsMerge(out)) { + // Edge NextIteration->Merge has been counted. + continue; + } + ++num_ready_inputs[out->id()]; + if (!out->IsOp()) continue; // Skip Sink/Source nodes. + if (num_ready_inputs[out->id()] != out->in_edges().size()) continue; + + absl::string_view frame_name = control_flow_info_[out_id].frame_name; + if (IsRootEnter(out)) { + ready_enters_per_frame[frame_name].push_back(out); + } else if (IsRootExit(out)) { + ready_exits.push_back(out); + } else { + ready.push_back(out); + } + } + + if (ready.empty()) { + // Try moving nodes from ready_enters_per_frame and ready_exits to + // `ready`. + if (!ready_exits.empty()) { + // If there are nodes in ready_exits we must process them before + // processing ready_enters_per_frame to make sure all nodes in the + // currently processing frame are visited before starting processing + // other frames. + absl::string_view frame_name = + control_flow_info_[ready_exits.front()->id()].frame_name; + CHECK_EQ(ready_exits.size(), num_exits_for_frame[frame_name]); + ready.insert(ready.end(), ready_exits.begin(), ready_exits.end()); + ready_exits.clear(); + } else { + // Otherwise, try moving nodes from ready_enters to `ready`. + for (auto iter = ready_enters_per_frame.begin(); + iter != ready_enters_per_frame.end(); ++iter) { + absl::string_view frame_name = iter->first; + const std::vector& ready_enters = iter->second; + if (ready_enters.size() == num_enters_for_frame[frame_name]) { + ready.insert(ready.end(), ready_enters.begin(), ready_enters.end()); + ready_enters_per_frame.erase(iter); + break; + } + } + } + } + } + + if (!ready_enters_per_frame.empty() || !ready_exits.empty()) { + return errors::InvalidArgument( + "Some enters/exits have never been visited in the traversal." + " Most probably the input graph is malformed."); + } + return Status::OK(); } -Status DeadnessAnalysisImpl::PopulateWithReversePostOrder( - absl::Span rpo) { +// We populate the nodes along a special topological order where nodes having +// the same root frame are placed adjacent to each other. This grouping enables +// processing the graph per root frame at a time and guarantees that when a root +// frame is being processed, nodes in the downstream frames have not yet been +// processed. This property is important because we need to process an entire +// frame to know whether the optimistic mode converges or not. In other words, +// nodes in the downstream frames shall not be populated until all of its +// upstream frames are populated. In effect, this order enables processing each +// (nested) tf.while one-by-one, as each (nested) tf.while creates a unique +// (root) frame. Note that we don't separate while loops belonging to the same +// nested while, as there is no clean cut for separating them in the topological +// order. +Status DeadnessAnalysisImpl::Populate(bool enable_optimistic) { std::vector unreachable_nodes; // Compute the loop structure of the graph. TF_RETURN_IF_ERROR( @@ -1069,14 +1268,63 @@ Status DeadnessAnalysisImpl::PopulateWithReversePostOrder( absl::StrJoin(unreachable_nodes, ", ")); } + std::vector topo; + TF_RETURN_IF_ERROR(GetFrameBasedTopologicalOrder(&topo)); + + size_t frame_start = 0; + while (frame_start < topo.size()) { + // Batching nodes who have the same root frame. + absl::string_view cur_frame_name; + TF_RETURN_IF_ERROR( + GetRootFrame(topo[frame_start], control_flow_info_, &cur_frame_name)); + size_t frame_end = frame_start; + for (size_t i = frame_start + 1; i < topo.size(); ++i) { + absl::string_view i_frame_name; + TF_RETURN_IF_ERROR( + GetRootFrame(topo[i], control_flow_info_, &i_frame_name)); + if (i_frame_name == cur_frame_name) { + frame_end = i; + } else { + break; + } + } + absl::Span sub_topo(topo.data() + frame_start, + /*length=*/frame_end - frame_start + 1); + frame_start = frame_end + 1; + + // First, try the optimistic mode. + bool success = false; + if (enable_optimistic && !cur_frame_name.empty()) { + TF_RETURN_IF_ERROR( + PopulateFrame(sub_topo, /*use_optimistic_mode=*/true, &success)); + } + if (!success) { + // The optimistic mode does not converge. Let's fall back to the + // pessimistic mode. + TF_RETURN_IF_ERROR( + PopulateFrame(sub_topo, /*use_optimistic_mode=*/false, nullptr)); + } + VLOG(2) << "Done populating frame " << cur_frame_name << " using the " + << (success ? "optimistic" : "pessimistic") << " mode."; + } + + return Status::OK(); +} + +Status DeadnessAnalysisImpl::PopulateFrame(absl::Span topo, + bool use_optimistic_mode, + bool* success) { + CHECK(use_optimistic_mode && success != nullptr || + !use_optimistic_mode && success == nullptr); + // This an abstract interpretation over the deadness propagation semantics of // the graph executor. // - // We iterate over the graph twice, each time in RPO. On the first iteration - // merge nodes with backedges are mapped to symbolic predicates. On the - // second iteration we use the predicates assigned to the backedges in the - // previous iteration to infer a more precise predicate for the backedge merge - // nodes and all the nodes that transitively use it. + // We iterate over the graph twice, each time in a topological order. On the + // first iteration merge nodes with backedges are mapped to symbolic + // predicates. On the second iteration we use the predicates assigned to the + // backedges in the previous iteration to infer a more precise predicate for + // the backedge merge nodes and all the nodes that transitively use it. // // We don't track the output indices for should_revisit. Instead, putting a // node in `should_revisit` denotes that the deadness flowing out from any @@ -1086,9 +1334,10 @@ Status DeadnessAnalysisImpl::PopulateWithReversePostOrder( // delta should not change in the second iteration. std::vector should_revisit; should_revisit.resize(graph_.num_node_ids()); - for (Node* n : rpo) { + for (Node* n : topo) { VLOG(4) << "Visiting " << n->name(); - TF_RETURN_IF_ERROR(HandleNode(n, /*should_revisit=*/nullptr)); + TF_RETURN_IF_ERROR( + HandleNode(n, /*should_revisit=*/nullptr, use_optimistic_mode)); if (n->IsNextIteration()) { // If this is a backedge for a merge node then remember to reprocess the // merge the next time we run. @@ -1100,11 +1349,11 @@ Status DeadnessAnalysisImpl::PopulateWithReversePostOrder( } } - for (Node* n : rpo) { + for (Node* n : topo) { // The nodes added to should_revisit in the previous loop need to be // revisited now. Reprocesing these initial nodes may add *their* consumers // to should_revisit, and these newly added nodes will also be processed by - // this very same loop. Since we're traversing the graph in reverse post + // this very same loop. Since we're traversing the graph in topological // order (producers before consumers) and HandleNode(n) can only ever add // n's consumers to should_revisit, we won't "miss" an addition to // should_revisit. @@ -1114,6 +1363,71 @@ Status DeadnessAnalysisImpl::PopulateWithReversePostOrder( } } + // Check if the optimistic analysis converges. Specifically, check whether + // all the predicates of the merge nodes in the same frame are the same. If + // yes, report success. If not, report failure and clear the assigned + // predicates. + if (use_optimistic_mode) { + bool is_converged = true; + absl::flat_hash_map frame_to_pred; + for (Node* n : topo) { + if (!n->IsMerge()) { + continue; + } + const Edge* e; + TF_RETURN_IF_ERROR(FindUniqueBackedge(n, &e)); + if (e == nullptr) { + // Skip acyclic merge nodes. + continue; + } + Node* merge = n; + // Note that here uses frame names instead of root frame names. In the + // case of a nested while loop, each level of while loops can have merges + // with different predicate instances, while the merge nodes on the same + // level must have the same predicate instances. + absl::string_view frame_name = control_flow_info_[merge->id()].frame_name; + auto it = predicate_map_.find(TensorId(merge->name(), 0)); + Predicate* merge_pred = it->second; + if (merge_pred->kind() != Predicate::Kind::kAndRecurrence) { + is_converged = false; + VLOG(2) << "Running the optimistic mode on frame " << frame_name + << " does not converge because node " << merge->name() + << " cannot be mapped into the AndRecurrence form."; + break; + } + + auto insert_result = frame_to_pred.insert({frame_name, merge_pred}); + if (!insert_result.second) { + // If we have already seen this frame name, verify the predicate is the + // same as the previously seen one's. + Predicate* curr_andrec = merge_pred; + Predicate* prev_andrec = insert_result.first->second; + if (curr_andrec != prev_andrec) { + is_converged = false; + VLOG(2) << "Running the optimistic mode on frame " << frame_name + << " does not converge. Seeing different Merge predicates: \n" + << curr_andrec->ToString() << " and \n" + << prev_andrec->ToString(); + break; + } + } + } + + // Clear the assigned predicates if the optimistic mode does not converge. + if (!is_converged) { + for (Node* n : topo) { + for (int oid = 0; oid < n->num_outputs(); ++oid) { + predicate_map_.erase(TensorId(n->name(), oid)); + } + predicate_map_.erase(TensorId(n->name(), Graph::kControlSlot)); + } + } + + if (success != nullptr) { + *success = is_converged; + } + } + return Status::OK(); } @@ -1149,7 +1463,7 @@ DeadnessAnalysis::~DeadnessAnalysis() {} const Graph& graph, std::unique_ptr* result) { std::unique_ptr analysis( new DeadnessAnalysisImpl(&graph)); - TF_RETURN_IF_ERROR(analysis->Populate()); + TF_RETURN_IF_ERROR(analysis->Populate(/*enable_optimistic=*/true)); if (VLOG_IS_ON(2)) { analysis->Print(); @@ -1170,22 +1484,18 @@ DeadnessAnalysisImpl::PredicateMapAsString() const { } namespace deadness_analysis_internal { -Status ComputePredicates(const Graph& graph, - PredicateMapTy* out_predicate_map) { +Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map, + bool enable_optimistic) { DeadnessAnalysisImpl impl(&graph); - TF_RETURN_IF_ERROR(impl.Populate()); + TF_RETURN_IF_ERROR(impl.Populate(enable_optimistic)); *out_predicate_map = impl.PredicateMapAsString(); return Status::OK(); } -Status ComputePredicates(const Graph& graph, - absl::Span reverse_post_order, - PredicateMapTy* out_predicate_map) { - DeadnessAnalysisImpl impl(&graph); - TF_RETURN_IF_ERROR(impl.PopulateWithReversePostOrder(reverse_post_order)); - *out_predicate_map = impl.PredicateMapAsString(); - return Status::OK(); -} } // namespace deadness_analysis_internal +string DeadnessAnalysis::DebugString(DeadnessPredicate predicate) const { + return static_cast(predicate.pred_)->ToString(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/deadness_analysis.h b/tensorflow/compiler/jit/deadness_analysis.h index 08d8ad011bc..c8527de503d 100644 --- a/tensorflow/compiler/jit/deadness_analysis.h +++ b/tensorflow/compiler/jit/deadness_analysis.h @@ -82,6 +82,8 @@ class DeadnessAnalysis { virtual void Print() const = 0; virtual ~DeadnessAnalysis(); + string DebugString(DeadnessPredicate predicate) const; + // Run the deadness analysis over `graph` and returns an error or a populated // instance of DeadnessAnalysis in `result`. static Status Run(const Graph& graph, diff --git a/tensorflow/compiler/jit/deadness_analysis_internal.h b/tensorflow/compiler/jit/deadness_analysis_internal.h index 354782374ad..b2f0e72bc14 100644 --- a/tensorflow/compiler/jit/deadness_analysis_internal.h +++ b/tensorflow/compiler/jit/deadness_analysis_internal.h @@ -25,15 +25,9 @@ namespace deadness_analysis_internal { // Returns a map describing the predicate each Tensor was mapped to. For // testing purposes only. using PredicateMapTy = absl::flat_hash_map; -Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map); +Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map, + bool enable_optimistic = true); -// Returns a map describing the predicate each Tensor was mapped to. For -// testing purposes only. Makes deadness analysis visit the graph in the order -// specified in `reverse_post_order` which must be a valid RPO for the graph -// minus NextIteration->Merge edges. -Status ComputePredicates(const Graph& graph, - absl::Span reverse_post_order, - PredicateMapTy* out_predicate_map); } // namespace deadness_analysis_internal } // namespace tensorflow diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc index 3a44eb7db75..fae1e55c6ba 100644 --- a/tensorflow/compiler/jit/deadness_analysis_test.cc +++ b/tensorflow/compiler/jit/deadness_analysis_test.cc @@ -638,7 +638,22 @@ TEST(DeadnessAnalysisTest, ControlEquivalentLoopBodies) { } { PredicateMapTy predicate_map; - TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, + /*enable_optimistic=*/true)); + + EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)], + "{#true,&,*iv0/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv0)], + predicate_map[ControlOutputFor(iv.induction_var)]); + EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv1)], + predicate_map[ControlOutputFor(iv.induction_var)]); + EXPECT_EQ(predicate_map[ControlOutputFor(add0)], + predicate_map[ControlOutputFor(iv.induction_var)]); + } + { + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, + /*enable_optimistic=*/false)); EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)], "{#true,&,*iv0/cond:0}"); @@ -660,16 +675,6 @@ TEST(DeadnessAnalysisTest, LoopInvariantPredicateOnBackedge) { CreateDependentLoopInvariantValue(root, "div0", "frame", iv.loop_cond, 0); FixupSourceAndSinkEdges(root.graph()); - // To make deadness analysis think that dependent_iv is a loop we need an RPO - // that visits the merge before the backedge. This is a legal RPO for - // deadness analysis since it ignores NextIteration->Merge edges during RPO. - // Right now dependent_iv has an edge from Merge to NextIteration so do the - // RPO with this edge in place. Then remove this edge to get our test case. - std::vector rpo; - GetReversePostOrder(*root.graph(), &rpo, /*stable_comparator=*/{}, - /*edge_filter=*/[](const Edge& edge) { - return !edge.src()->IsNextIteration(); - }); TF_ASSERT_OK(root.graph()->UpdateEdge( iv.induction_var.node(), 0, dependent_iv.latch.output_true.node(), 0)); @@ -677,7 +682,16 @@ TEST(DeadnessAnalysisTest, LoopInvariantPredicateOnBackedge) { { PredicateMapTy predicate_map; - TF_ASSERT_OK(ComputePredicates(*root.graph(), rpo, &predicate_map)); + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, + /*enable_optimistic=*/true)); + + EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)], + "{#true,&,*iv0/cond:0}"); + } + { + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, + /*enable_optimistic=*/false)); EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)], "div0/iv:0"); @@ -731,7 +745,34 @@ TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) { } { PredicateMapTy predicate_map; - TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, + /*enable_optimistic=*/true)); + + EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)], + "{#true,&,*iv_outer/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner.induction_var)], + "{(*iv_outer/cond:0 & " + "{#true,&,*iv_outer/cond:0}),&,*iv_inner/" + "cond:0}"); + + // enable_optimistic = true or not should produce the same results because + // of fallback. However, note that the order of iv_inner/cond:0 and + // iv_inner/iv:0 is different because the optimistic approach does not + // create predicates for all merges and it can change the predicate id and + // hence the symbol order. + EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv0)], + "{{#true,&,(iv_outer/iv:0 & " + "*iv_outer/cond:0)},&,(*iv_inner/cond:0 & " + "iv_inner/iv:0)}"); + EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv1)], + predicate_map[ControlOutputFor(dependent_inner_iv0)]); + EXPECT_EQ(predicate_map[ControlOutputFor(add0)], + predicate_map[ControlOutputFor(dependent_inner_iv0)]); + } + { + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, + /*enable_optimistic=*/false)); EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)], "{#true,&,*iv_outer/cond:0}"); @@ -744,15 +785,10 @@ TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) { "{{#true,&,(iv_outer/iv:0 & " "*iv_outer/cond:0)},&,(iv_inner/iv:0 & " "*iv_inner/cond:0)}"); - EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv1)], - "{{#true,&,(iv_outer/iv:0 & " - "*iv_outer/cond:0)},&,(iv_inner/iv:0 & " - "*iv_inner/cond:0)}"); + predicate_map[ControlOutputFor(dependent_inner_iv0)]); EXPECT_EQ(predicate_map[ControlOutputFor(add0)], - "{{#true,&,(iv_outer/iv:0 & " - "*iv_outer/cond:0)},&,(iv_inner/iv:0 & " - "*iv_inner/cond:0)}"); + predicate_map[ControlOutputFor(dependent_inner_iv0)]); } } @@ -817,6 +853,104 @@ TEST(DeadnessAnalysisTest, ControlNonEquivalentNestedLoopBodies) { } } +TEST(DeadnessAnalysisTest, NestedLoopBodiesWithACapture) { + Scope root = Scope::NewRootScope().ExitOnError(); + InductionVarInfo iv_outer = + CreateInductionVariable(root, "iv_outer", "outer_loop", 0); + Output enter_constant_outer_loop = ops::internal::Enter( + root.WithOpName("constant_enter_outer_loop"), + ops::Const(root.WithOpName("constant"), 5), "outer_loop", + ops::internal::Enter::Attrs().IsConstant(true)); + ops::Switch inner_value(root.WithOpName("outer_is_live"), + enter_constant_outer_loop, iv_outer.loop_cond); + InductionVarInfo iv_inner = CreateInductionVariable( + root, "iv_inner", "inner_loop", inner_value.output_true); + + DependentInductionVar div0_outer = CreateDependentLoopInvariantValue( + root, "div0_outer", "outer_loop", iv_outer.loop_cond, 0); + DependentInductionVar div1_outer = CreateDependentLoopInvariantValue( + root, "div1_outer", "outer_loop", iv_outer.loop_cond, 0); + + DependentInductionVar div0_inner = CreateDependentLoopInvariantValue( + root, "div0_inner", "inner_loop", iv_inner.loop_cond, + div0_outer.induction_var); + DependentInductionVar div1_inner = CreateDependentLoopInvariantValue( + root, "div1_inner", "inner_loop", iv_inner.loop_cond, + div1_outer.induction_var); + + Output captured = ops::_Recv(root.WithOpName("captured"), DT_INT32, + "tensor_a", "sender", 0, "receiver"); + Output capture_enter_outer = ops::internal::Enter( + root.WithOpName("capture_enter_outer"), captured, "outer_loop", + ops::internal::Enter::Attrs().IsConstant(true)); + Output capture_enter_inner = ops::internal::Enter( + root.WithOpName("capture_enter_inner"), capture_enter_outer, "inner_loop", + ops::internal::Enter::Attrs().IsConstant(true)); + Output mul0 = ops::Mul(root.WithOpName("mul0"), div1_inner.induction_var, + capture_enter_inner); + TF_ASSERT_OK(root.graph()->UpdateEdge( + mul0.node(), 0, div1_inner.latch.output_true.node(), 0)); + + Output add0 = ops::Add(root.WithOpName("add0"), div0_inner.induction_var, + div1_inner.induction_var); + + VLogGraphIfAsked(*root.graph()); + + { + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + TF_ASSERT_OK_AND_ASSIGN( + bool has_inputs_with_mismatching_deadness, + HasInputsWithMismatchingDeadness(*result, *add0.node())); + EXPECT_TRUE(has_inputs_with_mismatching_deadness); + } +} + +TEST(DeadnessAnalysisTest, CyclicRecurrence) { + Scope root = Scope::NewRootScope().ExitOnError(); + InductionVarInfo iv = CreateInductionVariable(root, "iv0", "loop", 0); + DependentInductionVar div0 = + CreateDependentLoopInvariantValue(root, "div0", "loop", iv.loop_cond, 0); + DependentInductionVar div1 = + CreateDependentLoopInvariantValue(root, "div1", "loop", iv.loop_cond, 0); + FixupSourceAndSinkEdges(root.graph()); + TF_ASSERT_OK(root.graph()->UpdateEdge(div1.induction_var.node(), 0, + div0.latch.output_true.node(), 0)); + TF_ASSERT_OK(root.graph()->UpdateEdge(div0.induction_var.node(), 0, + div1.latch.output_true.node(), 0)); + + VLogGraphIfAsked(*root.graph()); + + { + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, + /*enable_optimistic=*/true)); + + EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)], + "{#true,&,*iv0/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(div0.induction_var)], + "{#true,&,*iv0/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(div1.induction_var)], + "{#true,&,*iv0/cond:0}"); + + // This tests the rule {S,&,X} & ~X => S. + TensorId switch_false_out = {div1.latch.output_false.node()->name(), + div1.latch.output_false.index()}; + EXPECT_EQ(predicate_map[switch_false_out], "(#true)"); + } + { + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, + /*enable_optimistic=*/false)); + + EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)], + "{#true,&,*iv0/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(div0.induction_var)], "div0/iv:0"); + EXPECT_EQ(predicate_map[ControlOutputFor(div1.induction_var)], "div1/iv:0"); + } +} + TEST(DeadnessAnalysisTest, AndRecurrenceNeedsFrameName) { Scope root = Scope::NewRootScope().ExitOnError(); InductionVarInfo iv_0 = CreateInductionVariable(root, "iv_0", "frame_0", 10); diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index b6d97434eb0..982803d501f 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -144,17 +144,8 @@ static const char* const kRecvAtHostOp = "_XlaRecvAtHost"; class Encapsulator { public: - Encapsulator(string group_attribute, string outside_compilation_attribute, - Graph const* graph_in) - : group_attribute_(std::move(group_attribute)), - outside_compilation_attribute_( - std::move(outside_compilation_attribute)), - graph_in_(graph_in) {} - - // Find dependencies between subgraphs and outside_compilation clusters that - // only manifest via edges between outside_compilation clusters in the outer - // (non-compiled) graph. - Status FindClusterDependencies(); + Encapsulator(string group_attribute, Graph const* graph_in) + : group_attribute_(std::move(group_attribute)), graph_in_(graph_in) {} // Find subgraphs marked with 'group_attribute', and build a new // subgraph, one for each value of 'group_attribute'. @@ -176,68 +167,22 @@ class Encapsulator { private: // A subgraph of the input, all marked with a common 'group_attribute' - // value. A subgraph may contain multiple `outside_compilation' clusters. + // value. // // In the following simple example, A, B, ..., E are nodes in the original - // graph. The group attributes and outside_compilation attributes g and oc are - // each shown as either 0 or empty. + // graph. The group attributes g are each shown as either 0 or empty. // // A --> B --> C --> D --> E // g: g:0 g:0 g:0 g: - // oc: oc: oc:0 oc: oc: // // The example is rewritten to two graphs; one on the host and one to be - // compiled. The host graph is as follows. RAH is a RecvAtHost node receiving - // input from the compiled cluster, and SFH is a SendFromHost node sending - // input back to the compiled cluster. Dotted edges are control edges. A - // 'sequencing' node S is inserted, and both RAH and SFH are connected via S - // to E (and in general all nodes that depend on nodes in the compiled - // cluster) to ensure that they are not pruned. + // compiled. The host graph is as follows. // // A --> Call --> E - // ^ - // . - // ........> S - // .... ^ - // .. . - // RAH --> C --> SFH // - // The compiled cluster is as follows. HC is a HostCompute node which is the - // source of a channel to the RAH node above and the destination of a channel - // from the SFH node above. + // The compiled cluster is as follows. // - // Arg --> B --> HC --> D --> Retval - // - // The channels HC/RAH and SFH/HC each transmit multiple tensors, so there is - // at most one RAH and SFH in each outside_compilation cluster. This design is - // preferred over adding separate Arg/Retval nodes for each transmitted value - // because it allows optimizations to the host code that would like to limit - // communication between host and device and, e.g., raise only one interrupt - // per channel rather than one per transmitted value. - // - // The shapes of the outputs from the HC node in general cannot be determined - // until the shapes of its inputs are known at compile time, since e.g., - // above, the shape of C's outputs aren't known until the shape of its inputs - // are known. If the shapes of the HC's outputs can be determined during the - // rewrite, they are stored in the node's 'shapes' attr. Otherwise a minimal - // graph is stored in the shape_inference_graph attr. This graph can be used - // when compiling the HC Op to determined the shape of the SFH inputs given - // the shapes of any ancestor RAH outputs. If it can be determined that the - // shape of the SFH inputs will not be inferrable even once the shapes of the - // RAH outputs are known, an error is returned by the rewriter. - // - // Once edges between compiled and outside_compilation clusters have been - // replaced by send/recv ops, some dependencies may no longer be apparent. - // A clustering pass finds all the dependencies between HC nodes that are only - // present as a result of edges between nodes in outside_compilation clusters. - // Suppose there is a path from outside_compilation cluster C in subgraph S - // to outside_compilation cluster D in subgraph T. If S != T then a control - // edge is added from the call node for S to the call node for T, which - // ensures that C will execute before D because S executes before T. If S==T - // then a control dependency is added between the HC nodes for C and D in S, - // and the HC node for C is added to an 'ancestors' attr in the HC node for D - // so that during compilation of the HC node for D, an XLA control dependency - // can be added to ensure C's SendToHost executes before D's RecvFromHost. + // Arg --> B --> C --> D --> Retval class Subgraph { public: // Creates a graph to build the subgraph in, if it doesn't already exist, @@ -262,17 +207,6 @@ class Encapsulator { const std::unordered_map& node_images, Graph* graph_out); - // Adds _RecvAtHost and _SendFromHost nodes, where needed, to graph_out. - Status AddOutsideCompilationHostIONodes( - const string& group_attribute, const string& subgraph_name, - const string& outside_compilation_attribute, - const std::unordered_map& node_images, - Graph* graph_out); - - // Returns the names of all the outside_compilation subgraphs in this - // Subgraph. - void GetOutsideCompilationSubgraphNames(std::vector* names) const; - // Returns the Node that the inputs and outputs of the function should be // wired up to. Node* GetCallNode() const; @@ -283,24 +217,6 @@ class Encapsulator { // Returns the index of the result that the src of edge should connect to. int GetResultIndexForEdge(const Edge* edge) const; - // Returns the RecvAtHost node for an outside_compilation subgraph. - Node* GetRecvAtHostNode( - const string& outside_compilation_subgraph_name) const; - - // Returns the output slot for the RecvAtHost node that corresponds to the - // source of edge in an outside_compilation subgraph. - int GetRecvAtHostSlot(const string& outside_compilation_subgraph_name, - const Edge* edge) const; - - // Returns the SendFromHost node for an outside_compilation subgraph. - Node* GetSendFromHostNode( - const string& outside_compilation_subgraph_name) const; - - // Returns the input slot for the SendFromHost node that corresponds to the - // destination of edge in an outside_compilation subgraph. - int GetSendFromHostSlot(const string& outside_compilation_subgraph_name, - const Edge* edge) const; - // Creates an _Arg node for the src node of edge, and add its index to // args_by_src_, if none exists yet. Also adds its index to args_by_dst_, // and adds the edge within the subgraph from the _Arg node to the image of @@ -323,37 +239,6 @@ class Encapsulator { const Edge* edge, const std::unordered_map& node_images); - // Creates an outside_compilation subgraph for outside_compilation_id if - // none exists yet. Creates an entry for the src node of edge in the list of - // inputs for the outside_compilation subgraph, if none exists yet. - void RecordOutsideCompilationInputOrControl( - const string& outside_compilation_id, const Edge* edge); - - // Creates an outside_compilation subgraph for outside_compilation_id if - // none exists yet. Creates an entry for the src node of edge in the list of - // outputs by src for the outside_compilation subgraph, if none exists - // yet. Creates an entry for the dst node of edge in the list of outputs by - // dst for the outside_compilation subgraph. - void RecordOutsideCompilationOutputOrControl( - const string& outside_compilation_id, const Edge* edge); - - // Records the fact that there is a path from a node in outside_compilation - // cluster ancestor to node in cluster successor that does not go through - // the subgraph. - void RecordOutsideCompilationDependency(const string& successor, - const string& ancestor); - - // Returns the mapping from outside_compilation cluster C to the set of - // outside_compilation clusters that have a path to C entirely outside - // compiled subgraphs. - const std::unordered_map> - OutsideCompilationAncestorMap() const; - - // Adds the HostCompute nodes for each outside_compilation subgraph. - Status AddHostComputes( - const string& subgraph_name, - const std::unordered_map& node_images); - // Creates the sequencer node if it doesn't exist, adding it to graph_out. Status MakeSequencingNode(const string& subgraph_name, Graph* graph_out); @@ -361,102 +246,9 @@ class Encapsulator { // the call node. void ConnectSequencerToCallNode(Graph* graph_out); - Status AddShapeInferenceInfo( - const string& subgraph_name, - const string& outside_compilation_subgraph_name, - const std::vector& shapes, Graph* inference_graph, - FunctionLibraryDefinition* library); - Status ReplaceFunctionDef(FunctionLibraryDefinition* library); private: - struct OutsideCompilationSubgraph { - // Map from source (producer node/slot) tensors in the original graph to - // input index (slot number in the HostCompute/RecvAtHost nodes that will - // be created) for the outside_compilation subgraph. - std::unordered_map inputs; - - // Set of nodes in the original graph that are the source of control edges - // that cross from the containing compiled subgraph into the - // outside_compilation subgraph. These are recorded by - // RecordOutsideCompilationInputOrControl while walking all the subgraph - // edges, and lifted control edges within the subgraph are added by - // AddSendsToOutsideCompilation once the _HostCompute node has been - // created. The matching control edge from _RecvAtHost to the - // destination is added by CopyEdgeToOutputGraph. - std::unordered_set control_inputs; - - // Maps from source (producer node/slot) and destination (consumer - // node/slot) tensors in the original graph to output index (slot number - // in the SendFromHost/HostCompute nodes that will be created) for the - // outside_compilation subgraph. - struct ArgNumAndType { - int index; - DataType dtype; - - ArgNumAndType(int i, DataType t) : index(i), dtype(t) {} - }; - std::unordered_map - outputs_by_src; - std::unordered_map outputs_by_dst; - - // Set of nodes in the original graph that are the destination of control - // edges that cross from the outside_compilation subgraph into the - // containing compiled subgraph. These are recorded by - // RecordOutsideCompilationOutputOrControl while walking all the subgraph - // edges, and lifted control edges within the subgraph are added by - // AddRecvsFromToOutsideCompilation once the _HostCompute node has been - // created. The matching control edge from the source to _SendFromHost to - // the destination is added by CopyEdgeToOutputGraph. - std::unordered_set control_outputs; - - // Name of the _HostCompute node in the subgraph. - string host_compute_name; - - // _RecvAtHost node in the output graph. Not owned. - Node* recv_at_host = nullptr; - - // _SendFromHost node in the output graph. Not owned. - Node* send_from_host = nullptr; - }; - - // Creates an outside_compilation subgraph for outside_compilation_id if - // none exists yet. Returns the (possible newly created) subgraph for - // outside_compilation_id. - OutsideCompilationSubgraph* LookupOrCreateOutsideCompilationSubgraph( - const string& outside_compilation_id); - - // Builds a placeholder node used to provide the key input to a RecvAtHost - // or SendFromHost node. This placeholder node will be removed by a later - // pass. - Status AddHostComputeKeyPlaceholder(OutsideCompilationSubgraph* oc_subgraph, - Graph* graph_out); - - // Get the set of outside_compilation clusters and the dependency edges - // between them. - void GetActiveClusterDependencyGraph( - std::unordered_set* clusters, - std::unordered_set* has_successor, - std::unordered_map>* ancestors_map); - - // Builds a _RecvAtHost node producing all the inputs of an - // outside_compilation subgraph and stores it in oc_subgraph.recv_at_host. - Status AddRecvAtHostNode(const string& group_attribute, - const string& subgraph_name, - const string& outside_compilation_attribute, - const string& oc_subgraph_name, - OutsideCompilationSubgraph* oc_subgraph, - Graph* graph_out); - - // Builds a _SendFromHost node consuming all the outputs of an - // outside_compilation subgraph and stores it in oc_subgraph.send_from_host. - Status AddSendFromHostNode( - const std::unordered_map& node_images, - const string& group_attribute, const string& subgraph_name, - const string& outside_compilation_attribute, - const string& oc_subgraph_name, OutsideCompilationSubgraph* oc_subgraph, - Graph* graph_out); - // The subgraph extracted from the input graph, suitable for being turned // into a FunctionDef. Inputs are fed by _Arg nodes, and outputs are // returned by _Retval nodes. @@ -498,31 +290,13 @@ class Encapsulator { // removed from the graph. absl::flat_hash_set control_output_nodes_; - // The outside_compilation clusters in this subgraph. - std::unordered_map - outside_compilation_subgraphs_; - // For each outside_compilation cluster C, the outside_compilation clusters - // that have a path to C outside the compiled graph. - std::unordered_map> - outside_compilation_ancestors_; - // For each outside_compilation cluster C, the outside_compilation clusters - // that have a path from C outside the compiled graph. - std::unordered_map> - outside_compilation_successors_; - - // NoOp node in the output graph that is sequenced after the call node and - // used to prevent host-side outside_compilation sends and recvs from being - // pruned. + // NoOp node in the output graph that is sequenced after the call node. Node* sequencer_ = nullptr; }; - // Returns the key attribute and outside_compilation attribute associated - // with a node in attr, and outside_compilation_attr, respectively. Sets - // either result to the empty string if the respective attribute is not - // found. Returns error status if there is an outside_compilation attribute - // and no key attribute, - Status GetFunctionNameAttr(Node const* node, string* attr, - string* outside_compilation_attr) const; + // Returns the key attribute associated with a node in attr. Sets either + // result to the empty string if the respective attribute is not found. + Status GetFunctionNameAttr(Node const* node, string* attr) const; // Copies edges local to a subgraph. Adds _Arg and _Retval nodes to // subgraphs for data edges that cross subgraph boundaries. @@ -530,8 +304,7 @@ class Encapsulator { const std::unordered_map& node_images, std::vector>* src_arg_pairs); - // Copies all marked nodes to a subgraph. Does nothing for unmarked nodes, - // or nodes marked outside_compilation. + // Copies all marked nodes to a subgraph. Does nothing for unmarked nodes. Status CopySubgraphNodes(std::unordered_map* node_images); // Copies all nodes that aren't in a compiled subgraph to the output graph. @@ -543,92 +316,50 @@ class Encapsulator { const std::unordered_map& node_images, Graph* graph_out); - // Adds _RecvAtHost and _SendFromHost nodes, where needed, for all - // outside_compilation subgraphs. - Status AddOutsideCompilationHostIONodes( - const std::unordered_map& node_images, - Graph* graph_out); - // Finds the image of an edge source in the output graph. If the edge crosses // a subgraph boundary it is the output of a call node, otherwise it is a node // in the output graph. Status FindOutputImageOfEdgeSrc( - const string& src_func_id, const string& src_outside_compilation_id, - const string& dst_func_id, const string& dst_outside_compilation_id, + const string& src_func_id, const string& dst_func_id, const std::unordered_map& node_images, const Node* original_src_node, Node** src_image); // Finds an edge source slot in the output graph. If the edge crosses a - // subgraph boundary it is a slot on the output of a call node or a - // _RecvAtHost node, otherwise it is a slot on a node in the output graph. + // subgraph boundary it is a slot on the output of a call node, otherwise it + // is a slot on a node in the output graph. int FindOutputSlotOfEdgeSrc(const string& src_func_id, - const string& src_outside_compilation_id, const string& dst_func_id, - const string& dst_outside_compilation_id, const Edge* edge); // Finds the image of an edge destination in the output graph. If the edge - // crosses a subgraph boundary it is the input of a call node or a - // _SendFromHost node, otherwise it is a node in the output graph. + // crosses a subgraph boundary it is the input of a call node, otherwise it is + // a node in the output graph. Status FindOutputImageOfEdgeDst( - const string& src_func_id, const string& src_outside_compilation_id, - const string& dst_func_id, const string& dst_outside_compilation_id, + const string& src_func_id, const string& dst_func_id, const std::unordered_map& node_images, const Node* original_dst_node, Node** dst_image); // Finds an edge destination slot in the output graph. If the edge crosses a - // subgraph boundary it is a slot on the input of a call node or a - // _SendFromHost node, otherwise it is a slot on a node in the output graph. + // subgraph boundary it is a slot on the input of a call node, otherwise it is + // a slot on a node in the output graph. int FindOutputSlotOfEdgeDst(const string& src_func_id, - const string& src_outside_compilation_id, const string& dst_func_id, - const string& dst_outside_compilation_id, const Edge* edge); // Copies a single edge to the output graph. The edge is either entirely // within the output graph, or crosses into or out of a compiled subgraph. Status CopyEdgeToOutputGraph( - const Edge* edge, const string& src_func_id, - const string& src_outside_compilation_id, const string& dst_func_id, - const string& dst_outside_compilation_id, + const Edge* edge, const string& src_func_id, const string& dst_func_id, const std::unordered_map& node_images, Graph* graph_out, std::unordered_set, OutputInputTensorPairHasher>* edges_added); - // Adds control dependencies between subgraph call nodes that have - // dependencies via outside_compilation edges. - Status AddCallNodeDependencies(Graph* graph_out); - // Adds all edges to the output graph. Status AddEdgesToOutputGraph( const std::unordered_map& node_images, Graph* graph_out); - // Constructs a minimal shape inference graph that can be used to determine - // the shape of send_node at the time that the subgraph is compiled. - // recv_at_host_nodes contains the names of all the recv_at_host nodes that - // send_node might depend on. These recv_at_host nodes have shapes that are - // not known during the rewrite pass, but will be known at compile time. - // - // If the shapes of all the inputs to send_node can be determined during the - // rewrite pass, on exit graphdef_out is empty and the shapes are returned in - // static_shape_out. Otherwise graphdef_out contains a graph that can be used - // for shape inference at compile time, where all the source nodes of the - // graph are either constants with known shapes, or nodes named in - // recv_at_host_nodes. - // - // A non-OK status is returned if neither of the above conditions can be - // satisfied, e.g., because send_node depends on a node that doesn't have a - // registered shape inference function. - Status DoStaticShapeInferenceForOutsideCompilationSend( - const Graph& graph_in, const BackEdgeHelper& back_edge_helper, - const ShapeRefiner& shape_refiner, - const std::unordered_set& recv_at_host_nodes, Node* send_node, - FunctionLibraryDefinition* library, - std::vector* static_shape_out, - std::unique_ptr* graph_out); - // Makes a copy of graph containing only nodes that are ancestors of at least // one node in send_from_host_nodes and store it in pruned_graph. On exit // nodes_images contains a mapping from nodes in graph to nodes in @@ -639,35 +370,10 @@ class Encapsulator { std::unordered_map* node_images, FunctionLibraryDefinition* library); - // Makes a copy of graph containing only nodes that are ancestors of a - // send_from_host node in an outside_compilation subgraph, and store it in - // pruned_graph. Also perform shape inference on the pruned graph, using - // shape_refiner. On exit node_images contains a mapping from nodes in graph - // to nodes in pruned_graph. - Status MakeGraphForOutsideCompilationSends( - const Graph& graph, std::unique_ptr* pruned_graph, - BackEdgeHelper* back_edge_helper, ShapeRefiner* shape_refiner, - std::unordered_map* node_images, - FunctionLibraryDefinition* library); - - // Performs static shape inference, as far as possible, for the send_from_host - // nodes in each outside_compilation subgraph. Where it is not possible to - // determine the shape statically, stores a serialized GraphDef in the - // HostCompute 'shape_inference_graph' attr, to be used at compile time for - // final inference. If the shapes are known statically they are stored in the - // HostCompute 'shapes' attr. - Status GetShapeInfoForOutsideCompilationSends( - Graph* graph_out, FunctionLibraryDefinition* library); - const string group_attribute_; - const string outside_compilation_attribute_; const Graph* graph_in_; std::unordered_map subgraphs_; - // For each subgraph S the subgraphs S' such that there is a path in some - // outside_compilation cluster C in S to some outside_compilation cluster C' - // in S', that goes only through the uncompiled graph. - std::unordered_map> subgraph_ancestors_; TF_DISALLOW_COPY_AND_ASSIGN(Encapsulator); }; @@ -733,30 +439,6 @@ int Encapsulator::Subgraph::GetResultIndexForEdge(const Edge* edge) const { return results_.at(OutputTensor(edge->src(), edge->src_output())); } -Node* Encapsulator::Subgraph::GetRecvAtHostNode( - const string& outside_compilation_subgraph_name) const { - return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name) - .recv_at_host; -} - -int Encapsulator::Subgraph::GetRecvAtHostSlot( - const string& outside_compilation_subgraph_name, const Edge* edge) const { - return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name) - .inputs.at(OutputTensor(edge->src(), edge->src_output())); -} - -Node* Encapsulator::Subgraph::GetSendFromHostNode( - const string& outside_compilation_subgraph_name) const { - return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name) - .send_from_host; -} - -int Encapsulator::Subgraph::GetSendFromHostSlot( - const string& outside_compilation_subgraph_name, const Edge* edge) const { - return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name) - .outputs_by_dst.at(InputTensor(edge->dst(), edge->dst_input())); -} - Node* Encapsulator::Subgraph::MakeNodeImage(const Graph* graph_in, Node* node) { if (!graph_) { graph_.reset(new Graph(graph_in->op_registry())); @@ -854,217 +536,6 @@ Status Encapsulator::Subgraph::RecordResult( return Status::OK(); } -Encapsulator::Subgraph::OutsideCompilationSubgraph* -Encapsulator::Subgraph::LookupOrCreateOutsideCompilationSubgraph( - const string& outside_compilation_id) { - auto iter = outside_compilation_subgraphs_ - .emplace(outside_compilation_id, OutsideCompilationSubgraph()) - .first; - OutsideCompilationSubgraph* outside_subgraph = &iter->second; - return outside_subgraph; -} - -void Encapsulator::Subgraph::RecordOutsideCompilationInputOrControl( - const string& outside_compilation_id, const Edge* edge) { - OutsideCompilationSubgraph* outside_subgraph = - LookupOrCreateOutsideCompilationSubgraph(outside_compilation_id); - if (edge->IsControlEdge()) { - outside_subgraph->control_inputs.insert(edge->src()); - } else { - int input_index = outside_subgraph->inputs.size(); - outside_subgraph->inputs.emplace( - OutputTensor(edge->src(), edge->src_output()), input_index); - } -} - -void Encapsulator::Subgraph::RecordOutsideCompilationOutputOrControl( - const string& outside_compilation_id, const Edge* edge) { - OutsideCompilationSubgraph* outside_subgraph = - LookupOrCreateOutsideCompilationSubgraph(outside_compilation_id); - if (edge->IsControlEdge()) { - outside_subgraph->control_outputs.insert(edge->dst()); - } else { - DataType dtype = edge->dst()->input_type(edge->dst_input()); - auto output_iter = - outside_subgraph->outputs_by_src - .emplace(OutputTensor(edge->src(), edge->src_output()), - OutsideCompilationSubgraph::ArgNumAndType( - outside_subgraph->outputs_by_src.size(), dtype)) - .first; - const int output_index = output_iter->second.index; - outside_subgraph - ->outputs_by_dst[InputTensor(edge->dst(), edge->dst_input())] = - output_index; - } -} - -void Encapsulator::Subgraph::RecordOutsideCompilationDependency( - const string& successor, const string& ancestor) { - outside_compilation_ancestors_[successor].insert(ancestor); - outside_compilation_successors_[ancestor].insert(successor); -} - -const std::unordered_map> -Encapsulator::Subgraph::OutsideCompilationAncestorMap() const { - return outside_compilation_ancestors_; -} - -void Encapsulator::Subgraph::GetActiveClusterDependencyGraph( - std::unordered_set* clusters, - std::unordered_set* has_successor, - std::unordered_map>* ancestors_map) { - // During initial clustering the ancestor and successor datastructures may - // have been built including oc_cluster names that never turned into subgraphs - // because they had no edges into or out of the compiled cluster. Remove them - // before proceeding to simplify the logic. Get the set of clusters that was - // actually added, then remove references to the others. - for (const auto& oc_subgraph : outside_compilation_subgraphs_) { - clusters->insert(oc_subgraph.first); - } - for (const auto& cluster : outside_compilation_successors_) { - if (clusters->find(cluster.first) != clusters->end()) { - for (const auto& successor : cluster.second) { - if (clusters->find(successor) != clusters->end()) { - has_successor->insert(cluster.first); - break; - } - } - } - } - for (const auto& cluster : outside_compilation_ancestors_) { - if (clusters->find(cluster.first) != clusters->end()) { - std::unordered_set& ancestors = (*ancestors_map)[cluster.first]; - for (const auto& ancestor : cluster.second) { - if (clusters->find(ancestor) != clusters->end()) { - ancestors.insert(ancestor); - } - } - } - } -} - -Status Encapsulator::Subgraph::AddHostComputes( - const string& subgraph_name, - const std::unordered_map& node_images) { - // Get the set of outside_compilation clusters and the dependency edges - // between them. - std::unordered_set clusters; - std::unordered_set has_successor; - std::unordered_map> ancestors_map; - GetActiveClusterDependencyGraph(&clusters, &has_successor, &ancestors_map); - // Topologically sort the outside_compilation clusters according to their - // dependency relation. - std::vector sorted_clusters; - TopologicalClusterSort(clusters, has_successor, ancestors_map, - &sorted_clusters); - - // The host compute nodes added for each outside_compilation_cluster; - std::unordered_map host_compute_node; - for (const string& oc_subgraph_name : sorted_clusters) { - OutsideCompilationSubgraph& oc_subgraph = - outside_compilation_subgraphs_[oc_subgraph_name]; - if (!oc_subgraph.inputs.empty() || !oc_subgraph.control_inputs.empty() || - !oc_subgraph.outputs_by_src.empty() || - !oc_subgraph.control_outputs.empty()) { - // Build a _HostCompute node. - std::vector inputs(oc_subgraph.inputs.size()); - std::vector input_dtypes(oc_subgraph.inputs.size(), DT_INVALID); - std::vector output_dtypes(oc_subgraph.outputs_by_src.size(), - DT_INVALID); - - for (const auto& input_src : oc_subgraph.inputs) { - const Node* src_node = input_src.first.node; - Node* src_image = node_images.at(src_node); - int src_slot = input_src.first.index; - int input_index = input_src.second; - - DataType dtype = src_node->output_type(src_slot); - inputs[input_index].Reset(src_image->name(), src_slot, dtype); - input_dtypes[input_index] = dtype; - } - for (const auto& output : oc_subgraph.outputs_by_src) { - DataType dtype = output.second.dtype; - int output_index = output.second.index; - output_dtypes[output_index] = dtype; - } - - std::vector host_compute_ancestors; - const auto iter = ancestors_map.find(oc_subgraph_name); - if (iter != ancestors_map.end()) { - for (const string& ancestor_cluster : iter->second) { - host_compute_ancestors.push_back( - outside_compilation_subgraphs_[ancestor_cluster] - .host_compute_name); - } - } - - NodeDef host_compute_def; - // TODO(shikharagarwal): What source node should we use for errors? - NodeDefBuilder builder(absl::StrCat("outside_compilation_", - oc_subgraph_name, "_host_compute"), - kHostComputeOp); - builder.Input(inputs); - builder.Attr("Tinputs", input_dtypes); - builder.Attr("Toutputs", output_dtypes); - builder.Attr("ancestors", host_compute_ancestors); - builder.Attr("key", absl::StrCat("host_compute_channel_", subgraph_name, - "_", oc_subgraph_name)); - builder.Attr("_outside_compilation_subgraph", oc_subgraph_name); - Status s = builder.Finalize(&host_compute_def); - if (!s.ok()) return s; - - Node* host_compute = graph_->AddNode(host_compute_def, &s); - if (!s.ok()) return s; - host_compute_node[host_compute->name()] = host_compute; - oc_subgraph.host_compute_name = host_compute->name(); - - // Connect the _HostCompute node to its producers in the subgraph. - for (auto& input_src : oc_subgraph.inputs) { - const Node* src_node = input_src.first.node; - Node* src_image = node_images.at(src_node); - int src_slot = input_src.first.index; - int input_index = input_src.second; - graph_->AddEdge(src_image, src_slot, host_compute, input_index); - } - - // Connect the _HostCompute node to its control edge producers in the - // subgraph. - for (const auto& src_node : oc_subgraph.control_inputs) { - Node* src_image = node_images.at(src_node); - graph_->AddControlEdge(src_image, host_compute, - /* allow_duplicates= */ true); - } - - // Connect the _HostCompute node to its ancestor host compute nodes. - for (const auto& ancestor_name : host_compute_ancestors) { - Node* ancestor = host_compute_node[ancestor_name]; - graph_->AddControlEdge(ancestor, host_compute, - /* allow_duplicates= */ true); - } - - // Connect the consumers in the subgraph to the _HostCompute node. - for (const auto& output : oc_subgraph.outputs_by_dst) { - const Node* dst_node = output.first.node; - Node* dst_image = node_images.at(dst_node); - int dst_slot = output.first.index; - int output_index = output.second; - - graph_->AddEdge(host_compute, output_index, dst_image, dst_slot); - } - - // Connect the control edge consumers in the subgraph to the _HostCompute - // node. - for (const auto& dst_node : oc_subgraph.control_outputs) { - Node* dst_image = node_images.at(dst_node); - graph_->AddControlEdge(host_compute, dst_image, - /* allow_duplicates= */ true); - } - } - } - - return Status::OK(); -} - Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name, Graph* graph_out) { if (sequencer_ == nullptr) { @@ -1167,48 +638,6 @@ Status Encapsulator::Subgraph::BuildFunctionDef( return Status::OK(); } -Status Encapsulator::Subgraph::AddShapeInferenceInfo( - const string& subgraph_name, - const string& outside_compilation_subgraph_name, - const std::vector& shapes, Graph* inference_graph, - FunctionLibraryDefinition* library) { - OutsideCompilationSubgraph& oc_subgraph = - outside_compilation_subgraphs_.at(outside_compilation_subgraph_name); - - Node* host_compute = nullptr; - for (Node* n : graph_->nodes()) { - if (n->name() == oc_subgraph.host_compute_name) { - host_compute = n; - break; - } - } - if (host_compute == nullptr) { - return errors::InvalidArgument( - "After rewriting subgraph ", outside_compilation_subgraph_name, - " there is no HostCompute Op for outside compilation subgraph ", - oc_subgraph.host_compute_name); - } - - if (inference_graph == nullptr) { - host_compute->AddAttr("shape_inference_graph", ""); - host_compute->AddAttr("shapes", shapes); - } else { - string inference_graph_name = - absl::StrCat("_outside_compilation_shape_inference_", subgraph_name, - "_", outside_compilation_subgraph_name); - FunctionDef fdef; - TF_RETURN_IF_ERROR( - GraphToFunctionDef(*inference_graph, inference_graph_name, &fdef)); - host_compute->AddAttr("shape_inference_graph", inference_graph_name); - host_compute->AddAttr("shapes", std::vector()); - // TODO(sibyl-Aix6ihai): Understand why there are multiple calls to Encapsulator. - if (library->Find(inference_graph_name) == nullptr) { - TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef)); - } - } - return Status::OK(); -} - Status Encapsulator::Subgraph::ReplaceFunctionDef( FunctionLibraryDefinition* library) { const string& name = function_def_name_; @@ -1241,214 +670,29 @@ Status Encapsulator::Subgraph::AddFunctionCallNode( return Status::OK(); } -Status Encapsulator::Subgraph::AddHostComputeKeyPlaceholder( - OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out) { - TensorShapeProto shape_proto; - TensorShape shape({2}); - shape.AsProto(&shape_proto); - GraphDefBuilder::Options options(graph_out, /*status=*/nullptr); - NodeDef key_def; - NodeDefBuilder builder( - absl::StrCat(call_node_def_.name(), "_key_placeholder"), "Placeholder", - NodeDebugInfo(call_node_def_)); - builder.Attr("dtype", DT_STRING); - builder.Attr("shape", shape_proto); - builder.Attr("_host_compute_call_node", call_node_def_.name()); - Status s = builder.Finalize(&key_def); - if (!s.ok()) return s; - - host_compute_key_placeholder_ = graph_out->AddNode(key_def, &s); - if (!s.ok()) return s; - host_compute_key_placeholder_->set_assigned_device_name(device_); - - return Status::OK(); -} - -Status Encapsulator::Subgraph::AddRecvAtHostNode( - const string& group_attribute, const string& subgraph_name, - const string& outside_compilation_attribute, const string& oc_subgraph_name, - OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out) { - if (host_compute_key_placeholder_ == nullptr) { - TF_RETURN_IF_ERROR(AddHostComputeKeyPlaceholder(oc_subgraph, graph_out)); - } - - std::vector dtypes(oc_subgraph->inputs.size(), DT_INVALID); - - for (const auto& input : oc_subgraph->inputs) { - const Node* src_node = input.first.node; - int src_slot = input.first.index; - int input_index = input.second; - - DataType dtype = src_node->output_type(src_slot); - dtypes[input_index] = dtype; - } - - NodeDef recv_def; - // TODO(shikharagarwal): What source node should we use for errors? - NodeDefBuilder builder(absl::StrCat("outside_compilation_", subgraph_name, - "_", oc_subgraph_name, "_recv"), - kRecvAtHostOp); - builder.Device(device_); - builder.Attr("Toutputs", dtypes); - // The correct device_ordinal will be inserted during replication in a - // subsequent rewrite. - builder.Attr("device_ordinal", 0); - builder.Attr("key", absl::StrCat("host_compute_channel_", subgraph_name, "_", - oc_subgraph_name)); - builder.Attr(group_attribute, subgraph_name); - builder.Attr(outside_compilation_attribute, oc_subgraph_name); - builder.Input(host_compute_key_placeholder_->name(), 0, DT_STRING); - Status s = builder.Finalize(&recv_def); - if (!s.ok()) return s; - - oc_subgraph->recv_at_host = graph_out->AddNode(recv_def, &s); - if (!s.ok()) return s; - graph_out->AddEdge(host_compute_key_placeholder_, 0, - oc_subgraph->recv_at_host, 0); - - // Add a control dependency forcing the RecvAtHost to run before the subgraph - // completes. This has no effect on execution order but prevents the - // RecvAtHost being pruned. - TF_RETURN_IF_ERROR(MakeSequencingNode(subgraph_name, graph_out)); - graph_out->AddControlEdge(oc_subgraph->recv_at_host, sequencer_, - true /* skip duplicates check */); - - return Status::OK(); -} - -Status Encapsulator::Subgraph::AddSendFromHostNode( - const std::unordered_map& node_images, - const string& group_attribute, const string& subgraph_name, - const string& outside_compilation_attribute, const string& oc_subgraph_name, - OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out) { - if (host_compute_key_placeholder_ == nullptr) { - TF_RETURN_IF_ERROR(AddHostComputeKeyPlaceholder(oc_subgraph, graph_out)); - } - - std::vector dtypes(oc_subgraph->outputs_by_src.size(), DT_INVALID); - std::vector inputs( - oc_subgraph->outputs_by_src.size()); - - for (const auto& output : oc_subgraph->outputs_by_src) { - const Node* src_node = output.first.node; - Node* src_image = node_images.at(src_node); - int src_slot = output.first.index; - int output_index = output.second.index; - - DataType dtype = src_node->output_type(src_slot); - dtypes[output_index] = dtype; - inputs[output_index].Reset(src_image->name(), src_slot, dtype); - } - - NodeDef send_def; - // TODO(shikharagarwal): What source node should we use for errors? - NodeDefBuilder builder(absl::StrCat("outside_compilation_", subgraph_name, - "_", oc_subgraph_name, "_send"), - kSendFromHostOp); - builder.Device(device_); - builder.Attr("Tinputs", dtypes); - builder.Attr("key", absl::StrCat("host_compute_channel_", subgraph_name, "_", - oc_subgraph_name)); - // The correct device_ordinal will be inserted during replication in a - // subsequent rewrite. - builder.Attr("device_ordinal", 0); - builder.Attr(group_attribute, subgraph_name); - builder.Attr(outside_compilation_attribute, oc_subgraph_name); - builder.Input(inputs); - builder.Input(host_compute_key_placeholder_->name(), 0, DT_STRING); - Status s = builder.Finalize(&send_def); - if (!s.ok()) return s; - - oc_subgraph->send_from_host = graph_out->AddNode(send_def, &s); - if (!s.ok()) return s; - graph_out->AddEdge(host_compute_key_placeholder_, 0, - oc_subgraph->send_from_host, inputs.size()); - - // Add a control dependency forcing the SendFromHost to run before the - // subgraph completes. This has no effect on execution order but prevents the - // RecvAtHost being pruned. - TF_RETURN_IF_ERROR(MakeSequencingNode(subgraph_name, graph_out)); - graph_out->AddControlEdge(oc_subgraph->send_from_host, sequencer_, - /* allow_duplicates= */ true); - - return Status::OK(); -} - -Status Encapsulator::Subgraph::AddOutsideCompilationHostIONodes( - const string& group_attribute, const string& subgraph_name, - const string& outside_compilation_attribute, - const std::unordered_map& node_images, - Graph* graph_out) { - for (auto& outside_compilation_subgraph_entry : - outside_compilation_subgraphs_) { - const string& oc_name = outside_compilation_subgraph_entry.first; - OutsideCompilationSubgraph& oc_subgraph = - outside_compilation_subgraph_entry.second; - - if (!oc_subgraph.inputs.empty() || !oc_subgraph.control_inputs.empty()) { - TF_RETURN_IF_ERROR(AddRecvAtHostNode(group_attribute, subgraph_name, - outside_compilation_attribute, - oc_name, &oc_subgraph, graph_out)); - } - - if (!oc_subgraph.outputs_by_src.empty() || - !oc_subgraph.control_outputs.empty()) { - TF_RETURN_IF_ERROR(AddSendFromHostNode( - node_images, group_attribute, subgraph_name, - outside_compilation_attribute, oc_name, &oc_subgraph, graph_out)); - } - } - return Status::OK(); -} - -void Encapsulator::Subgraph::GetOutsideCompilationSubgraphNames( - std::vector* names) const { - for (auto& entry : outside_compilation_subgraphs_) { - names->push_back(entry.first); - } -} - -Status Encapsulator::GetFunctionNameAttr( - Node const* node, string* attr, string* outside_compilation_attr) const { +Status Encapsulator::GetFunctionNameAttr(Node const* node, string* attr) const { AttrSlice attrs = node->attrs(); attr->clear(); - outside_compilation_attr->clear(); bool found_group_attribute = false; - bool found_outside_compilation_attribute = false; for (const auto& node_attr : attrs) { if (node_attr.first == group_attribute_) { TF_RETURN_IF_ERROR(AttrValueHasType(node_attr.second, "string")); *attr = node_attr.second.s(); found_group_attribute = true; - } else if (node_attr.first == outside_compilation_attribute_) { - TF_RETURN_IF_ERROR(AttrValueHasType(node_attr.second, "string")); - *outside_compilation_attr = node_attr.second.s(); - found_outside_compilation_attribute = true; + break; } - if (found_group_attribute && found_outside_compilation_attribute) break; - } - - if (found_outside_compilation_attribute && !found_group_attribute) { - return errors::InvalidArgument( - "Node ", node->name(), " has ", outside_compilation_attribute_, - " attribute but no ", group_attribute_, " attribute."); - } else { - return Status::OK(); } + return Status::OK(); } -bool IsInSubgraph(const string& func_id, const string& outside_compilation_id) { - return !func_id.empty() && outside_compilation_id.empty(); -} +bool IsInSubgraph(const string& func_id) { return !func_id.empty(); } Status Encapsulator::CopySubgraphNodes( std::unordered_map* node_images) { for (Node* node : graph_in_->op_nodes()) { string func_id; - string outside_compilation_id; - TF_RETURN_IF_ERROR( - GetFunctionNameAttr(node, &func_id, &outside_compilation_id)); - if (!IsInSubgraph(func_id, outside_compilation_id)) continue; + TF_RETURN_IF_ERROR(GetFunctionNameAttr(node, &func_id)); + if (!IsInSubgraph(func_id)) continue; Subgraph& subgraph = subgraphs_[func_id]; Node* image = subgraph.MakeNodeImage(graph_in_, node); @@ -1463,19 +707,14 @@ Status Encapsulator::CopySubgraphEdges( std::vector>* src_arg_pairs) { for (const Edge* edge : graph_in_->edges()) { string src_func_id; - string src_outside_compilation_id; - TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->src(), &src_func_id, - &src_outside_compilation_id)); + TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->src(), &src_func_id)); string dst_func_id; - string dst_outside_compilation_id; - TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->dst(), &dst_func_id, - &dst_outside_compilation_id)); + TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->dst(), &dst_func_id)); Node* src_image = gtl::FindWithDefault(node_images, edge->src(), nullptr); Node* dst_image = gtl::FindWithDefault(node_images, edge->dst(), nullptr); // Copy edges that are local to a subgraph. - if (IsInSubgraph(src_func_id, src_outside_compilation_id) && - IsInSubgraph(dst_func_id, dst_outside_compilation_id) && + if (IsInSubgraph(src_func_id) && IsInSubgraph(dst_func_id) && src_func_id == dst_func_id) { Graph* g = subgraphs_[src_func_id].GetGraph(); if (edge->IsControlEdge()) { @@ -1488,7 +727,7 @@ Status Encapsulator::CopySubgraphEdges( } // Record 'src' as an output of its subgraph, if applicable. - if (IsInSubgraph(src_func_id, src_outside_compilation_id)) { + if (IsInSubgraph(src_func_id)) { if (!edge->IsControlEdge()) { DataType dtype = edge->src()->output_type(edge->src_output()); if (IsRefType(dtype)) { @@ -1500,23 +739,15 @@ Status Encapsulator::CopySubgraphEdges( } Subgraph& src_subgraph = subgraphs_[src_func_id]; - if (src_func_id == dst_func_id) { - // src is in the subgraph and dst is outside_compilation in the same - // subgraph. - src_subgraph.RecordOutsideCompilationInputOrControl( - dst_outside_compilation_id, edge); + if (edge->IsControlEdge()) { + TF_RETURN_IF_ERROR(src_subgraph.RecordControlResult(edge, node_images)); } else { - if (edge->IsControlEdge()) { - TF_RETURN_IF_ERROR( - src_subgraph.RecordControlResult(edge, node_images)); - } else { - TF_RETURN_IF_ERROR(src_subgraph.RecordResult(edge, node_images)); - } + TF_RETURN_IF_ERROR(src_subgraph.RecordResult(edge, node_images)); } } // Record 'dst' as an input of its subgraph, if applicable. - if (IsInSubgraph(dst_func_id, dst_outside_compilation_id)) { + if (IsInSubgraph(dst_func_id)) { // Look at the type of the destination not the source, since Ref output // Tensors can be automatically cast to non-Ref Tensors at the // destination. @@ -1531,18 +762,11 @@ Status Encapsulator::CopySubgraphEdges( } Subgraph& dst_subgraph = subgraphs_[dst_func_id]; - if (src_func_id == dst_func_id) { - // dst is in the subgraph and src is outside_compilation in the same - // subgraph. - dst_subgraph.RecordOutsideCompilationOutputOrControl( - src_outside_compilation_id, edge); - } else { - // Ignore control edges entering the subgraph. We will lift them onto - // the enclosing call operators in BuildOutputGraph(). - if (!edge->IsControlEdge()) { - TF_RETURN_IF_ERROR( - dst_subgraph.RecordArg(edge, node_images, src_arg_pairs)); - } + // Ignore control edges entering the subgraph. We will lift them onto + // the enclosing call operators in BuildOutputGraph(). + if (!edge->IsControlEdge()) { + TF_RETURN_IF_ERROR( + dst_subgraph.RecordArg(edge, node_images, src_arg_pairs)); } } } @@ -1564,16 +788,6 @@ Status Encapsulator::SplitIntoSubgraphs(FunctionLibraryDefinition* library) { TF_RETURN_IF_ERROR(CopySubgraphNodes(&node_images)); TF_RETURN_IF_ERROR(CopySubgraphEdges(node_images, &src_arg_pairs)); - // For each subgraph, add the nodes that deal with inputs and outputs its - // nested outside_compilation subgraphs. These could not be added earlier - // during CopySubgraphEdges since we need to discover all the types of the - // inputs and outputs for an outside_compilation subgraph before creating a - // single input and output node for it. - for (auto& entry : subgraphs_) { - Subgraph& subgraph = entry.second; - TF_RETURN_IF_ERROR(subgraph.AddHostComputes(entry.first, node_images)); - } - MarkGuaranteedConstants(*graph_in_, src_arg_pairs); for (auto& entry : subgraphs_) { @@ -1609,12 +823,10 @@ Status Encapsulator::CopyNodesToOutputGraph( Graph* graph_out, std::unordered_map* node_images) { for (Node* node : graph_in_->op_nodes()) { string func_id; - string outside_compilation_id; - TF_RETURN_IF_ERROR( - GetFunctionNameAttr(node, &func_id, &outside_compilation_id)); + TF_RETURN_IF_ERROR(GetFunctionNameAttr(node, &func_id)); // Don't copy nodes that are going to be encapsulated. - if (IsInSubgraph(func_id, outside_compilation_id)) continue; + if (IsInSubgraph(func_id)) continue; Node* image = graph_out->CopyNode(node); (*node_images)[node] = image; @@ -1634,37 +846,14 @@ Status Encapsulator::AddFunctionCallNodes( return Status::OK(); } -Status Encapsulator::AddOutsideCompilationHostIONodes( - const std::unordered_map& node_images, - Graph* graph_out) { - for (auto& subgraph_entry : subgraphs_) { - const string& subgraph_name = subgraph_entry.first; - Subgraph& subgraph = subgraph_entry.second; - TF_RETURN_IF_ERROR(subgraph.AddOutsideCompilationHostIONodes( - group_attribute_, subgraph_name, outside_compilation_attribute_, - node_images, graph_out)); - } - return Status::OK(); -} - Status Encapsulator::FindOutputImageOfEdgeSrc( - const string& src_func_id, const string& src_outside_compilation_id, - const string& dst_func_id, const string& dst_outside_compilation_id, + const string& src_func_id, const string& dst_func_id, const std::unordered_map& node_images, const Node* original_src_node, Node** src_image) { - if (IsInSubgraph(src_func_id, src_outside_compilation_id)) { - if (dst_func_id == src_func_id) { - // The edge is from a subgraph to an outside_compilation cluster in the - // same subgraph so use the appropriate _RecvAtHost node in the output - // graph. - TF_RET_CHECK(!dst_outside_compilation_id.empty()); - *src_image = subgraphs_.at(src_func_id) - .GetRecvAtHostNode(dst_outside_compilation_id); - } else { - // The edge is from a subgraph to a regular node in the output graph so - // use the subgraph's call node output. - *src_image = subgraphs_.at(src_func_id).GetCallNode(); - } + if (IsInSubgraph(src_func_id)) { + // The edge is from a subgraph to a regular node in the output graph so + // use the subgraph's call node output. + *src_image = subgraphs_.at(src_func_id).GetCallNode(); } else { // The source of the edge is in the output graph so use the node image in // the output graph. @@ -1673,21 +862,14 @@ Status Encapsulator::FindOutputImageOfEdgeSrc( return Status::OK(); } -int Encapsulator::FindOutputSlotOfEdgeSrc( - const string& src_func_id, const string& src_outside_compilation_id, - const string& dst_func_id, const string& dst_outside_compilation_id, - const Edge* edge) { - if (IsInSubgraph(src_func_id, src_outside_compilation_id)) { +int Encapsulator::FindOutputSlotOfEdgeSrc(const string& src_func_id, + const string& dst_func_id, + const Edge* edge) { + if (IsInSubgraph(src_func_id)) { const Subgraph& src_subgraph = subgraphs_.at(src_func_id); - if (src_func_id == dst_func_id) { - // 'src' is in a subgraph and 'dst' is outside_compilation in the same - // subgraph. Use the corresponding _RecvAtHost output instead. - return src_subgraph.GetRecvAtHostSlot(dst_outside_compilation_id, edge); - } else { - // 'src' is in a subgraph and 'dst' is a regular node in the output - // graph. Use the corresponding call output instead. - return src_subgraph.GetResultIndexForEdge(edge); - } + // 'src' is in a subgraph and 'dst' is a regular node in the output + // graph. Use the corresponding call output instead. + return src_subgraph.GetResultIndexForEdge(edge); } else { // The source of the edge is in the output graph so use the regular edge // slot. @@ -1696,23 +878,13 @@ int Encapsulator::FindOutputSlotOfEdgeSrc( } Status Encapsulator::FindOutputImageOfEdgeDst( - const string& src_func_id, const string& src_outside_compilation_id, - const string& dst_func_id, const string& dst_outside_compilation_id, + const string& src_func_id, const string& dst_func_id, const std::unordered_map& node_images, const Node* original_dst_node, Node** dst_image) { - if (IsInSubgraph(dst_func_id, dst_outside_compilation_id)) { - if (src_func_id == dst_func_id) { - // The edge is to a subgraph from an outside_compilation cluster in the - // same subgraph so use the appropriate _SendFromHost node in the output - // graph. - TF_RET_CHECK(!src_outside_compilation_id.empty()); - *dst_image = subgraphs_.at(dst_func_id) - .GetSendFromHostNode(src_outside_compilation_id); - } else { - // The edge is to a subgraph from a regular node in the output graph so - // use the subgraph's call node input. - *dst_image = subgraphs_.at(dst_func_id).GetCallNode(); - } + if (IsInSubgraph(dst_func_id)) { + // The edge is to a subgraph from a regular node in the output graph so + // use the subgraph's call node input. + *dst_image = subgraphs_.at(dst_func_id).GetCallNode(); } else { // The destination of the edge is in the output graph so use the node image // in the output graph. @@ -1721,21 +893,14 @@ Status Encapsulator::FindOutputImageOfEdgeDst( return Status::OK(); } -int Encapsulator::FindOutputSlotOfEdgeDst( - const string& src_func_id, const string& src_outside_compilation_id, - const string& dst_func_id, const string& dst_outside_compilation_id, - const Edge* edge) { - if (IsInSubgraph(dst_func_id, dst_outside_compilation_id)) { +int Encapsulator::FindOutputSlotOfEdgeDst(const string& src_func_id, + const string& dst_func_id, + const Edge* edge) { + if (IsInSubgraph(dst_func_id)) { const Subgraph& dst_subgraph = subgraphs_.at(dst_func_id); - if (dst_func_id == src_func_id) { - // 'dst' is in a subgraph and 'src' is outside_compilation in the same - // subgraph. Use the corresponding _SendFromHost input instead. - return dst_subgraph.GetSendFromHostSlot(src_outside_compilation_id, edge); - } else { // 'dst' is in a subgraph and 'src' is a regular node in the output // graph. Use the corresponding call input instead. return dst_subgraph.GetArgIndexForEdge(edge); - } } else { // The destination of the edge is in the output graph so use the regular // edge slot. @@ -1744,20 +909,16 @@ int Encapsulator::FindOutputSlotOfEdgeDst( } Status Encapsulator::CopyEdgeToOutputGraph( - const Edge* edge, const string& src_func_id, - const string& src_outside_compilation_id, const string& dst_func_id, - const string& dst_outside_compilation_id, + const Edge* edge, const string& src_func_id, const string& dst_func_id, const std::unordered_map& node_images, Graph* graph_out, std::unordered_set, OutputInputTensorPairHasher>* edges_added) { Node* src_image; TF_RETURN_IF_ERROR(FindOutputImageOfEdgeSrc( - src_func_id, src_outside_compilation_id, dst_func_id, - dst_outside_compilation_id, node_images, edge->src(), &src_image)); + src_func_id, dst_func_id, node_images, edge->src(), &src_image)); Node* dst_image; TF_RETURN_IF_ERROR(FindOutputImageOfEdgeDst( - src_func_id, src_outside_compilation_id, dst_func_id, - dst_outside_compilation_id, node_images, edge->dst(), &dst_image)); + src_func_id, dst_func_id, node_images, edge->dst(), &dst_image)); // If this is a control edge then copy it and return. Lift control edges onto // the enclosing call operator. @@ -1774,13 +935,9 @@ Status Encapsulator::CopyEdgeToOutputGraph( return Status::OK(); } - int src_output = - FindOutputSlotOfEdgeSrc(src_func_id, src_outside_compilation_id, - dst_func_id, dst_outside_compilation_id, edge); + int src_output = FindOutputSlotOfEdgeSrc(src_func_id, dst_func_id, edge); - int dst_input = - FindOutputSlotOfEdgeDst(src_func_id, src_outside_compilation_id, - dst_func_id, dst_outside_compilation_id, edge); + int dst_input = FindOutputSlotOfEdgeDst(src_func_id, dst_func_id, edge); // Add the edge, if we have not already added it. if (edges_added @@ -1792,18 +949,6 @@ Status Encapsulator::CopyEdgeToOutputGraph( return Status::OK(); } -Status Encapsulator::AddCallNodeDependencies(Graph* graph_out) { - for (const auto& ancestors : subgraph_ancestors_) { - const string& subgraph = ancestors.first; - for (const string& ancestor : ancestors.second) { - graph_out->AddControlEdge(subgraphs_[ancestor].GetCallNode(), - subgraphs_[subgraph].GetCallNode(), - /* allow_duplicates= */ true); - } - } - return Status::OK(); -} - Status Encapsulator::AddEdgesToOutputGraph( const std::unordered_map& node_images, Graph* graph_out) { @@ -1816,18 +961,13 @@ Status Encapsulator::AddEdgesToOutputGraph( for (const Edge* edge : graph_in_->edges()) { string src_func_id; - string src_outside_compilation_id; - TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->src(), &src_func_id, - &src_outside_compilation_id)); + TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->src(), &src_func_id)); string dst_func_id; - string dst_outside_compilation_id; - TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->dst(), &dst_func_id, - &dst_outside_compilation_id)); + TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->dst(), &dst_func_id)); // Ignore edges that are strictly contained within one subgraph, unless // we are constructing parallel check graphs. - if (IsInSubgraph(src_func_id, src_outside_compilation_id) && - IsInSubgraph(dst_func_id, dst_outside_compilation_id) && + if (IsInSubgraph(src_func_id) && IsInSubgraph(dst_func_id) && src_func_id == dst_func_id) { continue; } @@ -1835,15 +975,13 @@ Status Encapsulator::AddEdgesToOutputGraph( // We have an edge that crosses a cluster boundary or is entirely within the // unclustered graph. TF_RETURN_IF_ERROR(CopyEdgeToOutputGraph( - edge, src_func_id, src_outside_compilation_id, dst_func_id, - dst_outside_compilation_id, node_images, graph_out, &edges_added)); + edge, src_func_id, dst_func_id, node_images, graph_out, &edges_added)); } for (auto& subgraph_entry : subgraphs_) { Subgraph& subgraph = subgraph_entry.second; subgraph.ConnectSequencerToCallNode(graph_out); } - TF_RETURN_IF_ERROR(AddCallNodeDependencies(graph_out)); return Status::OK(); } @@ -1893,413 +1031,8 @@ Node* AddDummyShapedNode(const Node* src_node, int src_port, return node; } -// Adds a copy of node_in to graph_out and adds the mapping to -// copied_node_images. -Status CopyShapeInferenceNodeToGraph( - Node* node_in, const Node* send_node, - const std::unordered_map& dummy_node_images, - FunctionLibraryDefinition* library, - std::unordered_map* copied_node_images, Graph* graph_out) { - // Once all the ancestor nodes have been added to graph_out, add this node - // and connect it to its ancestors. - Node* node_out = graph_out->CopyNode(node_in); - (*copied_node_images)[node_in] = node_out; - // Don't bother to build the shape inference graph if there's a node with no - // shape inference function, since it would just result in an error later at - // compile time. - const OpRegistrationData* op_reg_data; - TF_RETURN_IF_ERROR(library->LookUp(node_in->type_string(), &op_reg_data)); - if (op_reg_data->shape_inference_fn == nullptr) { - return errors::InvalidArgument( - "Shape inference is not possible for outside_compilation " - "SendFromHost node ", - send_node->name(), " because it depends on node ", node_in->name(), - " which does not have a shape inference function registered."); - } - // Add all the edges to the newly copied node. - for (const Edge* in_edge : node_in->in_edges()) { - if (!in_edge->IsControlEdge()) { - Node* src = in_edge->src(); - const auto iter = dummy_node_images.find(src); - if (iter == dummy_node_images.end()) { - // The src is a copied node so use the original output port. - graph_out->AddEdge((*copied_node_images)[in_edge->src()], - in_edge->src_output(), node_out, - in_edge->dst_input()); - } else { - // The src is a dummy node so use output port 0. - graph_out->AddEdge(iter->second, 0, node_out, in_edge->dst_input()); - } - } - } - // Work around the fact that Enter nodes refuse to propagate shape information - // unless they are marked loop invariant. Since we are never going to execute - // this graph, marking them all loop invariant is fine. - if (node_out->type_string() == "Enter") { - node_out->ClearAttr("is_constant"); - node_out->AddAttr("is_constant", true); - } - return Status::OK(); -} - } // namespace -Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( - const Graph& graph_in, const BackEdgeHelper& back_edge_helper, - const ShapeRefiner& shape_refiner, - const std::unordered_set& recv_at_host_nodes, Node* send_node, - FunctionLibraryDefinition* library, - std::vector* static_shape_out, - std::unique_ptr* graph_out) { - // Get the control flow structure of the input graph so we can build - // well-formed output graphs. - std::vector control_flow_info; - TF_RETURN_IF_ERROR(BuildControlFlowInfo(&graph_in, &control_flow_info)); - - // Maps from nodes in graph_in to nodes in graph_out. - // - // When an edge has fully defined shape the source node in graph_in is - // replaced in graph_out by a dummy constant node. The mapping from nodes - // in graph_in to dummy nodes is stored in dummy_node_images. - // - // When a node in graph_in has at least one ancestor that doesn't have fully - // defined shape, it is copied into graph_out. The mapping from nodes in - // graph_in to copied nodes is stored in copied_node_images. - // - // The two types of node are treated differently because, when adding edges to - // graph_out, an output from a dummy node always uses port 0, whereas an - // output from a copied node uses the same port that was used in graph_in. - std::unordered_map dummy_node_images; - std::unordered_map copied_node_images; - - graph_out->reset(new Graph(graph_in.op_registry())); - (*graph_out)->set_versions(graph_in.versions()); - // The final input to the send node is the dynamic key, which we don't include - // in the static shapes. - static_shape_out->resize(send_node->num_inputs() - 1); - - // We don't use the standard ReverseDFS because we want to cut off traversal - // whenever we find an output with fully defined shape. - struct Work { - Node* node; - bool leave; // Are we entering or leaving node? - }; - std::vector stack({{send_node, false}}); - std::vector visited(graph_in.num_node_ids(), false); - while (!stack.empty()) { - Work w = stack.back(); - stack.pop_back(); - Node* n = w.node; - - if (w.leave) { - TF_RETURN_IF_ERROR(CopyShapeInferenceNodeToGraph( - n, send_node, dummy_node_images, library, &copied_node_images, - graph_out->get())); - } else { - if (visited[n->id()]) continue; - visited[n->id()] = true; - - // Arrange to revisit when all done with all inputs. - stack.push_back(Work{n, true}); - - bool has_parent_with_unknown_shape = false; - for (const Edge* in_edge : n->in_edges()) { - if (!in_edge->IsControlEdge()) { - Node* src_node = in_edge->src(); - int src_port = in_edge->src_output(); - shape_inference::InferenceContext* context = - shape_refiner.GetContext(src_node); - shape_inference::ShapeHandle shape = context->output(src_port); - if (context->FullyDefined(shape)) { - // This ancestor has known shape, so instead of adding it to the - // stack, add a dummy node with that shape to graph_out and - // continue. - TensorShapeProto proto; - context->ShapeHandleToProto(shape, &proto); - VLOG(2) << "Node " << src_node->name() - << " has known shape: " << proto.DebugString(); - if (dummy_node_images.find(src_node) == dummy_node_images.end()) { - dummy_node_images[src_node] = - AddDummyShapedNode(src_node, src_port, control_flow_info, - proto, graph_out->get()); - } - // The final input to the send node is the dynamic key, which we - // don't include in the static shapes. - if (n == send_node && - in_edge->dst_input() < static_shape_out->size()) { - (*static_shape_out)[in_edge->dst_input()] = proto; - } - } else { - has_parent_with_unknown_shape = true; - if (!visited[src_node->id()]) { - if (VLOG_IS_ON(2)) { - TensorShapeProto proto; - context->ShapeHandleToProto(shape, &proto); - VLOG(2) << "Node " << src_node->name() - << " has unknown shape: " << proto.DebugString(); - } - stack.push_back({src_node, false}); - } - } - } - } - if (!has_parent_with_unknown_shape) { - if (n == send_node) { - // The shapes of all the inputs to send_node are statically known. We - // won't have to do any inference at compile time so return now: the - // shapes were stored in static_shape_out above. - graph_out->reset(); - return Status::OK(); - } else { - // Any shape that is being processed is either the original send node - // or has at least one output with statically-unknown shape. If the - // latter and it doesn't have any inputs with statically-unknown - // shape, then check that it is of the recv nodes that we can fill in - // the shape of at run-time later. If it isn't one of those, then we - // won't have any additional knowledge at compile time, so we already - // know we won't be able to do shape inference and we can return an - // error now. - if (recv_at_host_nodes.find(n->name()) == recv_at_host_nodes.end()) { - return errors::InvalidArgument( - "Shape inference is not possible for outside_compilation " - "SendFromHost node ", - send_node->name(), " because shape of node ", - FormatNodeForError(*n), - " will not be known at compilation time."); - } - } - } - } - } - - for (const auto edge : back_edge_helper.RemovedEdges()) { - if (copied_node_images.find(edge.dst) != copied_node_images.end()) { - // The destination of this back edge was added to the inference graph, so - // fix it up. - Node* dst = copied_node_images[edge.dst]; - if (dst->type_string() != "Merge") { - return errors::InvalidArgument( - "outside_compilation cluster contains a back-edge to node ", - dst->name(), " of type ", dst->type_string(), - ". The analysis pass only supports back-edges to Merge nodes."); - } - const Edge* existing_input_edge; - if (edge.dst_input != 1 || dst->num_inputs() != 2 || - !dst->input_edge(0, &existing_input_edge).ok()) { - // TODO(misard) if we see graphs built with a different structure, relax - // this constraint. Leaving it here for now to avoid writing unnecessary - // complex code since we believe graphs generated by front ends all have - // the back edge as the second input to the merge node. - return errors::Internal( - "Internal assumption failed while rewriting an outside_compilation " - "cluster that contains a while loop. Logic assumes back-edge is to " - "port 1 of a 2-input Merge node."); - } - // Connect the existing edge to both inputs of the Merge node so that the - // graph will be well-formed. - (*graph_out) - ->AddEdge(existing_input_edge->src(), - existing_input_edge->src_output(), dst, edge.dst_input); - } - } - - return Status::OK(); -} - -namespace { - -// Helper struct for building cluster dependencies and also debugging cycles in -// the dependencies. While computing dependencies we construct a mapping from -// Node* to PathDetails. -struct PathDetails { - struct SubgraphAndCluster { - string subgraph; - string outside_compilation_cluster; - bool operator==(const SubgraphAndCluster& other) const { - return subgraph == other.subgraph && - outside_compilation_cluster == other.outside_compilation_cluster; - } - }; - - struct SubgraphAndClusterHash { - inline std::size_t operator()(const SubgraphAndCluster& v) const { - return hash()( - absl::StrCat(v.subgraph, v.outside_compilation_cluster)); - } - }; - - typedef std::unordered_set - SubgraphAndClusterSet; - - // Returns the set of (subgraph, oc_cluster) pairs that should be recorded as - // ancestors for any successor of this node. If the node is in the outer - // graph, it returns the transitive union of the ancestors of the node's - // inputs. If the node is in an outside_compilation cluster, it returns just - // that cluster. If the node is compiled, it returns the empty set. - SubgraphAndClusterSet AncestorsForSuccessor() { - if (subgraph.empty()) { - return ancestor_clusters; - } else if (outside_compilation_cluster.empty()) { - return SubgraphAndClusterSet(); - } else { - SubgraphAndCluster entry; - entry.subgraph = subgraph; - entry.outside_compilation_cluster = outside_compilation_cluster; - return SubgraphAndClusterSet({entry}); - } - } - - // The transitive union of the ancestor's of this node's inputs. This is only - // saved for debugging in order to print out enough information to debug a - // discovered cycle. - SubgraphAndClusterSet ancestor_clusters; - // The subgraph attr on this node. - string subgraph; - // The outside_compilation attr on this node. - string outside_compilation_cluster; -}; - -// Adds an edge from ancestor to successor to the cycle detector, and returns an -// error if that edge causes the formation of a cycle. In the error case, logs -// the contents of the node_ancestors_map to facilitate debugging. -Status CheckClusterDependencyForCycles( - const string& ancestor, const string& successor, - const std::unordered_map>& ancestors, - const std::unordered_map& node_ancestors_map, - GraphCycles* cycle_detector, - std::unordered_map* cycle_detector_map) { - if (cycle_detector_map->find(ancestor) == cycle_detector_map->end()) { - (*cycle_detector_map)[ancestor] = cycle_detector->NewNode(); - } - if (cycle_detector_map->find(successor) == cycle_detector_map->end()) { - (*cycle_detector_map)[successor] = cycle_detector->NewNode(); - } - - if (!cycle_detector->InsertEdge((*cycle_detector_map)[ancestor], - (*cycle_detector_map)[successor])) { - LOG(ERROR) << "Cycle in outside_compilation clusters"; - for (const auto& cluster : ancestors) { - LOG(ERROR) << "Cluster " << cluster.first << " depends on:"; - for (const auto& ancestor : cluster.second) { - LOG(ERROR) << " " << ancestor; - } - } - for (const auto& node_ancestors : node_ancestors_map) { - LOG(ERROR) << "Node " << node_ancestors.first->name() << " (" - << node_ancestors.second.subgraph << ";" - << node_ancestors.second.outside_compilation_cluster - << ") has ancestor clusters:"; - for (const auto& ancestor : node_ancestors.second.ancestor_clusters) { - LOG(ERROR) << " " << ancestor.subgraph << ";" - << ancestor.outside_compilation_cluster; - } - } - return errors::InvalidArgument( - "Can't compile outside_compilation clusters because there is a " - "dependency cycle: see error log for details."); - } - return Status::OK(); -} - -} // namespace - -Status Encapsulator::FindClusterDependencies() { - // Map from nodes to ancestor details. A node is entered into the map if it is - // in a compilation subgraph, and outside_compilation cluster, or appears on a - // path in the outer graph leading from an outside_compilation subgraph. - std::unordered_map node_ancestors_map; - // We check that clusters are acyclic using this cycle detector. - GraphCycles cycle_detector; - // Map from cluster name to cycle detector node id. - std::unordered_map cycle_detector_map; - // Process the nodes in topologically-sorted order. - std::vector nodes; - GetReversePostOrder(*graph_in_, &nodes); - for (Node* node : nodes) { - string subgraph_name; - string oc_cluster; - TF_RETURN_IF_ERROR(GetFunctionNameAttr(node, &subgraph_name, &oc_cluster)); - // First create an entry in the ancestors map if the node is in a compiled - // subgraph or outside_compilation cluster, or if any incoming edge is from - // a node with an ancestor map entry; and find the union of all the - // ancestors. - if (!subgraph_name.empty()) { - node_ancestors_map[node].subgraph = subgraph_name; - node_ancestors_map[node].outside_compilation_cluster = oc_cluster; - } - for (Node* src : node->in_nodes()) { - const auto iter = node_ancestors_map.find(src); - if (iter != node_ancestors_map.end()) { - const auto& ancestors_to_follow = iter->second.AncestorsForSuccessor(); - for (const auto& ancestor : ancestors_to_follow) { - if (ancestor.subgraph != subgraph_name || - ancestor.outside_compilation_cluster != oc_cluster) { - node_ancestors_map[node].ancestor_clusters.insert(ancestor); - } - } - } - } - if (!subgraph_name.empty()) { - // The node is in a compiled subgraph or an outside_compilation cluster. - if (oc_cluster.empty()) { - // The node is not in an outside_compilation cluster. Record the - // subgraph's ancestor dependencies. - for (const auto& cluster : node_ancestors_map[node].ancestor_clusters) { - if (cluster.subgraph != subgraph_name) { - subgraph_ancestors_[subgraph_name].insert(cluster.subgraph); - TF_RETURN_IF_ERROR(CheckClusterDependencyForCycles( - cluster.subgraph, subgraph_name, subgraph_ancestors_, - node_ancestors_map, &cycle_detector, &cycle_detector_map)); - } - } - } else { - Subgraph& subgraph = subgraphs_[subgraph_name]; - // The node is in an outside_compilation cluster. Record the cluster - // and/or subgraph ancestor dependencies. - for (const auto& cluster : node_ancestors_map[node].ancestor_clusters) { - if (cluster.subgraph == subgraph_name) { - // The ancestor is in the same subgraph. - if (cluster.outside_compilation_cluster != oc_cluster) { - // But not in the same oc_cluster, so record the dependency. - subgraph.RecordOutsideCompilationDependency( - oc_cluster, cluster.outside_compilation_cluster); - TF_RETURN_IF_ERROR(CheckClusterDependencyForCycles( - cluster.outside_compilation_cluster, oc_cluster, - subgraph.OutsideCompilationAncestorMap(), node_ancestors_map, - &cycle_detector, &cycle_detector_map)); - } - } else { - // The ancestor is in a different subgraph, so record the - // dependency. - subgraph_ancestors_[subgraph_name].insert(cluster.subgraph); - TF_RETURN_IF_ERROR(CheckClusterDependencyForCycles( - cluster.subgraph, subgraph_name, subgraph_ancestors_, - node_ancestors_map, &cycle_detector, &cycle_detector_map)); - } - } - } - } - } - if (VLOG_IS_ON(2)) { - // Print debug information. - VLOG(2) << "node_ancestors_map:"; - for (const auto& node_iter : node_ancestors_map) { - VLOG(2) << "\t" << node_iter.first->name() << ": subgraph = '" - << node_iter.second.subgraph - << "', outside_compilation_cluster = '" - << node_iter.second.outside_compilation_cluster - << "', ancestor_clusters: " - << (node_iter.second.ancestor_clusters.empty() ? "(empty)" : ""); - for (const auto& cluster_iter : node_iter.second.ancestor_clusters) { - VLOG(2) << "\t\tsubgraph = '" << cluster_iter.subgraph - << "', outside_compilation_cluster = '" - << cluster_iter.outside_compilation_cluster << "'"; - } - } - } - return Status::OK(); -} - Status Encapsulator::MakePrunedGraphCopyAndInline( const Graph& graph, const std::vector& sink_nodes, std::unique_ptr* pruned_graph, @@ -2362,118 +1095,6 @@ Status Encapsulator::MakePrunedGraphCopyAndInline( return Status::OK(); } -Status Encapsulator::MakeGraphForOutsideCompilationSends( - const Graph& graph, std::unique_ptr* pruned_graph, - BackEdgeHelper* back_edge_helper, ShapeRefiner* shape_refiner, - std::unordered_map* node_images, - FunctionLibraryDefinition* library) { - // Find all the send_from_host nodes in all subgraphs, to use as roots for the - // pruning. - std::vector send_from_host_nodes; - for (auto& subgraph_entry : subgraphs_) { - Subgraph& subgraph = subgraph_entry.second; - std::vector outside_compilation_names; - subgraph.GetOutsideCompilationSubgraphNames(&outside_compilation_names); - for (const auto& name : outside_compilation_names) { - Node* send_node = subgraph.GetSendFromHostNode(name); - if (send_node != nullptr) { - send_from_host_nodes.push_back(send_node); - } - } - } - - // Make a copy of all the graph nodes needed to evaluate the send_from_host - // nodes, inlining any functions as needed. - TF_RETURN_IF_ERROR(MakePrunedGraphCopyAndInline( - graph, send_from_host_nodes, pruned_graph, node_images, library)); - FixupSourceAndSinkEdges(pruned_graph->get()); - - // Remove back edges from any cycles in the pruned graph to simplify shape - // inference traversal. They will be fixed up in the per-subgraph shape - // inference graphs stored in the function library. - TF_RETURN_IF_ERROR(back_edge_helper->Remove(pruned_graph->get())); - - // Perform shape inference on the pruned graph. - shape_refiner->set_require_shape_inference_fns(false); - std::vector post_order; - GetReversePostOrder(*(*pruned_graph), &post_order); - for (auto node : post_order) { - // Ignore the status returned by the shape_refiner. At this point we want - // the best effort shapes, even if no shape function is registered for a - // node. - Status status = shape_refiner->AddNode(node); - if (!status.ok()) { - VLOG(1) << "Shape inference failed for node: " << status; - } - } - - return Status::OK(); -} - -Status Encapsulator::GetShapeInfoForOutsideCompilationSends( - Graph* graph_out, FunctionLibraryDefinition* library) { - BackEdgeHelper back_edge_helper; - std::unique_ptr pruned_graph; - ShapeRefiner shape_refiner(graph_out->versions(), graph_out->op_registry()); - std::unordered_map node_images; - TF_RETURN_IF_ERROR(MakeGraphForOutsideCompilationSends( - *graph_out, &pruned_graph, &back_edge_helper, &shape_refiner, - &node_images, library)); - - if (VLOG_IS_ON(1)) { - DumpGraphToFile("pruned_graph_for_shape_inference", *pruned_graph, library); - } - - for (auto& subgraph_entry : subgraphs_) { - const string& subgraph_name = subgraph_entry.first; - Subgraph& subgraph = subgraph_entry.second; - // Find all the recv_at_host nodes in this subgraph. - std::vector outside_compilation_names; - subgraph.GetOutsideCompilationSubgraphNames(&outside_compilation_names); - std::unordered_set recv_at_host_names; - for (const auto& oc_name : outside_compilation_names) { - Node* recv_node = subgraph.GetRecvAtHostNode(oc_name); - if (recv_node != nullptr) { - recv_at_host_names.insert(recv_node->name()); - } - } - // For each send_from_host node, do as much shape inference as possible - // without knowing the shape of the recv_at_host nodes, and store the - // result, along with enough information to complete the job at compile time - // once the recv_at_host shapes are known. - for (const auto& oc_name : outside_compilation_names) { - Node* send_node = subgraph.GetSendFromHostNode(oc_name); - std::vector static_shape; - std::unique_ptr graph; - if (send_node != nullptr) { - TF_RETURN_IF_ERROR(DoStaticShapeInferenceForOutsideCompilationSend( - *pruned_graph, back_edge_helper, shape_refiner, recv_at_host_names, - node_images[send_node], library, &static_shape, &graph)); - if (graph == nullptr) { - VLOG(2) << "Send node " << send_node->name() << " shapes"; - for (int i = 0; i < static_shape.size(); ++i) { - VLOG(2) << static_shape[i].DebugString(); - } - } else { - if (VLOG_IS_ON(2)) { - GraphDef graphdef; - graph->ToGraphDef(&graphdef); - VLOG(2) << "Send node " << send_node->name() << " graph\n" - << graphdef.DebugString(); - } - } - } - TF_RETURN_IF_ERROR(subgraph.AddShapeInferenceInfo( - subgraph_name, oc_name, static_shape, graph.get(), library)); - } - if (!outside_compilation_names.empty()) { - TF_RETURN_IF_ERROR(subgraph.ReplaceFunctionDef(library)); - } - } - - return Status::OK(); -} - Status Encapsulator::BuildOutputGraph(Graph* graph_out, FunctionLibraryDefinition* library) { // Map from nodes in the input graph to nodes in the output graph. @@ -2481,26 +1102,19 @@ Status Encapsulator::BuildOutputGraph(Graph* graph_out, TF_RETURN_IF_ERROR(CopyNodesToOutputGraph(graph_out, &node_images)); TF_RETURN_IF_ERROR(AddFunctionCallNodes(node_images, graph_out)); - TF_RETURN_IF_ERROR(AddOutsideCompilationHostIONodes(node_images, graph_out)); TF_RETURN_IF_ERROR(AddEdgesToOutputGraph(node_images, graph_out)); - TF_RETURN_IF_ERROR( - GetShapeInfoForOutsideCompilationSends(graph_out, library)); - return Status::OK(); } } // anonymous namespace Status EncapsulateSubgraphsInFunctions( - string group_attribute, string outside_compilation_attribute, - const Graph& graph_in, const RewriteSubgraphFn& rewrite_subgraph_fn, - bool reuse_existing_functions, std::unique_ptr* graph_out, - FunctionLibraryDefinition* library) { + string group_attribute, const Graph& graph_in, + const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions, + std::unique_ptr* graph_out, FunctionLibraryDefinition* library) { Encapsulator encapsulator(std::move(group_attribute), - std::move(outside_compilation_attribute), &graph_in); - TF_RETURN_IF_ERROR(encapsulator.FindClusterDependencies()); TF_RETURN_IF_ERROR(encapsulator.SplitIntoSubgraphs(library)); TF_RETURN_IF_ERROR(encapsulator.BuildFunctionDefs( @@ -2685,9 +1299,8 @@ Status EncapsulateSubgraphsPass::Run( TF_RETURN_WITH_CONTEXT_IF_ERROR( EncapsulateSubgraphsInFunctions( - kXlaClusterAttr, kXlaOutsideCompilationAttr, **options.graph, - rewrite_subgraph, /*reuse_existing_functions=*/false, &graph_out, - library), + kXlaClusterAttr, **options.graph, rewrite_subgraph, + /*reuse_existing_functions=*/false, &graph_out, library), "EncapsulateSubgraphsPass failed"); if (VLOG_IS_ON(1)) { diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h index 90354a801af..62b752cf40f 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h @@ -52,16 +52,6 @@ typedef std::function* graph_out, - FunctionLibraryDefinition* library); + string group_attribute, const Graph& graph_in, + const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions, + std::unique_ptr* graph_out, FunctionLibraryDefinition* library); // The attribute that marks function calls produced by the encapsulate // subgraphs pass and that should in turn be compiled via XlaLaunch operators. diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index 958b0a5f61c..d162c16cc16 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -514,10 +514,10 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library, auto flr = pflr->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); std::unique_ptr graph_out; - s = EncapsulateSubgraphsInFunctions( - "_encapsulate", /*outside_compilation_attribute=*/"", *graph, - /*rewrite_subgraph_fn=*/{}, - /*reuse_existing_functions=*/false, &graph_out, lib_def.get()); + s = EncapsulateSubgraphsInFunctions("_encapsulate", *graph, + /*rewrite_subgraph_fn=*/{}, + /*reuse_existing_functions=*/false, + &graph_out, lib_def.get()); if (!s.ok()) return s; std::unordered_map clusters; @@ -746,7 +746,7 @@ TEST(EncapsulateSubgraphsTest, InputDeduplication) { FunctionLibraryDefinition library(OpRegistry::Global(), {}); std::unique_ptr graph; TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( - "_cluster", "", graph_before_encapsulation, + "_cluster", graph_before_encapsulation, /*rewrite_subgraph_fn=*/{}, /*reuse_existing_functions=*/false, &graph, &library)); @@ -798,7 +798,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) { FunctionLibraryDefinition library(OpRegistry::Global(), {}); int guaranteed_consts = 0; TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( - "_encapsulate", "", graph_before, + "_encapsulate", graph_before, /*rewrite_subgraph_fn=*/ [&guaranteed_consts](const std::vector& arg_source_tensors, std::unique_ptr* graph_ptr, @@ -843,7 +843,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) { FunctionLibraryDefinition library(OpRegistry::Global(), {}); int guaranteed_consts = 0; TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( - "_encapsulate", "", graph_before, + "_encapsulate", graph_before, /*rewrite_subgraph_fn=*/ [&guaranteed_consts](const std::vector& arg_source_tensors, std::unique_ptr* graph_ptr, @@ -1109,7 +1109,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { absl::Span( {"_xla_token_arg_node", "outside_compilation_O1_host_compute"})}}, - {"F"}}, + {"F", "outside_compilation_O1_host_compute"}}, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"C:o:0", "D:o:0"}, @@ -1990,7 +1990,8 @@ TEST(EncapsulateSubgraphsTest, {"_xla_token_input_nodes", absl::Span( {"_xla_token_arg_node", - "outside_compilation_O1_host_compute"})}}}, + "outside_compilation_O1_host_compute"})}}, + {"outside_compilation_O1_host_compute"}}, }, {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, {"h_0_retval_retval", "H:o:0"}}); @@ -2117,7 +2118,8 @@ TEST(EncapsulateSubgraphsTest, {"_xla_token_input_nodes", absl::Span( {"_xla_token_arg_node", - "outside_compilation_O1_host_compute"})}}}, + "outside_compilation_O1_host_compute"})}}, + {"outside_compilation_O1_host_compute"}}, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"D:o:0"}, @@ -2267,7 +2269,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {"_xla_token_input_nodes", absl::Span( {"_xla_token_arg_node", "outside_compilation_O1_host_compute"})}}, - {}}, + {"outside_compilation_O1_host_compute"}}, {{"outside_compilation_O3_host_compute"}, "XlaHostCompute", {"D:o:0"}, @@ -2282,7 +2284,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { absl::Span({"_xla_token_arg_node", "outside_compilation_O1_host_compute", "outside_compilation_O2_host_compute"})}}, - {}}}, + {"outside_compilation_O1_host_compute", + "outside_compilation_O2_host_compute"}}}, {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, {"h_0_retval_retval", "H:o:0"}}); diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc index 4e65971191a..2c2cd094133 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc @@ -231,9 +231,9 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors, auto output = absl::make_unique((*graph)->op_registry()); TF_RETURN_WITH_CONTEXT_IF_ERROR( - EncapsulateSubgraphsInFunctions( - kXlaClusterAttr, "", **graph, RewriteSubgraph, - /*reuse_existing_functions=*/true, &output, flib_def), + EncapsulateSubgraphsInFunctions(kXlaClusterAttr, **graph, RewriteSubgraph, + /*reuse_existing_functions=*/true, + &output, flib_def), "EncapsulateXlaComputationsPass failed"); graph->swap(output); return Status::OK(); diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc index a6e66657fb5..0667de9d230 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc @@ -393,7 +393,7 @@ Status ValidateOutsideCompilationCallNode(Node* call_node) { // Replace outside compilation function call node with XlaHostCompute node. // If the function call node has no input/output edges, we will just remove it // and not create a XlaHostCompute node. -Status ReplaceOrRemoveOutsideCompilationCallNode( +xla::StatusOr ReplaceOrRemoveOutsideCompilationCallNode( Graph* g, Node* call_node, const std::map& host_compute_core, const absl::flat_hash_map>& cluster_deps) { // If the function call node has no input/output edges, just remove it. @@ -413,7 +413,7 @@ Status ReplaceOrRemoveOutsideCompilationCallNode( if (!has_edge) { VLOG(4) << "Did not add HostCompute node for " << call_node->DebugString(); g->RemoveNode(call_node); - return Status::OK(); + return nullptr; } // Build XlaHostCompute NodeDef. @@ -424,7 +424,7 @@ Status ReplaceOrRemoveOutsideCompilationCallNode( ReplaceNode(g, call_node, node_def)); VLOG(4) << "Added HostCompute node: " << host_compute_node->DebugString(); - return Status::OK(); + return host_compute_node; } // Resets "device_ordinal" attr to placeholder value for related nodes @@ -1634,7 +1634,7 @@ Status ExtractOutsideCompilationForFunction( RewriteOutsideCompilationSubgraphFn rewrite_fn( xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name); TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions( - outside_compilation_attr_name, "", *fbody->graph, rewrite_fn, + outside_compilation_attr_name, *fbody->graph, rewrite_fn, /*reuse_existing_functions=*/true, &graph_out, fld)); // Replace outside_compilation function nodes with HostCompute ops. @@ -1670,10 +1670,35 @@ Status ExtractOutsideCompilationForFunction( } } } + std::map host_compute_nodes; for (Node* n : outside_compilation_nodes) { TF_RETURN_IF_ERROR(ValidateOutsideCompilationCallNode(n)); - TF_RETURN_IF_ERROR(ReplaceOrRemoveOutsideCompilationCallNode( - graph_out.get(), n, host_compute_core, *cluster_deps)); + auto host_compute_node_or = ReplaceOrRemoveOutsideCompilationCallNode( + graph_out.get(), n, host_compute_core, *cluster_deps); + TF_RETURN_IF_ERROR(host_compute_node_or.status()); + Node* host_compute_node = host_compute_node_or.ValueOrDie(); + if (host_compute_node) { + host_compute_nodes[host_compute_node->name()] = host_compute_node; + } + } + // For XlaHostCompute nodes with dependencies, add control edges between them + // so XlaCompiler can handle them in correct order. + for (auto iter : host_compute_nodes) { + Node* host_compute_node = iter.second; + std::vector token_input_node_names; + TF_RETURN_IF_ERROR(GetNodeAttr(host_compute_node->def(), + kXlaTokenInputNodesAttrName, + &token_input_node_names)); + for (const string& node_name : token_input_node_names) { + if (node_name == kXlaTokenArgNodeName) { + continue; + } + + auto iter = host_compute_nodes.find(node_name); + if (iter != host_compute_nodes.end()) { + graph_out->AddControlEdge(iter->second, host_compute_node); + } + } } // Handle nodes with associated functions. diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc index 93817378e96..2717487c78e 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc @@ -990,6 +990,16 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_1->attrs()), "_xla_token_input_nodes", &token_input_nodes)); EXPECT_EQ(token_input_nodes, expected_token_input_nodes_1); + + // Check there is a control edge from host_compute_0 to host_compute_1. + bool has_control_edge = false; + for (const Edge *e : host_compute_1->in_edges()) { + if (e->IsControlEdge() && e->src() == host_compute_0) { + has_control_edge = true; + break; + } + } + EXPECT_TRUE(has_control_edge); } TEST_F(ExtractOutsideCompilationForFunctionTest, @@ -1062,5 +1072,15 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_1->attrs()), "_xla_token_input_nodes", &token_input_nodes)); EXPECT_EQ(token_input_nodes, expected_token_input_nodes_1); + + // Check there is a control edge from host_compute_0 to host_compute_1. + bool has_control_edge = false; + for (const Edge *e : host_compute_1->in_edges()) { + if (e->IsControlEdge() && e->src() == host_compute_0) { + has_control_edge = true; + break; + } + } + EXPECT_TRUE(has_control_edge); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/graphcycles/BUILD b/tensorflow/compiler/jit/graphcycles/BUILD index f9be7c45743..69c67c87615 100644 --- a/tensorflow/compiler/jit/graphcycles/BUILD +++ b/tensorflow/compiler/jit/graphcycles/BUILD @@ -1,9 +1,8 @@ -licenses(["notice"]) # Apache 2.0 - package( default_visibility = [ "//tensorflow/compiler/tf2xla:internal", ], + licenses = ["notice"], # Apache 2.0 ) load("//tensorflow:tensorflow.bzl", "tf_cc_test") diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index 3524da23fb3..0a65529cdb9 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -1,9 +1,8 @@ -licenses(["notice"]) # Apache 2.0 - package( default_visibility = [ "//tensorflow/compiler/tf2xla:internal", ], + licenses = ["notice"], # Apache 2.0 ) cc_library( @@ -29,6 +28,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:state_ops_op_lib", "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/stream_executor:tf_allocator_adapter", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", ], diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index 6df0991e354..e825a77b1d1 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -61,7 +61,7 @@ XlaPlatformInfo PlatformInfoFromContext(OpKernelConstruction* ctx) { DeviceType device_type = ctx->device_type(); se::Platform::Id platform_id = nullptr; const XlaDevice::Metadata* xla_device_metadata = nullptr; - std::unique_ptr xla_allocator; + std::unique_ptr xla_allocator; se::DeviceMemoryAllocator* device_allocator = nullptr; if (ctx->device_type() == DeviceType(DEVICE_CPU)) { @@ -93,7 +93,7 @@ XlaPlatformInfo PlatformInfoFromContext(OpKernelConstruction* ctx) { se::MultiPlatformManager::PlatformWithId(platform_id); OP_REQUIRES_OK_RETURN(ctx, XlaPlatformInfo(), maybe_platform.status()); - xla_allocator = absl::make_unique( + xla_allocator = absl::make_unique( maybe_platform.ValueOrDie(), ctx->device()->GetAllocator({})); } diff --git a/tensorflow/compiler/jit/kernels/xla_ops.h b/tensorflow/compiler/jit/kernels/xla_ops.h index eaa686780e4..3a1009ec8a7 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.h +++ b/tensorflow/compiler/jit/kernels/xla_ops.h @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/util/stream_executor_util.h" +#include "tensorflow/stream_executor/tf_allocator_adapter.h" namespace tensorflow { @@ -36,11 +37,11 @@ class XlaPlatformInfo { public: XlaPlatformInfo() : device_type_("") {} XlaPlatformInfo(XlaPlatformInfo&&) = default; - explicit XlaPlatformInfo(const DeviceType device_type, - se::Platform::Id platform_id, - const XlaDevice::Metadata* xla_device_metadata, - std::unique_ptr xla_allocator, - se::DeviceMemoryAllocator* device_allocator) + explicit XlaPlatformInfo( + const DeviceType device_type, se::Platform::Id platform_id, + const XlaDevice::Metadata* xla_device_metadata, + std::unique_ptr xla_allocator, + se::DeviceMemoryAllocator* device_allocator) : device_type_(device_type), platform_id_(platform_id), xla_device_metadata_(xla_device_metadata), @@ -84,8 +85,8 @@ class XlaPlatformInfo { // then device_allocator_ is the xla::Backend's memory allocator and // xla_allocator_ is null. If the op is placed on a regular CPU or GPU device // then device_allocator_ is null and xla_allocator_ points to an appropriate - // XlaAllocator instance. - std::unique_ptr xla_allocator_; + // se::TfAllocatorAdapter instance. + std::unique_ptr xla_allocator_; se::DeviceMemoryAllocator* device_allocator_; TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo); diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 4142de56813..81ffea31c30 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -229,32 +229,18 @@ class MarkForCompilationPassImpl { // Initialize some internal data structures. Status Initialize(); - // Runs through all the nodes in `cycles_graph_` and tries to create clusters. - // Returns true if any new clusters were created. - StatusOr RunEdgeContractionLoopInPostOrderOnce(); + // Runs through the entire cluster graph in post-order and calls `fn(from, + // to)` on each edge. `fn(from, to)` is expected to return true if it was + // able to contract `from`->`to`. + // + // Returns true if `fn` returned true for any edge. + template + StatusOr ForEachEdgeInPostOrder(FnTy fn); - // Runs through all the nodes in `cycles_graph_` and tries to contract high - // priority edges for clusters. Returns true if any new clusters were created. - // - // There are potentially many maximal clustering results, but they will not - // all be equally performant. Some clustering decision are likely to improve - // performance much more than others, and we cannot order contractions on this - // cost function, nor can we look at global information while deciding on - // individual edges to contract. Instead, we will make decisions on these - // important edges then make decisions on all other edges, causing the highest - // chance of all most important edges to be contracted. - // - // An example of where this might occur is with a digraph: - // {A -> B, B -> C, A -> X, X -> C} where B is a Size operation and X is - // not-compilable. In this case, the valid clusterings are {A,B} or {B,C}. B - // should be clustered with A because it will prevent a potentially large - // tensor from A being computed and copied. - // - // This pass will ensure that contraction happens, which cannot be enforced in - // a single pass with the current algorithm. - // graph and prevent B->C from being clusterd in anticipation of a later A->B - // cluster. - StatusOr ContractPreferredEdges(); + // If from->to is a "preferred" edge (i.e. if we have a choice, we want to + // prioritize contracting from->to over contracting other edges) then + // contracts it and returns true. Else returns false. + StatusOr ContractEdgeIfPreferred(Cluster* from, Cluster* to); // Contracts as many edges as possible to create XLA clusters. After this // finishes the clustering decisions made are implicitly stored in @@ -276,10 +262,6 @@ class MarkForCompilationPassImpl { // true if successful. StatusOr TryToContractEdge(Cluster* from, Cluster* to); - // Tries to contract each edge from `cluster_from`. Returns true if any edges - // were contracted, false otherwise. - StatusOr TryToContractEdgesFrom(Cluster* cluster_from); - // Nodes that XLA can compile are put in `compilation_candidates_`. Status FindCompilationCandidates(); @@ -401,6 +383,13 @@ class MarkForCompilationPassImpl { return true; } + string EdgeContractionFailureMsg(Cluster* from, Cluster* to, + absl::string_view reason) { + return absl::StrCat("Could not contract ", from->DebugString(*graph_), + " -> ", to->DebugString(*graph_), " because ", reason, + "."); + } + DebugOptions debug_options_; Graph* graph_; FunctionLibraryDefinition* flib_def_; @@ -611,7 +600,8 @@ Status MarkForCompilationPassImpl::Initialize() { return BuildInitialClusterSet(); } -StatusOr MarkForCompilationPassImpl::ContractPreferredEdges() { +template +StatusOr MarkForCompilationPassImpl::ForEachEdgeInPostOrder(FnTy fn) { bool changed = false; for (int32 node : cycles_graph_.AllNodesInPostOrder()) { Cluster* cluster_from = GetClusterForCyclesGraphNode(node); @@ -632,55 +622,33 @@ StatusOr MarkForCompilationPassImpl::ContractPreferredEdges() { continue; } - if (cluster_to->cluster_size() == 1) { - Node* n = graph_->FindNodeId(cluster_to->GetIdOfOnlyNode()); - - // Shape consuming operations are desirable to cluster with their - // operands because they return a small set of scalar values after - // consuming a large amount of data. For example, given a graph X -> Y - // -> Size -> Z, where the possible clustering is [{X, Y, Size}, {Z}] or - // [{X, Y}, {Size, Z}], the better clustering is Size with Y because the - // output of size will be a small tensor while Y is a potentially large - // tensor that must be computed and possible transposed/copied before - // the second cluster executes. - if (IsShapeConsumerOp(*n)) { - TF_ASSIGN_OR_RETURN(bool contracted_edge, - TryToContractEdge(cluster_from, cluster_to)); - changed |= contracted_edge; - } - } + TF_ASSIGN_OR_RETURN(bool contracted_edge, fn(cluster_from, cluster_to)); + changed |= contracted_edge; } } return changed; } -StatusOr -MarkForCompilationPassImpl::RunEdgeContractionLoopInPostOrderOnce() { - bool changed = false; - // Iterating over the graph once in post-order is sufficient to produce a - // maximal clustering: - // - // A. We visit a cluster only after maximally clustering all its children. - // B. By the time we're done with `node` (in `TryToContractEdgesFrom`) all of - // its children that could have been absorbed into `node` have been - // absorbed. - // C. We have an invariant that making a cluster larger does not make edges - // leaving it more contractable. That is, if we have - // digraph { X->Y; Y->Z; } then collapsing X->Y does not make it possible - // to contract Y->Z if Y->Z was not contractible originally. - for (int32 node : cycles_graph_.AllNodesInPostOrder()) { - Cluster* cluster_from = GetClusterForCyclesGraphNode(node); - if (!cluster_from) { - continue; - } +StatusOr MarkForCompilationPassImpl::ContractEdgeIfPreferred( + Cluster* from, Cluster* to) { + if (to->cluster_size() == 1) { + Node* n = graph_->FindNodeId(to->GetIdOfOnlyNode()); - TF_ASSIGN_OR_RETURN(bool contracted_one_edge, - TryToContractEdgesFrom(cluster_from)); - changed |= contracted_one_edge; + // Shape consuming operations are desirable to cluster with their + // operands because they return a small set of scalar values after + // consuming a large amount of data. For example, given a graph X -> Y + // -> Size -> Z, where the possible clustering is [{X, Y, Size}, {Z}] or + // [{X, Y}, {Size, Z}], the better clustering is Size with Y because the + // output of size will be a small tensor while Y is a potentially large + // tensor that must be computed and possible transposed/copied before + // the second cluster executes. + if (IsShapeConsumerOp(*n)) { + return TryToContractEdge(from, to); + } } - return changed; + return false; } Status MarkForCompilationPassImpl::RunEdgeContractionLoop() { @@ -694,25 +662,68 @@ Status MarkForCompilationPassImpl::RunEdgeContractionLoop() { // without restrictions. This helps to minimize data output from clusters (and // possible transpose operations before outputs) that might occur if a // ShapeConsumingOp is on the edge of 2 clusters due to cycle considerations. - TF_ASSIGN_OR_RETURN(bool changed, ContractPreferredEdges()); + // + // There are potentially many maximal clustering results, but they will not + // all be equally performant. Some clustering decision are likely to improve + // performance much more than others, and we cannot order contractions on this + // cost function, nor can we look at global information while deciding on + // individual edges to contract. Instead, we will make decisions on these + // important edges then make decisions on all other edges, causing the highest + // chance of all most important edges to be contracted. + // + // An example of where this might occur is with a digraph: + // {A -> B, B -> C, A -> X, X -> C} where B is a Size operation and X is + // not-compilable. In this case, the valid clusterings are {A,B} or {B,C}. B + // should be clustered with A because it will prevent a potentially large + // tensor from A being computed and copied. + // + // This pass will ensure that contraction happens, which cannot be enforced in + // a single pass with the current algorithm. + // graph and prevent B->C from being clusterd in anticipation of a later A->B + // cluster. - TF_ASSIGN_OR_RETURN(changed, RunEdgeContractionLoopInPostOrderOnce()); + TF_ASSIGN_OR_RETURN(bool changed, + ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) { + return ContractEdgeIfPreferred(from, to); + })); - // Check that RunEdgeContractionLoopInPostOrderOnce is idempotent. Once the - // linear time post-order scheme has been battle tested we can move this to - // happen only in debug builds. - TF_ASSIGN_OR_RETURN(changed, RunEdgeContractionLoopInPostOrderOnce()); + // Iterating over the whole graph once in post-order is sufficient to produce + // a maximal clustering: + // + // A. We visit a cluster only after maximally clustering all its children. + // B. By the time we're done with `node` (in `TryToContractEdgesFrom`) all of + // its children that could have been absorbed into `node` have been + // absorbed. + // C. We have an invariant that making a cluster larger does not make edges + // leaving it more contractable. That is, if we have + // digraph { X->Y; Y->Z; } then collapsing X->Y does not make it possible + // to contract Y->Z if Y->Z was not contractible originally. + TF_ASSIGN_OR_RETURN(changed, + ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) { + return TryToContractEdge(from, to); + })); + + // Check that the conclusion made above (that iterating over the graph once in + // post order gives a maximal clustering) holds. Once the linear time + // post-order scheme has been battle tested we can move this to happen only in + // debug builds. + TF_ASSIGN_OR_RETURN(changed, + ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) { + return TryToContractEdge(from, to); + })); TF_RET_CHECK(!changed); return Status::OK(); } +std::atomic cluster_sequence_num; + +int64 GetNextClusterSequenceNumber() { return cluster_sequence_num++; } + Status MarkForCompilationPassImpl::CreateClusters() { TF_RET_CHECK(initialized_ && edges_contracted_ && !clusters_created_); clusters_created_ = true; - static std::atomic cluster_sequence_num; - // Names for each cluster. std::unordered_map cluster_names; @@ -745,7 +756,7 @@ Status MarkForCompilationPassImpl::CreateClusters() { string& name = cluster_names[cluster->cycles_graph_node_id()]; if (name.empty()) { - name = absl::StrCat("cluster_", cluster_sequence_num++); + name = absl::StrCat("cluster_", GetNextClusterSequenceNumber()); } n->AddAttr(kXlaClusterAttr, name); @@ -1065,8 +1076,7 @@ bool MarkForCompilationPassImpl::CompilationDisallowedByXlaCompileAttr( bool MarkForCompilationPassImpl::LogNotContractableAndReturnFalse( Cluster* from, Cluster* to, absl::string_view reason) { - VLOG(3) << "Could not contract " << from->DebugString(*graph_) << " -> " - << to->DebugString(*graph_) << " because " << reason << "."; + VLOG(3) << EdgeContractionFailureMsg(from, to, reason); return false; } @@ -1075,8 +1085,14 @@ StatusOr MarkForCompilationPassImpl::TryToContractEdge(Cluster* from, DCHECK(from->deadness_predicate().has_value() == to->deadness_predicate().has_value()); if (from->deadness_predicate() != to->deadness_predicate()) { - return LogNotContractableAndReturnFalse( - from, to, "the two nodes have mismatching deadness"); + VLOG(3) << EdgeContractionFailureMsg( + from, to, + absl::StrCat( + "the two nodes have mismatching deadness: ", + deadness_analysis_->DebugString(*from->deadness_predicate()), + " and ", + deadness_analysis_->DebugString(*to->deadness_predicate()))); + return false; } TF_ASSIGN_OR_RETURN(bool devices_compatible, @@ -1133,32 +1149,6 @@ StatusOr MarkForCompilationPassImpl::TryToContractEdge(Cluster* from, return MergeClusters(from, to); } -StatusOr MarkForCompilationPassImpl::TryToContractEdgesFrom( - Cluster* cluster_from) { - bool changed = false; - - // Make a copy of the set of successors because we may modify the graph in - // TryToContractEdge. - std::vector successors_copy = - cycles_graph_.SuccessorsCopy(cluster_from->cycles_graph_node_id()); - - for (int to : successors_copy) { - iteration_count_++; - - Cluster* cluster_to = GetClusterForCyclesGraphNode(to); - if (!cluster_to) { - continue; - } - - TF_ASSIGN_OR_RETURN(bool contracted_edge, - TryToContractEdge(cluster_from, cluster_to)); - - changed |= contracted_edge; - } - - return changed; -} - Status MarkForCompilationPassImpl::Run() { // Make sure that kernels have been registered on the JIT device. XlaOpRegistry::RegisterCompilationKernels(); @@ -1485,7 +1475,8 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) { op_filter.allow_control_trigger = true; op_filter.allow_eliding_assert_and_checknumerics_ops = true; op_filter.allow_ops_producing_or_consuming_variant = true; - op_filter.allow_slow_and_inaccurate_ops = true; + op_filter.allow_slow_ops = true; + op_filter.allow_inaccurate_ops = true; return RecursiveCompilabilityChecker{&op_filter, &jit_device_type} .IsCompilableCall(ndef, flr); @@ -1522,4 +1513,8 @@ Status MarkForCompilationPass::RunForTest( return MarkForCompilation(options, debug_options); } + +namespace testing { +void ResetClusterSequenceNumber() { cluster_sequence_num = 0; } +} // namespace testing } // namespace tensorflow diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.h b/tensorflow/compiler/jit/mark_for_compilation_pass.h index 16b8427b60e..2eee144e645 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.h +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.h @@ -51,6 +51,13 @@ class MarkForCompilationPass : public GraphOptimizationPass { // function is compilable iff every operator in the function body is // compilable. bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef); + +namespace testing { +// DO NOT USE IN PRODUCTION. +// +// Resets some internal state to let us write reliable unit tests. +void ResetClusterSequenceNumber(); +} // namespace testing } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_H_ diff --git a/tensorflow/compiler/jit/ops/BUILD b/tensorflow/compiler/jit/ops/BUILD index 64409d93347..3b7a74ec780 100644 --- a/tensorflow/compiler/jit/ops/BUILD +++ b/tensorflow/compiler/jit/ops/BUILD @@ -1,7 +1,6 @@ -licenses(["notice"]) # Apache 2.0 - package( default_visibility = ["//tensorflow/compiler/tf2xla:internal"], + licenses = ["notice"], # Apache 2.0 ) load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") diff --git a/tensorflow/compiler/jit/test_util.cc b/tensorflow/compiler/jit/test_util.cc index cada272090a..f50ecdf2287 100644 --- a/tensorflow/compiler/jit/test_util.cc +++ b/tensorflow/compiler/jit/test_util.cc @@ -49,7 +49,7 @@ Status ShapeAnnotationsMatch( missing.push_back(entry.first); } return errors::InvalidArgument("Missing shapes for nodes: ", - str_util::Join(missing, ",")); + absl::StrJoin(missing, ",")); } return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index 19e3793f29b..fbfda449ebd 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -60,7 +60,8 @@ Status XlaCpuDeviceFactory::CreateDevices( registration.cluster_control_trigger = true; registration.elide_assert_and_checknumerics = true; registration.cluster_variant_ops = true; - registration.cluster_slow_and_inaccurate_ops = true; + registration.cluster_slow_ops = true; + registration.cluster_inaccurate_ops = true; XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_CPU, registration); static XlaDeviceOpRegistrations* registrations = diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index 3b9c4160b95..ff5d5d38e8c 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -71,7 +71,7 @@ class XlaDeviceContext : public DeviceContext { StatusCallback done) const override; xla::LocalClient* client() const { return client_; } - se::Stream* stream() const { return stream_.get(); } + se::Stream* stream() const override { return stream_.get(); } se::Stream* host_to_device_stream() const { return host_to_device_stream_.get(); } diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 293ea3997cc..68c8b64cc82 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/core/kernels/host_constant_op.h" #include "tensorflow/core/kernels/identity_n_op.h" #include "tensorflow/core/kernels/identity_op.h" +#include "tensorflow/core/kernels/logging_ops.h" #include "tensorflow/core/kernels/no_op.h" #include "tensorflow/core/kernels/queue_op.h" #include "tensorflow/core/kernels/resource_variable_ops.h" @@ -81,6 +82,11 @@ class XlaAssignVariableOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE), KERNEL); #define REGISTER_XLA_DEVICE_KERNELS(DEVICE, TYPES) \ + REGISTER_KERNEL_BUILDER(Name("Assert") \ + .Device(DEVICE) \ + .HostMemory("condition") \ + .HostMemory("data"), \ + AssertOp); \ REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE), SendOp); \ REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE), RecvOp); \ REGISTER_KERNEL_BUILDER( \ diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index 913612f9a6c..02eed3ee16f 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -95,7 +95,8 @@ Status XlaGpuDeviceFactory::CreateDevices( registration.cluster_control_trigger = true; registration.elide_assert_and_checknumerics = true; registration.cluster_variant_ops = true; - registration.cluster_slow_and_inaccurate_ops = true; + registration.cluster_slow_ops = true; + registration.cluster_inaccurate_ops = true; XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_GPU, registration); static XlaDeviceOpRegistrations* registrations = diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc index 4252e2e24ac..f720183e196 100644 --- a/tensorflow/compiler/jit/xla_interpreter_device.cc +++ b/tensorflow/compiler/jit/xla_interpreter_device.cc @@ -63,7 +63,8 @@ Status XlaInterpreterDeviceFactory::CreateDevices( registration.cluster_control_trigger = true; registration.elide_assert_and_checknumerics = true; registration.cluster_variant_ops = true; - registration.cluster_slow_and_inaccurate_ops = true; + registration.cluster_slow_ops = true; + registration.cluster_inaccurate_ops = true; XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_INTERPRETER, registration); diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 3bb698b33d6..d66c80fea90 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -167,32 +167,6 @@ Status SnapshotResourceVariables(OpKernelContext* ctx, return Status::OK(); } -XlaAllocator::XlaAllocator(const se::Platform* platform, Allocator* wrapped) - : se::DeviceMemoryAllocator(platform), wrapped_(wrapped) {} - -XlaAllocator::~XlaAllocator() {} - -xla::StatusOr XlaAllocator::Allocate( - int device_ordinal, uint64 size, bool retry_on_failure) { - AllocationAttributes attrs; - attrs.no_retry_on_failure = !retry_on_failure; - void* data = nullptr; - if (size != 0) { - data = wrapped_->AllocateRaw(Allocator::kAllocatorAlignment, size, attrs); - if (data == nullptr) { - return errors::ResourceExhausted( - "Out of memory while trying to allocate ", size, " bytes."); - } - } - return se::OwningDeviceMemory(se::DeviceMemoryBase(data, size), - device_ordinal, this); -} - -Status XlaAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) { - wrapped_->DeallocateRaw(mem.opaque()); - return Status::OK(); -} - XlaComputationLaunchContext::XlaComputationLaunchContext( xla::LocalClient* client, se::DeviceMemoryAllocator* xla_allocator, bool allocate_xla_tensors, bool use_multiple_streams) diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 4cb020ffe34..429ff0a065c 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -32,7 +32,6 @@ limitations under the License. #include "tensorflow/stream_executor/device_memory_allocator.h" namespace tensorflow { -class XlaAllocator; // Struct that represents a possibly-absent Tensor. struct OptionalTensor { @@ -104,74 +103,6 @@ class VariableInfo { Status LockVariables(absl::Span variables) EXCLUSIVE_LOCK_FUNCTION(); -// Adapter class that wraps a Tensorflow allocator as an XLA allocator. -// Assumes that the Tensorflow allocator permits asynchronous deallocation: -// see comment on `AllowsAsynchronousDeallocation()`. -class XlaAllocator : public se::DeviceMemoryAllocator { - public: - XlaAllocator(const se::Platform* platform, Allocator* wrapped); - ~XlaAllocator() override; - xla::StatusOr Allocate( - int device_ordinal, uint64 size, bool retry_on_failure) override; - Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override; - - // The Tensorflow BFC allocator used on GPU allows host-side deallocation - // before GPU execution takes place. Tensorflow uses the ordering of the main - // compute stream to enforce a happens-before relationship between a memory - // allocation and code that reuses the same memory. If Tensorflow adds - // support for multiple GPU streams or allocators with different ordering - // requirements, this code may need to change. - // (This attribute has no effect on CPU.) - bool AllowsAsynchronousDeallocation() const override { return true; } - - private: - Allocator* wrapped_; -}; - -// Adapter class that wraps per-device TF allocators as an XLA allocator. -// Assumes that the Tensorflow allocator permits asynchronous deallocation; -// see comment on `AllowsAsynchronousDeallocation()`. -class MultiDeviceAdapter : public se::DeviceMemoryAllocator { - public: - MultiDeviceAdapter( - const se::Platform* platform, - std::vector> tf_allocators) - : DeviceMemoryAllocator(platform), - tf_allocators_(std::move(tf_allocators)) { - for (const auto& tf_allocator : tf_allocators_) { - per_device_allocators_.emplace_back(platform, tf_allocator.get()); - } - } - - xla::StatusOr Allocate( - int device_ordinal, uint64 size, bool retry_on_failure) override { - CHECK_LT(device_ordinal, per_device_allocators_.size()); - return per_device_allocators_[device_ordinal].Allocate(device_ordinal, size, - retry_on_failure); - } - - Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override { - CHECK_LT(device_ordinal, per_device_allocators_.size()); - return per_device_allocators_[device_ordinal].Deallocate(device_ordinal, - mem); - } - - // The Tensorflow BFC allocator used on GPU allows host-side deallocation - // before GPU execution takes place. Tensorflow uses the ordering of the main - // compute stream to enforce a happens-before relationship between a memory - // allocation and code that reuses the same memory. If Tensorflow adds - // support for multiple GPU streams or allocators with different ordering - // requirements, this code may need to change. - // (This attribute has no effect on CPU.) - bool AllowsAsynchronousDeallocation() const override { return true; } - - private: - std::vector per_device_allocators_; - // The wrapped TF allocators backing per_device_allocators_ (XlaAllocator does - // not take ownership of its underlying Allocator). - std::vector> tf_allocators_; -}; - // Helper class to perform the marshalling of TensorFlow inputs and outputs to // ShapedBuffers suitable for passing to an XLA computation. class XlaComputationLaunchContext { diff --git a/tensorflow/compiler/plugin/BUILD b/tensorflow/compiler/plugin/BUILD index 238fd15166c..c2ba5cb3ecd 100644 --- a/tensorflow/compiler/plugin/BUILD +++ b/tensorflow/compiler/plugin/BUILD @@ -28,10 +28,9 @@ ** Please don't remove this file - it is supporting some 3rd party plugins ** """ -licenses(["notice"]) - package( default_visibility = ["//visibility:public"], + licenses = ["notice"], ) cc_library( diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index fbb60d17316..43dbab1e9a7 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -954,7 +954,7 @@ tf_xla_py_test( tf_xla_py_test( name = "ternary_ops_test", - size = "small", + size = "medium", srcs = ["ternary_ops_test.py"], deps = [ ":xla_test", diff --git a/tensorflow/compiler/tests/cond_test.py b/tensorflow/compiler/tests/cond_test.py index 5963020bbb7..a28c2c5ca88 100644 --- a/tensorflow/compiler/tests/cond_test.py +++ b/tensorflow/compiler/tests/cond_test.py @@ -19,11 +19,13 @@ from __future__ import division from __future__ import print_function from tensorflow.compiler.tests import xla_test +from tensorflow.python.client import session from tensorflow.python.compiler.xla import xla from tensorflow.python.eager import function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -46,8 +48,8 @@ class CondTest(xla_test.XLATestCase): def f(): ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1) output = control_flow_ops.cond( - constant_op.constant( - True), lambda: ta.write(0, 5.), lambda: ta.write(0, 10.)) + constant_op.constant(True), + lambda: ta.write(0, 5.), lambda: ta.write(0, 10.)) return output.stack() @@ -56,6 +58,46 @@ class CondTest(xla_test.XLATestCase): xla_context.Exit() + def testCondAndTensorArrayInDefun_constFolding(self): + g = ops.Graph() + with session.Session(graph=g), g.as_default(), self.test_scope(): + xla_context = control_flow_ops.XLAControlFlowContext() + xla_context.Enter() + + @function.defun + def f(): + ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1) + output = control_flow_ops.cond( + constant_op.constant(False), + lambda: ta.write(0, 5.), lambda: ta.write(0, 10.)) + + return output.stack() + + output_t = f() + self.assertAllEqual([10.], self.evaluate(output_t)) + + xla_context.Exit() + + def testCondAndTensorArray_xlaCompile(self): + self.skipTest("b/127846988") + # Fails with "Uninitialized arguments" in XlaIfOp::Compile + with self.session(), self.test_scope(): + xla_context = control_flow_ops.XLAControlFlowContext() + xla_context.Enter() + + def f(): + ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1) + output = control_flow_ops.cond( + constant_op.constant(True), + lambda: ta.write(0, 5.), lambda: ta.write(0, 10.)) + + return output.stack() + + output_t, = xla.compile(f) + self.assertAllEqual([5.], self.evaluate(output_t)) + + xla_context.Exit() + def testCondConstPropagation(self): with self.session() as sess, self.test_scope(): xla_context = control_flow_ops.XLAControlFlowContext() @@ -199,6 +241,28 @@ class CondTest(xla_test.XLATestCase): xla_context.Exit() + def testSwitchCaseAndTensorArray_xlaCompile(self): + self.skipTest("b/127846988") + with self.session(), self.test_scope(): + xla_context = control_flow_ops.XLAControlFlowContext() + xla_context.Enter() + + def f(): + ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1) + output = control_flow_ops.switch_case( + constant_op.constant(1), { + 0: lambda: ta.write(0, 5.), + 1: lambda: ta.write(0, 10.), + 2: lambda: ta.write(0, 15.), + }) + + return output.stack() + + output_t, = xla.compile(f) + self.assertAllEqual([10.], self.evaluate(output_t)) + + xla_context.Exit() + def testSwitchCaseConstPropagation(self): self.skipTest("b/127846988") with self.session() as sess, self.test_scope(): diff --git a/tensorflow/compiler/tests/extract_image_patches_op_test.py b/tensorflow/compiler/tests/extract_image_patches_op_test.py index 9e9b1f367e2..d0686c4bcb8 100644 --- a/tensorflow/compiler/tests/extract_image_patches_op_test.py +++ b/tensorflow/compiler/tests/extract_image_patches_op_test.py @@ -130,5 +130,20 @@ class ExtractImagePatches(xla_test.XLATestCase): padding="VALID", patches=patches) + def testKsize2x2Stride1x1Rate1x1ValidDepth2(self): + """Test for 2x2 kernel with VALID padding.""" + # [1, 2, 2, 2] + image = [[[[1, 5], [2, 6]], [[3, 7], [4, 8]]]] + # [1, 1, 1, 8] + patches = [[[[1, 5, 2, 6, 3, 7, 4, 8]]]] + self._VerifyValues( + image, + ksizes=[2, 2], + strides=[1, 1], + rates=[1, 1], + padding="VALID", + patches=patches) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py index 7e8edc5f0b1..200851ee500 100644 --- a/tensorflow/compiler/tests/ternary_ops_test.py +++ b/tensorflow/compiler/tests/ternary_ops_test.py @@ -72,21 +72,21 @@ class TernaryOpsTest(xla_test.XLATestCase): for dtype in self.numeric_types: self._testTernary( array_ops.where, - np.array(0, dtype=np.bool), + np.array(False), np.array(2, dtype=dtype), np.array(7, dtype=dtype), expected=np.array(7, dtype=dtype)) self._testTernary( array_ops.where, - np.array(1, dtype=np.bool), + np.array(True), np.array([1, 2, 3, 4], dtype=dtype), np.array([5, 6, 7, 8], dtype=dtype), expected=np.array([1, 2, 3, 4], dtype=dtype)) self._testTernary( array_ops.where, - np.array(0, dtype=np.bool), + np.array(False), np.array([[1, 2], [3, 4], [5, 6]], dtype=dtype), np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype), expected=np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype)) @@ -105,6 +105,74 @@ class TernaryOpsTest(xla_test.XLATestCase): np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype), expected=np.array([[7, 8], [3, 4], [11, 12]], dtype=dtype)) + def testSelectV2(self): + for dtype in self.numeric_types: + self._testTernary( + array_ops.where_v2, + np.array(False), + np.array(2, dtype=dtype), + np.array(7, dtype=dtype), + expected=np.array(7, dtype=dtype)) + + self._testTernary( + array_ops.where_v2, + np.array(True), + np.array([1, 2, 3, 4], dtype=dtype), + np.array([5, 6, 7, 8], dtype=dtype), + expected=np.array([1, 2, 3, 4], dtype=dtype)) + + self._testTernary( + array_ops.where_v2, + np.array(False), + np.array([[1, 2], [3, 4], [5, 6]], dtype=dtype), + np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype), + expected=np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype)) + + self._testTernary( + array_ops.where_v2, + np.array([0, 1, 1, 0], dtype=np.bool), + np.array([1, 2, 3, 4], dtype=dtype), + np.array([5, 6, 7, 8], dtype=dtype), + expected=np.array([5, 2, 3, 8], dtype=dtype)) + + # Broadcast the condition + self._testTernary( + array_ops.where_v2, + np.array([0, 1], dtype=np.bool), + np.array([[1, 2], [3, 4], [5, 6]], dtype=dtype), + np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype), + expected=np.array([[7, 2], [9, 4], [11, 6]], dtype=dtype)) + + # Broadcast the then branch to the else + self._testTernary( + array_ops.where_v2, + np.array([[0, 1], [1, 0], [1, 1]], dtype=np.bool), + np.array([[1, 2]], dtype=dtype), + np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype), + expected=np.array([[7, 2], [1, 10], [1, 2]], dtype=dtype)) + + # Broadcast the else branch to the then + self._testTernary( + array_ops.where_v2, + np.array([[1, 0], [0, 1], [0, 0]], dtype=np.bool), + np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype), + np.array([[1, 2]], dtype=dtype), + expected=np.array([[7, 2], [1, 10], [1, 2]], dtype=dtype)) + + # Broadcast the then/else branches to the condition + self._testTernary( + array_ops.where_v2, + np.array([[1, 0], [0, 1], [1, 1]], dtype=np.bool), + np.array(7, dtype=dtype), + np.array(8, dtype=dtype), + expected=np.array([[7, 8], [8, 7], [7, 7]], dtype=dtype)) + self._testTernary( + array_ops.where_v2, + np.array([[1, 0], [0, 1], [0, 0]], dtype=np.bool), + np.array(7, dtype=dtype), + np.array([8, 9], dtype=dtype), + expected=np.array([[7, 9], [8, 7], [8, 9]], dtype=dtype)) + def testSlice(self): for dtype in self.numeric_types: self._testTernary( diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc index 723eba7eb96..34d4ee79542 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc @@ -104,6 +104,24 @@ struct EdgePtrCompare { } }; +// TODO(laigd): instead of deciding the device here, the converter should accept +// a device name as one of the conversion parameter so users can control on +// which device they want to run the conversion. +std::pair GetFirstValidDeviceId() { + for (int tf_gpu_id_value = 0; tf_gpu_id_value < 100; ++tf_gpu_id_value) { + TfGpuId tf_gpu_id(tf_gpu_id_value); + PlatformGpuId platform_gpu_id; + Status s = GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id); + if (s.ok()) { + VLOG(1) << "Found TF GPU " << tf_gpu_id.value() << " at cuda device " + << platform_gpu_id.value(); + return std::make_pair(tf_gpu_id, platform_gpu_id); + } + } + LOG(ERROR) << "Could not find any TF GPUs"; + return std::make_pair(TfGpuId(-1), PlatformGpuId(-1)); +} + // Function to get subsegment information structure. Status GetEngineInfo(const Graph* g, const grappler::GraphProperties& graph_properties, @@ -128,27 +146,43 @@ Status GetEngineInfo(const Graph* g, if (segment_nodes.count(node) == 0) continue; auto node_device = node->requested_device(); if (!node_device.empty()) { - // If device is CPU, treat as if no device was assigned. Don't add CPU to - // segment_device because that would cause a segfault in - // GetDeviceAndAllocator. This is because GetDeviceAndAllocator assumes - // any already set device is a GPU. + // If device is set, it means device placement may have been done before, + // so we need to assign a device for the TRTEngineOp to maintain the + // invariance. + // If the device is CPU in this case, it tries to find the first available + // GPU and use it as the device. DeviceNameUtils::ParsedName parsed_name; - DeviceNameUtils::ParseFullName(node_device, &parsed_name); - if (parsed_name.type == "CPU") { - VLOG(1) << "Node " << node->name() << " was assigned to the CPU. " - << "Attempting to place on GPU."; + const bool parse_succeeded = + DeviceNameUtils::ParseFullName(node_device, &parsed_name); + if (!parse_succeeded || (parse_succeeded && parsed_name.type == "CPU")) { + string msg; + if (!parse_succeeded) { + msg = StrCat("Failed to parse assigned device of node ", node->name(), + ". "); + } else { + msg = StrCat("Node ", node->name(), " was assigned to the CPU. "); + } + VLOG(1) << msg << "Attempting to place on GPU."; + TfGpuId tf_gpu_id; + PlatformGpuId platform_gpu_id; + std::tie(tf_gpu_id, platform_gpu_id) = GetFirstValidDeviceId(); + if (tf_gpu_id.value() >= 0) { + parsed_name.type = "GPU"; + parsed_name.id = tf_gpu_id.value(); + segment_devices.insert(DeviceNameUtils::FullName( + parsed_name.job, parsed_name.replica, parsed_name.task, + parsed_name.type, parsed_name.id)); + } } else { segment_devices.insert(node_device); } + } else if (node->has_assigned_device_name()) { + // It appears that nodes will not have assigned devices at this point in + // execution. + segment_devices.insert(node->assigned_device_name()); } else { - if (node->has_assigned_device_name()) { - // It appears that nodes will not have assigned devices at this point in - // execution. - segment_devices.insert(node->assigned_device_name()); - } else { - VLOG(2) << "Node " << node->name() - << " neither have requested device nor assigned device"; - } + VLOG(2) << "Node " << node->name() + << " neither have requested device nor assigned device"; } subgraph_nodes.push_back(node); @@ -251,13 +285,11 @@ Status GetEngineInfo(const Graph* g, info->engine_name = StrCat(scope_name, info->engine_name); VLOG(1) << "Converted TensorRT candidate segment '" << info->engine_name << "' to a GraphDef"; - // TODO(sami): This should not happen once segmenter is updated. if (segment_devices.size() == 1) { info->device = *segment_devices.begin(); } else if (segment_devices.size() > 1) { - LOG(WARNING) << "Detected multiple(" << segment_devices.size() - << ") devices for the segment. Picking first one to continue " - << "but this shouldn't have happened"; + LOG(WARNING) << "Detected multiple (" << segment_devices.size() + << ") devices for the segment. Picking first one to continue."; info->device = *segment_devices.begin(); } else { VLOG(1) << "No device is assigned to the segment. " @@ -543,10 +575,10 @@ Status RegisterSegmentFunctionToFunctionLibrary(Graph* graph, std::map io_nodes; int num_inputs = 0; for (auto n : sgraph.op_nodes()) { - if (str_util::StartsWith(n->name(), kInputPHName)) { + if (absl::StartsWith(n->name(), kInputPHName)) { num_inputs++; io_nodes.insert({n->name(), n}); - } else if (str_util::StartsWith(n->name(), kOutputPHName)) { + } else if (absl::StartsWith(n->name(), kOutputPHName)) { io_nodes.insert({n->name(), n}); } } @@ -640,24 +672,17 @@ std::pair GetDeviceAndAllocator(const ConversionParams& params, if (params.cluster == nullptr || params.cluster->GetDeviceSet() == nullptr || engine.device.empty()) { // If device is not set, use the first found GPU device for the conversion. - for (int tf_gpu_id_value = 0; tf_gpu_id_value < 100; ++tf_gpu_id_value) { - TfGpuId tf_gpu_id(tf_gpu_id_value); - PlatformGpuId platform_gpu_id; - Status s = GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id); - if (s.ok()) { - VLOG(1) << "Found TF GPU " << tf_gpu_id.value() << " at cuda device " - << platform_gpu_id.value(); - cuda_device_id = platform_gpu_id.value(); - GPUOptions gpu_options; - // If the TF to Cuda gpu id mapping exist, the device and corresponding - // allocator must have been initialized already, so the - // GetGPUAllocator() call won't create a new allocator. - dev_allocator = GPUProcessState::singleton()->GetGPUAllocator( - gpu_options, tf_gpu_id, 1); - break; - } - LOG(ERROR) << "TF GPU with id " << tf_gpu_id_value << " does not exist " - << s; + TfGpuId tf_gpu_id; + PlatformGpuId platform_gpu_id; + std::tie(tf_gpu_id, platform_gpu_id) = GetFirstValidDeviceId(); + cuda_device_id = platform_gpu_id.value(); + if (cuda_device_id >= 0) { + GPUOptions gpu_options; + // If the TF to Cuda gpu id mapping exist, the device and corresponding + // allocator must have been initialized already, so the + // GetGPUAllocator() call won't create a new allocator. + dev_allocator = GPUProcessState::singleton()->GetGPUAllocator( + gpu_options, tf_gpu_id, 1); } return std::make_pair(cuda_device_id, dev_allocator); } @@ -750,8 +775,8 @@ Status ConvertAfterShapes(const ConversionParams& params) { EngineInfo curr_engine; curr_engine.engine_name = StrCat("TRTEngineOp_", t); Status status = - GetEngineInfo(&graph, *params.graph_properties, curr_segment.first, - node_map, reverse_topo_order, &curr_engine); + GetEngineInfo(&graph, *params.graph_properties, curr_segment, node_map, + reverse_topo_order, &curr_engine); if (!status.ok()) { LOG(WARNING) << "Failed to get engine info for segment " << t << ": " << status; @@ -776,7 +801,7 @@ Status ConvertAfterShapes(const ConversionParams& params) { engine_bytes_size.push_back(curr_engine.segment_graph_def.ByteSizeLong()); total_engine_bytes_size += engine_bytes_size.back(); - total_num_nodes_in_segments += curr_segment.first.size(); + total_num_nodes_in_segments += curr_segment.size(); engine_segments.push_back(std::move(curr_engine)); converted_segments.push_back(std::move(curr_segment)); @@ -806,7 +831,7 @@ Status ConvertAfterShapes(const ConversionParams& params) { engine.max_workspace_size_bytes = params.max_workspace_size_bytes * (engine_bytes_size.at(i) / total_engine_bytes_size + - converted_segments.at(i).first.size() / total_num_nodes_in_segments) / + converted_segments.at(i).size() / total_num_nodes_in_segments) / 2.0; VLOG(1) << "Assigned " << engine.max_workspace_size_bytes << " bytes to " << engine.engine_name; @@ -828,9 +853,9 @@ Status ConvertAfterShapes(const ConversionParams& params) { CreateTRTNode(params, engine_segments, i, params.max_batch_size, &graph, alloc.get(), &engine_nodes); - string msg = StrCat("TensorRT node ", engine.engine_name, - " added for segment ", i, " consisting of ", - converted_segments.at(i).first.size(), " nodes"); + string msg = + StrCat("TensorRT node ", engine.engine_name, " added for segment ", i, + " consisting of ", converted_segments.at(i).size(), " nodes"); if (status.ok()) { LOG(INFO) << msg << " succeeded."; } else { @@ -839,7 +864,7 @@ Status ConvertAfterShapes(const ConversionParams& params) { } if (VLOG_IS_ON(1)) { msg = "Segment consists of nodes: "; - for (const Node* node : converted_segments.at(i).first) { + for (const Node* node : converted_segments.at(i)) { StrAppend(&msg, node->name(), ", "); } VLOG(1) << msg; @@ -848,7 +873,7 @@ Status ConvertAfterShapes(const ConversionParams& params) { // If status is ok, we successfully added the node to the graph and can // remove segment ops. Otherwise graph is not modified. if (status.ok()) { - for (const Node* node : converted_segments.at(i).first) { + for (const Node* node : converted_segments.at(i)) { graph.RemoveNode(const_cast(node)); } } diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc index d8db0ffac7e..647c9b5068b 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc @@ -239,7 +239,7 @@ class ConvertAfterShapesTest : public ::testing::Test { params.output_names = &output_names; params.max_workspace_size_bytes = 8 << 20; params.output_graph_def = output_graph_def; - params.minimum_segment_size = 2; + params.minimum_segment_size = 1; params.graph_properties = &graph_properties; params.use_calibration = false; diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index 0ac508822f1..a1ccb3b3e6e 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -385,11 +385,10 @@ string DebugString(const nvinfer1::ITensor& tensor) { ", dims=", DebugString(tensor.getDimensions()), ")"); } -Status Converter::GetTrtBroadcastShape( - const TRT_TensorOrWeights& operand_l, const TRT_TensorOrWeights& operand_r, - nvinfer1::Dims* operand_l_new_dims, - nvinfer1::Dims* operand_r_new_dims) const { - // *************************************************************************** +Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l, + const TRT_TensorOrWeights& operand_r, + nvinfer1::Dims* operand_l_new_dims, + nvinfer1::Dims* operand_r_new_dims) { // TensorRT Elementwise op supports broadcast but requires both tensor to be // of Identical rank // @@ -473,14 +472,13 @@ nvinfer1::ITensor* Converter::CreateConstantLayer( nvinfer1::Weights trt_weights = weights.GetTrtWeights(); nvinfer1::IConstantLayer* layer = network()->addConstant(dims, trt_weights); if (!layer) return nullptr; - const nvinfer1::DataType trt_dtype = trt_weights.type; nvinfer1::ITensor* trt_tensor = layer->getOutput(0); #if !IS_TRT_VERSION_GE(5, 1, 3, 0) // TODO(laigd): there is a bug in TensorRT 5.0 library that, if we don't set // the data type below, it will always be kFLOAT regardless what the data type // of the weights is. Once NVIDIA fixes this bug, we should remove the data // type setting logic below and test should still pass. - trt_tensor->setType(trt_dtype); + trt_tensor->setType(trt_weights.type); #endif return trt_tensor; } @@ -1677,190 +1675,6 @@ Status UnaryCompute(const TRT_ShapedWeights& iweights, return Status::OK(); } -// If swapped_inputs is false, 'tensor' is the left operand and 'weights' is the -// right operand. If swapped_inputs is true, those two are swapped. -// -// TODO(jie): broadcast is needed yet not implemented. -// Only implemented channel wise for the time being. -Status BinaryTensorOpWeight(OpConverterParams* params, - nvinfer1::ITensor* tensor, - TRT_ShapedWeights weights, bool swapped_inputs) { - static const std::unordered_set supported_ops = {"Sub", "Add", "Mul", - "Div", "RealDiv"}; - const auto& node_def = params->node_def; - if (!supported_ops.count(node_def.op())) { - return errors::Unimplemented(node_def.op(), " is not supported, at ", - node_def.name()); - } - - // Check scale mode. - auto dims_w = weights.shape_; - const auto dims_t = tensor->getDimensions(); - - // TODO(jie): addScale checks for input tensor dimension - if (dims_t.nbDims != 3) { - return errors::InvalidArgument("addScale requires tensor with rank 3, at ", - node_def.name()); - } - - // Default to element-wise - auto scale_mode = nvinfer1::ScaleMode::kELEMENTWISE; - - // TODO(jie): maybe use a permutation instead to support more cases; - bool need_to_permute = false; - - if (weights.count() == 1) { - scale_mode = nvinfer1::ScaleMode::kUNIFORM; - } else { - VLOG(2) << "weights dims: " << DebugString(dims_w) - << "; tensor dims: " << DebugString(dims_t); - // Make sure no broadcasting on batch dimension. - if (dims_w.nbDims == dims_t.nbDims + 1) { - if (dims_w.d[0] == 1) { - for (int i = 1; i < dims_w.nbDims; i++) { - dims_w.d[i - 1] = dims_w.d[i]; - } - dims_w.nbDims--; - } else { - return errors::InvalidArgument("Binary op cannot operate on batch, at ", - node_def.name()); - } - } - - if (dims_w.nbDims == dims_t.nbDims && dims_w.d[0] == dims_t.d[0]) { - scale_mode = nvinfer1::ScaleMode::kELEMENTWISE; - // Default is element-wise - for (int i = 1; i < dims_w.nbDims; i++) { - if (dims_w.d[i] != dims_t.d[i]) { - // If dimension does not match, switch back to per-channel - scale_mode = nvinfer1::ScaleMode::kCHANNEL; - break; - } - } - // If the mode is per-channel, since channel dimension is assumed to be - // the third to last dimension, we need to make sure all other dimensions - // have size 1. - if (scale_mode == nvinfer1::ScaleMode::kCHANNEL) { - for (int i = 1; i < dims_w.nbDims; i++) { - if (dims_w.d[i] != 1) - return errors::InvalidArgument( - "Weight dims not compatible for channel-wise broadcast at ", - node_def.name()); - } - } - } else if (dims_w.nbDims == 1 && - dims_w.d[0] == dims_t.d[dims_t.nbDims - 1]) { - // Channel wise and broadcast required. We compare the last dimension of - // the tensor shape because of tensorflow default broadcasting rules. - need_to_permute = true; - scale_mode = nvinfer1::ScaleMode::kCHANNEL; - } else { - return errors::InvalidArgument("Weight dims not compatible at ", - node_def.name()); - } - } - // TODO(laigd): we should add validation_only support in TransposeTensor() and - // PrepareTensorForShape(). - if (params->validation_only) return Status::OK(); - - // Transpose last dimension. - std::vector permutation(dims_t.nbDims + 1); - if (need_to_permute) { - // We swap the last dimension into channel for trt, because of tensorflow - // default broadcasting rules. - for (int i = 0; i < static_cast(permutation.size()); i++) { - permutation[i] = i; - } - permutation[1] = dims_t.nbDims; - permutation[dims_t.nbDims] = 1; - TF_RETURN_IF_ERROR( - params->converter->TransposeTensor(tensor, permutation, &tensor)); - } - - // Prepare weights - TRT_ShapedWeights shift_weights(weights.TrtDType()); - TRT_ShapedWeights scale_weights(weights.TrtDType()); - TRT_ShapedWeights power_weights(weights.TrtDType()); - - if (node_def.op() == "Sub") { - if (swapped_inputs) { - shift_weights = weights; - nvinfer1::IUnaryLayer* layer = params->converter->network()->addUnary( - *tensor, nvinfer1::UnaryOperation::kNEG); - TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); - // Since quantization ranges are symmetric, the same range as the input - // will work for the negation of the input. - params->converter->MarkQuantizationRangesAsInferrable( - tensor, layer->getOutput(0)); - tensor = layer->getOutput(0); - } else { - TRT_ShapedWeights neg_weights = - params->weight_store->GetTempWeights(weights); - LambdaFactory unary_op; - unary_op.op = LambdaFactory::OP_CATEGORY::NEG; - TF_RETURN_IF_ERROR(UnaryCompute(weights, &neg_weights, unary_op)); - shift_weights = neg_weights; - } - } else if (node_def.op() == "Div" || node_def.op() == "RealDiv") { - if (swapped_inputs) { - // We need to infer the quantization range for this intermediate tensor. - // - // x -> [Recip] -> 1/x -> [Scale] -> s/x - // ^ - // need range for this - // - // We have the quantization scales for x and s/x - can we divide the scale - // for s/x by s? Only if it is a scalar. - // - // Because of this issue, fall back to BinaryTensorOpTensor if we are - // doing INT8 with no calibration. There is most likely no performance - // penalty by falling back here. - if (params->converter->precision_mode() == TrtPrecisionMode::INT8 && - !params->converter->use_calibration()) { - return errors::Unimplemented( - "Intermediate quantization range cannot be determined without" - " calibration. Falling back to BinaryTensorOpTensor for ", - node_def.op(), ", at ", node_def.name()); - } - scale_weights = weights; - nvinfer1::IUnaryLayer* layer = params->converter->network()->addUnary( - *tensor, nvinfer1::UnaryOperation::kRECIP); - TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); - tensor = layer->getOutput(0); - } else { - TRT_ShapedWeights recip_weights = - params->weight_store->GetTempWeights(weights); - LambdaFactory unary_op; - unary_op.op = LambdaFactory::OP_CATEGORY::RECIP; - TF_RETURN_IF_ERROR(UnaryCompute(weights, &recip_weights, unary_op)); - scale_weights = recip_weights; - } - } else if (node_def.op() == "Mul") { - scale_weights = weights; - } else if (node_def.op() == "Add") { - shift_weights = weights; - } else { - // This should not happen. - return errors::Unimplemented("Binary op not supported at ", node_def.op()); - } - - nvinfer1::IScaleLayer* layer = params->converter->network()->addScale( - *tensor, scale_mode, shift_weights.GetTrtWeights(), - scale_weights.GetTrtWeights(), power_weights.GetTrtWeights()); - TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); - - nvinfer1::ITensor* output_tensor = layer->getOutput(0); - // Transpose back dimension - if (need_to_permute) { - TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - output_tensor, permutation, &output_tensor)); - } - - // Pass the output - params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); - return Status::OK(); -} - Status ConvertConv2DHelper(OpConverterParams* params, int group, bool is_conv2d_backprop_input) { const auto& inputs = params->inputs; @@ -1951,7 +1765,8 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group, kernel_size.h() = weights.shape_.d[2]; kernel_size.w() = weights.shape_.d[3]; - // Add padding. +// Before TRT 5.1.3, we have to calculate padding ourselves. +#if !IS_TRT_VERSION_GE(5, 1, 3, 0) std::vector> padding; if (attrs.get("padding") == "SAME") { nvinfer1::DimsHW effective_kernel_size = kernel_size; @@ -1978,12 +1793,12 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group, padding = {{0, 0}, {0, 0}}; } -// TensorRT 5.1 added support for asymmetric padding. Due to a bug in 5.1.2, we -// can only use asymmetric padding in convolutions with 5.1.3+. -#if !IS_TRT_VERSION_GE(5, 1, 3, 0) + // Handle asymmetric padding. TensorRT 5.1 added support for asymmetric + // padding via setPrePadding and setPostPadding. Due to a bug in 5.1.2, we can + // only use asymmetric padding in convolutions with 5.1.3+. But in 5.1.3, we + // will always use setPaddingMode for simplicity. if (padding[0].first != padding[0].second || padding[1].first != padding[1].second) { - // Handle asymmetric padding. auto pad_layer = params->converter->network()->addPadding( *tensor, nvinfer1::DimsHW(padding[0].first, padding[1].first), nvinfer1::DimsHW(padding[0].second, padding[1].second)); @@ -2006,20 +1821,13 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group, layer->setStride(stride); // TensorRT 5.1.3 added support for padding modes. #if IS_TRT_VERSION_GE(5, 1, 3, 0) + // VALID padding is the default TRT behavior. if (attrs.get("padding") == "SAME") { - VLOG(2) << "Using SAME padding"; // SAME_UPPER means that post padding is preferred. layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER); } - // For VALID padding, we need to manually set the padding. - layer->setPrePadding(nvinfer1::DimsHW{padding[0].first, padding[1].first}); - layer->setPostPadding( - nvinfer1::DimsHW{padding[0].second, padding[1].second}); - VLOG(2) << "Set pre-padding to: " << DebugString(layer->getPrePadding()) - << " and post-padding to: " << DebugString(layer->getPostPadding()); #else layer->setPadding(nvinfer1::DimsHW{padding[0].first, padding[1].first}); - VLOG(2) << "Set padding to: " << DebugString(layer->getPadding()); #endif layer->setName(node_def.name().c_str()); layer->setNbGroups(num_groups); @@ -2033,17 +1841,10 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group, layer->setStride(stride); #if IS_TRT_VERSION_GE(5, 1, 3, 0) if (attrs.get("padding") == "SAME") { - VLOG(2) << "Using SAME padding"; layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER); } - layer->setPrePadding(nvinfer1::DimsHW{padding[0].first, padding[1].first}); - layer->setPostPadding( - nvinfer1::DimsHW{padding[0].second, padding[1].second}); - VLOG(2) << "Set pre-padding to: " << DebugString(layer->getPrePadding()) - << " and post-padding to: " << DebugString(layer->getPostPadding()); #else layer->setPadding(nvinfer1::DimsHW{padding[0].first, padding[1].first}); - VLOG(2) << "Set padding to: " << DebugString(layer->getPadding()); #endif layer->setName(node_def.name().c_str()); layer->setNbGroups(num_groups); @@ -2061,74 +1862,6 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group, return Status::OK(); } -Status BinaryTensorOpTensor(OpConverterParams* params, - const TRT_TensorOrWeights& operand_l, - const TRT_TensorOrWeights& operand_r) { - const auto& node_def = params->node_def; - static const std::unordered_map ops{ - {"Add", nvinfer1::ElementWiseOperation::kSUM}, - {"Mul", nvinfer1::ElementWiseOperation::kPROD}, - {"Sub", nvinfer1::ElementWiseOperation::kSUB}, - {"Div", nvinfer1::ElementWiseOperation::kDIV}, - {"RealDiv", nvinfer1::ElementWiseOperation::kDIV}, - {"Minimum", nvinfer1::ElementWiseOperation::kMIN}, - {"Maximum", nvinfer1::ElementWiseOperation::kMAX}, - {"Pow", nvinfer1::ElementWiseOperation::kPOW}, - }; - auto op_pair = ops.find(node_def.op()); - if (op_pair == ops.end()) { - return errors::Unimplemented("Binary op ", node_def.op(), - " not supported at: ", node_def.name()); - } - - nvinfer1::Dims broadcasted_dims_l, broadcasted_dims_r; - Status status = params->converter->GetTrtBroadcastShape( - operand_l, operand_r, &broadcasted_dims_l, &broadcasted_dims_r); - if (!status.ok()) { - return errors::InvalidArgument( - "Unsupported binary op broadcast scheme for op ", node_def.name(), ": ", - status.error_message()); - } - TFAttrs attrs(node_def); - nvinfer1::DataType dtype = attrs.get("T"); - if (dtype == nvinfer1::DataType::kINT32) { - return errors::Unimplemented("Binary op ", node_def.op(), - " does not support INT32, at ", - node_def.name()); - } - if (params->validation_only) return Status::OK(); - - nvinfer1::ITensor* tensor_l = nullptr; - nvinfer1::ITensor* tensor_r = nullptr; - status = params->converter->PrepareTensorForShape( - operand_l, broadcasted_dims_l, /*validation_only=*/false, &tensor_l); - if (status.ok()) { - status = params->converter->PrepareTensorForShape( - operand_r, broadcasted_dims_r, /*validation_only=*/false, &tensor_r); - } - if (!status.ok()) { - return errors::Internal("Failed to convert binary op ", node_def.name(), - ": ", status.error_message()); - } - - // Check type consistency. - TFTRT_CHECK_EQ_TYPE(tensor_l->getType(), dtype) - << DebugString(tensor_l->getType()) << " vs " << DebugString(dtype); - TFTRT_CHECK_EQ_TYPE(tensor_r->getType(), dtype) - << DebugString(tensor_r->getType()) << " vs " << DebugString(dtype); - - // Add ElementWise layer. - nvinfer1::IElementWiseLayer* layer = - params->converter->network()->addElementWise(*tensor_l, *tensor_r, - op_pair->second); - TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); - nvinfer1::ITensor* output_tensor = layer->getOutput(0); - - // Pass the output - params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); - return Status::OK(); -} - Status ConvertPlugin(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; @@ -2777,6 +2510,8 @@ Status ConvertPool(OpConverterParams* params) { const auto tf_kernel = attrs.get>("ksize"); const nvinfer1::DimsHW ksize(tf_kernel[h_index], tf_kernel[w_index]); +// Before TRT 5.1.3, we have to calculate padding ourselves. +#if !IS_TRT_VERSION_GE(5, 1, 3, 0) auto tensor_dim = tensor->getDimensions(); std::vector> padding; if (padding_type == "SAME") { @@ -2789,13 +2524,13 @@ Status ConvertPool(OpConverterParams* params) { } else if (padding_type == "VALID") { padding = {{0, 0}, {0, 0}}; } - -// TensorRT 5.1 added support for asymmetric padding. +#endif +// TensorRT 5.1 added support for asymmetric padding. Before that, we need an +// extra padding layer. #if !IS_TRT_VERSION_GE(5, 1, 0, 0) + // Asymmetric padding case. if (padding[0].first != padding[0].second || padding[1].first != padding[1].second) { - VLOG(2) << "Padding!!!: " << padding[0].first << padding[0].second - << padding[1].first << padding[1].second; auto pad_layer = params->converter->network()->addPadding( *tensor, nvinfer1::DimsHW(padding[0].first, padding[1].first), nvinfer1::DimsHW(padding[0].second, padding[1].second)); @@ -2817,16 +2552,13 @@ Status ConvertPool(OpConverterParams* params) { layer->getOutput(0)); layer->setStride(stride); -// TensorRT 5.1.3 added support for padding modes. #if IS_TRT_VERSION_GE(5, 1, 3, 0) + // VALID padding is the default TRT behavior. if (attrs.get("padding") == "SAME") { // SAME_UPPER means that post padding is preferred. layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER); } -#endif -// TensorRT 5.1 has support for asymmetric padding. -#if IS_TRT_VERSION_GE(5, 1, 0, 0) - // If padding mode is not SAME, then these values will be used instead. +#elif IS_TRT_VERSION_GE(5, 1, 0, 0) layer->setPrePadding(nvinfer1::DimsHW{padding[0].first, padding[1].first}); layer->setPostPadding(nvinfer1::DimsHW{padding[0].second, padding[1].second}); #else @@ -3350,9 +3082,6 @@ Status ConvertIdentity(OpConverterParams* params) { Status ConvertBinary(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - // TODO(tmorris): Enable once false is updated to mean either tensor or weight - // TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}, {"y", - // false}})); if (inputs.size() != 2) { return errors::InvalidArgument(node_def.op(), " got ", inputs.size(), " inputs but expected 2, at ", @@ -3368,33 +3097,45 @@ Status ConvertBinary(OpConverterParams* params) { "both input as constant at: ", node_def.name()); } + const TRT_TensorOrWeights& operand_l = inputs.at(0); + const TRT_TensorOrWeights& operand_r = inputs.at(1); - // TODO(tmorris): TRT plans to deprecate IScaleLayer and will replace it with - // IElementwiseLayer. At that point, we can remove BinaryTensorOpWeight. For - // now, the performance will be slightly better with IScaleLayer because it - // can be fused in more situations. However, most of the benefits of - // IScaleLayer are when the layer performs both a shift and a scale, which we - // don't do except for convolutions. - // - // Try to convert into Scale layer first (for better performance). - // Since scale layer supports restricted broadcast policy and op types, we - // allow failure and try to handle it through Elementwise op - // (BinaryTensorOpTensor). - Status status = Status::OK(); - if (inputs.at(0).is_tensor() && inputs.at(1).is_weights()) { - status = BinaryTensorOpWeight(params, inputs.at(0).tensor(), - inputs.at(1).weights(), false); - } else if (inputs.at(0).is_weights() && inputs.at(1).is_tensor()) { - status = BinaryTensorOpWeight(params, inputs.at(1).tensor(), - inputs.at(0).weights(), true); + static const std::unordered_map ops{ + {"Add", nvinfer1::ElementWiseOperation::kSUM}, + {"Mul", nvinfer1::ElementWiseOperation::kPROD}, + {"Sub", nvinfer1::ElementWiseOperation::kSUB}, + {"Div", nvinfer1::ElementWiseOperation::kDIV}, + {"RealDiv", nvinfer1::ElementWiseOperation::kDIV}, + {"Minimum", nvinfer1::ElementWiseOperation::kMIN}, + {"Maximum", nvinfer1::ElementWiseOperation::kMAX}, + {"Pow", nvinfer1::ElementWiseOperation::kPOW}, + }; + auto op_pair = ops.find(node_def.op()); + if (op_pair == ops.end()) { + return errors::Unimplemented("Binary op ", node_def.op(), + " not supported at: ", node_def.name()); } - // If both input are tensors, or one of them is weights but the conversion - // above failed, try the conversion using BinaryTensorOpTensor. - if ((inputs.at(0).is_tensor() && inputs.at(1).is_tensor()) || !status.ok()) { - if (!status.ok()) VLOG(2) << status; - status = BinaryTensorOpTensor(params, inputs.at(0), inputs.at(1)); - } - return status; + + nvinfer1::Dims broadcasted_dims_l, broadcasted_dims_r; + TF_RETURN_IF_ERROR(GetTrtBroadcastShape( + operand_l, operand_r, &broadcasted_dims_l, &broadcasted_dims_r)); + + nvinfer1::ITensor* tensor_l = nullptr; + nvinfer1::ITensor* tensor_r = nullptr; + // This will also convert constants to tensors, and set quantization ranges. + TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( + operand_l, broadcasted_dims_l, params->validation_only, &tensor_l)); + TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( + operand_r, broadcasted_dims_r, params->validation_only, &tensor_r)); + if (params->validation_only) return Status::OK(); + + // Add ElementWise layer. + nvinfer1::IElementWiseLayer* layer = + params->converter->network()->addElementWise(*tensor_l, *tensor_r, + op_pair->second); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0))); + return Status::OK(); } Status ConvertRsqrt(OpConverterParams* params) { @@ -4547,7 +4288,7 @@ Status ConvertSquaredDifference(OpConverterParams* params) { const auto& node_def = params->node_def; // Broadcast inputs. nvinfer1::Dims broadcasted_dims_l, broadcasted_dims_r; - TF_RETURN_IF_ERROR(params->converter->GetTrtBroadcastShape( + TF_RETURN_IF_ERROR(GetTrtBroadcastShape( inputs.at(0), inputs.at(1), &broadcasted_dims_l, &broadcasted_dims_r)); nvinfer1::ITensor* tensor_l = nullptr; nvinfer1::ITensor* tensor_r = nullptr; @@ -4692,8 +4433,8 @@ Status ConvertCombinedNMS(OpConverterParams* params) { TFTRT_RETURN_ERROR_IF_NULLPTR(creator, node_def.name()); // Create plugin - nvinfer1::IPluginV2* plugin = - creator->createPlugin(node_def.name().c_str(), &fc); + TrtUniquePtrType plugin( + creator->createPlugin(node_def.name().c_str(), &fc)); TFTRT_RETURN_ERROR_IF_NULLPTR(plugin, node_def.name()); // Set plugin inputs @@ -4875,7 +4616,8 @@ static void RegisterValidatableOpConverters( for (auto pool_op_type : {"AvgPool", "MaxPool"}) { (*registration)[pool_op_type] = ConvertPool; } - for (auto normalization_op_type : {"FusedBatchNorm", "FusedBatchNormV2"}) { + for (auto normalization_op_type : + {"FusedBatchNorm", "FusedBatchNormV2", "FusedBatchNormV3"}) { (*registration)[normalization_op_type] = ConvertFusedBatchNorm; } for (auto unary_op_pair : *UnaryOperationMap()) { diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h index 763b28b7402..d0f6d5ef1d1 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h @@ -512,13 +512,6 @@ class Converter { const bool validation_only, nvinfer1::ITensor** tensor); - // Return OK if the broadcast scheme is supported and compute the shapes after - // broadcasting. - Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l, - const TRT_TensorOrWeights& operand_r, - nvinfer1::Dims* operand_l_new_dims, - nvinfer1::Dims* operand_r_new_dims) const; - // Creates an IConstantLayer using 'weights' whose dimensions are specified by // 'dims', and returns the output ITensor. nvinfer1::ITensor* CreateConstantLayer(const TRT_ShapedWeights& weights, @@ -592,6 +585,13 @@ class Converter { friend class OpConverterTest; }; +// Return OK if the broadcast scheme is supported and compute the shapes after +// broadcasting. +Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l, + const TRT_TensorOrWeights& operand_r, + nvinfer1::Dims* operand_l_new_dims, + nvinfer1::Dims* operand_r_new_dims); + // Map of all supported UnaryOperations const std::unordered_map* UnaryOperationMap(); // Map of all supported ActivationTypes diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index c4ba69c1393..09b7a60c083 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -988,19 +988,17 @@ TEST_F(ConverterTest, GetTrtBroadcastShape) { operand_2_shape, operand_2_is_tensor, operand_2_batch_size); // operand_1 broadcast operand_2 - ExpectStatus( - this->converter_->GetTrtBroadcastShape( - operand_1, operand_2, &operand_1_new_dims, &operand_2_new_dims), - expected_code, expected_error_msg_substr); + ExpectStatus(GetTrtBroadcastShape(operand_1, operand_2, &operand_1_new_dims, + &operand_2_new_dims), + expected_code, expected_error_msg_substr); if (expected_code == error::OK) { ExpectTrtDimsEqualsArray(expected_operand_1_shape, operand_1_new_dims); ExpectTrtDimsEqualsArray(expected_operand_2_shape, operand_2_new_dims); } // operand_2 broadcast operand_1 - ExpectStatus( - this->converter_->GetTrtBroadcastShape( - operand_2, operand_1, &operand_2_new_dims, &operand_1_new_dims), - expected_code, expected_error_msg_substr); + ExpectStatus(GetTrtBroadcastShape(operand_2, operand_1, &operand_2_new_dims, + &operand_1_new_dims), + expected_code, expected_error_msg_substr); if (expected_code == error::OK) { ExpectTrtDimsEqualsArray(expected_operand_1_shape, operand_1_new_dims); ExpectTrtDimsEqualsArray(expected_operand_2_shape, operand_2_new_dims); @@ -1033,18 +1031,29 @@ TEST_F(ConverterTest, GetTrtBroadcastShape) { error::INVALID_ARGUMENT, "Broadcasting beyond batch dimension is not supported " "(tensor #dims 4 vs broadcast #dims 5)"); + symmetric_test({3}, {1, 1, 3}, kIsTensor, kIsNotTensor, {}, {}, + error::INVALID_ARGUMENT, + "Broadcasting beyond batch dimension is not supported " + "(tensor #dims 2 vs broadcast #dims 3)", + /*operand_1_batch_size=*/2); // Both inputs are tensors. symmetric_test({1, 1, 1}, {1, 1}, kIsTensor, kIsTensor, {}, {}, error::INVALID_ARGUMENT, "Broadcasting beyond batch dimension is not supported " "(tensor #dims 3 vs broadcast #dims 4)"); + symmetric_test({1, 3}, {3}, kIsTensor, kIsTensor, {}, {}, + error::INVALID_ARGUMENT, + "Broadcasting beyond batch dimension is not supported " + "(tensor #dims 2 vs broadcast #dims 3)"); symmetric_test({1, 3, 4}, {2, 1, 4}, kIsTensor, kIsTensor, {1, 3, 4}, {2, 1, 4}); symmetric_test({1, 1, 1}, {1, 1, 1, 1}, kIsTensor, kIsTensor, {}, {}, error::INVALID_ARGUMENT, "Broadcasting beyond batch dimension is not supported " "(tensor #dims 4 vs broadcast #dims 5)"); + symmetric_test({2, 3}, {7, 5}, kIsTensor, kIsTensor, {}, {}, + error::INVALID_ARGUMENT, "Infeasible broadcast scheme"); } TEST_F(ConverterTest, CreateConstantLayer) { @@ -1070,7 +1079,7 @@ class ConvertGraphDefToEngineTest : public ::testing::Test { int batch_size = -1; for (const NodeDef& node : gdef.node()) { absl::string_view node_name(node.name()); - if (str_util::ConsumePrefix(&node_name, kInputPHName)) { + if (absl::ConsumePrefix(&node_name, kInputPHName)) { int port = -1; EXPECT_TRUE(absl::SimpleAtoi(node_name, &port)) << node.name(); if (input_shapes.size() < port + 1) input_shapes.resize(port + 1); @@ -1351,10 +1360,6 @@ class OpConverterTest : public ::testing::Test { } } - void TestMatMulHelper( - const std::function& get_matmul, - const std::string& op_name); - // Expose quantization_ranges_ for tests std::unordered_map& quantization_ranges() { return converter_->quantization_ranges_; @@ -1682,59 +1687,60 @@ TEST_F(OpConverterTest, ConvertReshape) { // Helper function for testing MatMul and BatchMatMul // get_matmul corresponds to the function used to generate the node. It should // accept (DataType, transpose_a, transpose_b) as parameters. -void OpConverterTest::TestMatMulHelper( +void TestMatMulHelper( + OpConverterTest* test, const std::function& get_matmul, const std::string& op_name) { // HACK: This needs to be done in a better way. const bool is_batch_matmul = op_name == "BatchMatMul"; { // Unsupported data type. - Reset(); + test->Reset(); NodeDef node_def = get_matmul(DT_INT32, false, false); - AddTestTensor("input", {2}, /*batch_size=*/1, nvinfer1::DataType::kINT32); - AddTestWeights("weights", {2, 1}, {3, 5}); - RunValidationAndConversion( + test->AddTestTensor("input", {2}, /*batch_size=*/1, + nvinfer1::DataType::kINT32); + test->AddTestWeights("weights", {2, 1}, {3, 5}); + test->RunValidationAndConversion( node_def, error::UNIMPLEMENTED, - ("Data type int32 is not supported for " + op_name + - ", " - "must be one of [float, half], at my_matmul") + StrCat("Data type int32 is not supported for ", op_name, + ", must be one of [float, half], at my_matmul") .c_str()); } // OK. for (bool transpose_a : {false, true}) { for (bool transpose_b : {false, true}) { - Reset(); + test->Reset(); NodeDef node_def = get_matmul(DT_FLOAT, transpose_a, transpose_b); - AddTestTensor("input", {2}, /*batch_size=*/1); - AddTestWeights("weights", {2, 2}, {0, 1, 2, 3}); + test->AddTestTensor("input", {2}, /*batch_size=*/1); + test->AddTestWeights("weights", {2, 2}, {0, 1, 2, 3}); if (is_batch_matmul) { if (transpose_a || transpose_b) { - RunValidationAndConversion( + test->RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, "Input weight attempts to broadcast across batch dimension for " "BatchMatMul, at my_matmul"); } else { - RunValidationAndConversion( + test->RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, "Input weight attempts to broadcast across batch dimension"); } continue; } else if (transpose_a) { - RunValidationAndConversion( + test->RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, "Cannot transpose first input if it is a tensor with fewer than 2 " "non-batch dimensions"); continue; } - RunValidationAndConversion(node_def); + test->RunValidationAndConversion(node_def); TRT_TensorOrWeights output; - TF_EXPECT_OK(GetTensorOrWeights("my_matmul", &output)); + TF_EXPECT_OK(test->GetTensorOrWeights("my_matmul", &output)); ASSERT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray({2}, output.tensor()->getDimensions()); const DataVec input_data{{"input", test::AsTensor({0, 1})}}; DataVec output_data{{"my_matmul", ConstructTensor(2)}}; - BuildAndRun(input_data, &output_data); + test->BuildAndRun(input_data, &output_data); if (transpose_b) { EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(1, 3)); } else { @@ -1744,31 +1750,31 @@ void OpConverterTest::TestMatMulHelper( } // OK, 3D inputs for (bool transpose_b : {false, true}) { - Reset(); + test->Reset(); NodeDef node_def = get_matmul(DT_FLOAT, /*transpose_a=*/false, transpose_b); - AddTestTensor("input", {2}, /*batch_size=*/1); - AddTestWeights("weights", {2, 2}, {0, 1, 2, 3}); + test->AddTestTensor("input", {2}, /*batch_size=*/1); + test->AddTestWeights("weights", {2, 2}, {0, 1, 2, 3}); if (is_batch_matmul) { if (transpose_b) { - RunValidationAndConversion( + test->RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, "Input weight attempts to broadcast across batch dimension for " "BatchMatMul, at my_matmul"); } else { - RunValidationAndConversion( + test->RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, "Input weight attempts to broadcast across batch dimension"); } continue; } - RunValidationAndConversion(node_def); + test->RunValidationAndConversion(node_def); TRT_TensorOrWeights output; - TF_EXPECT_OK(GetTensorOrWeights("my_matmul", &output)); + TF_EXPECT_OK(test->GetTensorOrWeights("my_matmul", &output)); ASSERT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray({2}, output.tensor()->getDimensions()); const DataVec input_data{{"input", test::AsTensor({0, 1})}}; DataVec output_data{{"my_matmul", ConstructTensor(2)}}; - BuildAndRun(input_data, &output_data); + test->BuildAndRun(input_data, &output_data); if (transpose_b) { EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(1, 3)); } else { @@ -1832,7 +1838,7 @@ TEST_F(OpConverterTest, ConvertMatMul) { node_def, error::INVALID_ARGUMENT, "Cannot currently transpose constant input if it is not 2 dimensional"); } - TestMatMulHelper(get_matmul_nodedef, "MatMul"); + TestMatMulHelper(this, get_matmul_nodedef, "MatMul"); } TEST_F(OpConverterTest, ConvertBatchMatMul) { @@ -1889,7 +1895,7 @@ TEST_F(OpConverterTest, ConvertBatchMatMul) { } } - TestMatMulHelper(get_batch_matmul_nodedef, "BatchMatMul"); + TestMatMulHelper(this, get_batch_matmul_nodedef, "BatchMatMul"); } template @@ -2010,250 +2016,82 @@ void CheckAddedLayers(OpConverterTest* test, bool expect_scale_layer) { } template -void TestBinaryTensorOpWeightNoBroadcast(OpConverterTest* test) { - typedef typename EnumToDataType::Type CType; - for (auto swap_inputs : {false, true}) { - test->Reset(); - NodeDef node_def; - if (swap_inputs) { - node_def = GetBinaryOpNodeDef("weights", "input", dtype); - } else { - node_def = GetBinaryOpNodeDef("input", "weights", dtype); - } - - const std::vector operand1{CType(3), CType(7.5)}; - const std::vector operand2{CType(2), CType(3)}; - - // It requires the dims to be at least of rank 3 to apply an IScaleLayer. - test->AddTestTensor("input", /*dims=*/{1, 1, 2}, /*batch_size=*/1, - TfDataTypeToTrt(dtype)); - test->AddTestWeights("weights", /*dims=*/{1, 1, 2}, - /*values=*/swap_inputs ? operand1 : operand2); - test->RunValidationAndConversion(node_def); - - // Make sure it does use BinaryTensorOpWeight, not BinaryTensorOpTensor. - CheckAddedLayers(test, /*expect_scale_layer=*/true); - - // Check the dims of the output ITensor. - TRT_TensorOrWeights output; - TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output)); - ASSERT_TRUE(output.is_tensor()); - ExpectTrtDimsEqualsArray({1, 1, 2}, output.tensor()->getDimensions()); - - const DataVec input_data{ - {"input", test::AsTensor(swap_inputs ? operand2 : operand1)}}; - DataVec output_data{{"my_binary", ConstructTensor(2)}}; - test->BuildAndRun( - input_data, &output_data, - dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); - if (node_def.op() == "Add") { - EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAre(CType(5), CType(10.5))); - } else if (node_def.op() == "Sub") { - EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAre(CType(1), CType(4.5))); - } else if (node_def.op() == "Mul") { - EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAre(CType(6), CType(22.5))); - } else if (node_def.op() == "Div") { - EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAre(CType(1.5), CType(2.5))); - } else if (node_def.op() == "RealDiv") { - EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAre(CType(1.5), CType(2.5))); - } else { - ASSERT_TRUE(false); - } - } -} - -template -void TestBinaryTensorOpWeightWithChannelWiseBroadcast(OpConverterTest* test) { - typedef typename EnumToDataType::Type CType; - const NodeDef node_def = - GetBinaryOpNodeDef("input", "weights", dtype); - const std::vector input{CType(1), CType(2), CType(3), CType(4)}; - const std::vector weights{CType(10), CType(20)}; - // There are two types of valid dim pairs which requires channel-wise - // broadcasting: - // - input dims (X Y Z) vs weights dims (X 1 1) - // - input dims (X Y Z) vs weights dims (Z) - // Here X=Z=2 and Y=1. - for (auto weights_dims : std::vector>{{2, 1, 1}, {2}}) { - test->Reset(); - test->AddTestTensor("input", /*dims=*/{2, 1, 2}, /*batch_size=*/1, - TfDataTypeToTrt(dtype)); - test->AddTestWeights("weights", weights_dims, weights); - test->RunValidationAndConversion(node_def); - - // Make sure it does use BinaryTensorOpWeight, not BinaryTensorOpTensor. - CheckAddedLayers(test, /*expect_scale_layer=*/true); - - // Check the dims of the output ITensor. - TRT_TensorOrWeights output; - TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output)); - ASSERT_TRUE(output.is_tensor()); - ExpectTrtDimsEqualsArray({2, 1, 2}, output.tensor()->getDimensions()); - - const DataVec input_data{{"input", test::AsTensor(input)}}; - DataVec output_data{{"my_binary", ConstructTensor(4)}}; - test->BuildAndRun(input_data, &output_data); - if (weights_dims.size() == 1) { - EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAre(CType(11), CType(22), CType(13), CType(24))); - } else { - EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAre(CType(11), CType(12), CType(23), CType(24))); - } - } -} - -template -void TestBinaryTensorOpWeightWithUniformlyBroadcast(OpConverterTest* test) { - typedef typename EnumToDataType::Type CType; - const NodeDef node_def = - GetBinaryOpNodeDef("input", "weights", dtype); - const std::vector input{CType(1), CType(2), CType(3), CType(4)}; - const std::vector weights{CType(10)}; - test->Reset(); - test->AddTestTensor("input", /*dims=*/{2, 1, 2}, /*batch_size=*/1, - TfDataTypeToTrt(dtype)); - test->AddTestWeights("weights", {1, 1, 1, 1}, weights); - test->RunValidationAndConversion(node_def); - - // Make sure it does use BinaryTensorOpWeight, not BinaryTensorOpTensor. - CheckAddedLayers(test, /*expect_scale_layer=*/true); - - // Check the dims of the output ITensor. - TRT_TensorOrWeights output; - TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output)); - ASSERT_TRUE(output.is_tensor()); - ExpectTrtDimsEqualsArray({2, 1, 2}, output.tensor()->getDimensions()); - - const DataVec input_data{{"input", test::AsTensor(input)}}; - DataVec output_data{{"my_binary", ConstructTensor(4)}}; - test->BuildAndRun(input_data, &output_data); - EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAre(CType(11), CType(12), CType(13), CType(14))); -} - -template -void TestBinaryTensorOpWeightFallback(OpConverterTest* test, - const std::vector& input_dims, - const std::vector& weights_dims, - error::Code code = error::OK, - const char* error_msg_substr = nullptr, - const int input_batch_size = 1) { - const DataType dtype = DT_FLOAT; - typedef typename EnumToDataType::Type CType; - const size_t num_inputs = TrtTensorDimsNumElements(GetTestDims(input_dims)); - const size_t num_weights = - TrtWeightDimsNumElements(GetTestDims(weights_dims)); - - test->Reset(); - const NodeDef node_def = - GetBinaryOpNodeDef("input", "weights", dtype); - test->AddTestTensor("input", /*dims=*/input_dims, input_batch_size, - TfDataTypeToTrt(dtype)); - test->AddTestWeights( - "weights", /*dims=*/weights_dims, - /*values=*/std::vector(num_weights, CType(1))); - test->RunValidationAndConversion(node_def, code, error_msg_substr); - if (code != error::OK) return; - - // Make sure it does use BinaryTensorOpTensor, not BinaryTensorOpWeight. - CheckAddedLayers(test, /*expect_scale_layer=*/false); - - TRT_TensorOrWeights output; - TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output)); - ASSERT_TRUE(output.is_tensor()); - - // Check the dims of the output ITensor. - std::vector expected_output_dims = input_dims; - for (int i = expected_output_dims.size() - 1, j = weights_dims.size() - 1; - i >= 0 && j >= 0; --i, --j) { - if (expected_output_dims[i] == 1) { - expected_output_dims[i] = weights_dims[j]; - } - } - ExpectTrtDimsEqualsArray(expected_output_dims, - output.tensor()->getDimensions()); - - // Check the result of running the engine. - const int expected_num_outputs = - TrtTensorDimsNumElements(GetTestDims(expected_output_dims)); - const DataVec input_data{ - {"input", ConstructTensor(num_inputs, CType(2))}}; - DataVec output_data{ - {"my_binary", ConstructTensor(expected_num_outputs)}}; - test->BuildAndRun(input_data, &output_data); - if (node_def.op() == "Add") { - EXPECT_THAT( - GetSpanForData(output_data[0]), - ElementsAreArray(std::vector(expected_num_outputs, CType(3)))); - } else if (node_def.op() == "Minimum") { - EXPECT_THAT( - GetSpanForData(output_data[0]), - ElementsAreArray(std::vector(expected_num_outputs, CType(1)))); - } else { - ASSERT_TRUE(false); - } -} - -template -void TestBinaryTensorOpTensor(OpConverterTest* test) { +void TestBinaryOp(OpConverterTest* test, bool operand_1_is_tensor, + bool operand_2_is_tensor) { typedef typename EnumToDataType::Type CType; test->Reset(); const NodeDef node_def = GetBinaryOpNodeDef("input1", "input2", dtype); - test->AddTestTensor("input1", /*dims=*/{1, 2}, /*batch_size=*/1, - TfDataTypeToTrt(dtype)); - test->AddTestTensor("input2", /*dims=*/{2, 1}, /*batch_size=*/1, - TfDataTypeToTrt(dtype)); + if (operand_1_is_tensor) { + test->AddTestTensor("input1", /*dims=*/{1, 2}, /*batch_size=*/2, + TfDataTypeToTrt(dtype)); + } else { + test->AddTestWeights("input1", /*dims=*/{1, 2}, + /*values=*/std::vector{CType(3), CType(6)}); + } + if (operand_2_is_tensor) { + test->AddTestTensor("input2", /*dims=*/{2, 1}, /*batch_size=*/2, + TfDataTypeToTrt(dtype)); + } else { + test->AddTestWeights("input2", /*dims=*/{2, 1}, + /*values=*/std::vector{CType(2), CType(3)}); + } test->RunValidationAndConversion(node_def); - // Make sure it does use BinaryTensorOpTensor, not BinaryTensorOpWeight. - CheckAddedLayers(test, /*expect_scale_layer=*/false); - + DataVec input_data; + if (operand_1_is_tensor) { + input_data.push_back( + {"input1", + test::AsTensor({CType(3), CType(6), CType(3), CType(6)})}); + } + if (operand_2_is_tensor) { + input_data.push_back( + {"input2", + test::AsTensor({CType(2), CType(3), CType(2), CType(3)})}); + } + DataVec output_data{{"my_binary", ConstructTensor(8)}}; // Check output dims. TRT_TensorOrWeights output; TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output)); ASSERT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray({2, 2}, output.tensor()->getDimensions()); - - const DataVec input_data{ - {"input1", test::AsTensor({CType(3), CType(6)})}, - {"input2", test::AsTensor({CType(2), CType(3)})}}; - DataVec output_data{{"my_binary", ConstructTensor(4)}}; // After broadcasting first input becomes {3, 6, 3, 6} and second input // becomes {2, 3, 2, 3}. test->BuildAndRun( input_data, &output_data, - dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); + dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32, + /*batch_size=*/2); if (node_def.op() == "Add") { - EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAre(CType(5), CType(8), CType(6), CType(9))); + EXPECT_THAT( + GetSpanForData(output_data[0]), + ElementsAreArray(CastTestVector({5, 8, 6, 9, 5, 8, 6, 9}))); } else if (node_def.op() == "Sub") { - EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAre(CType(1), CType(4), CType(0), CType(3))); + EXPECT_THAT( + GetSpanForData(output_data[0]), + ElementsAreArray(CastTestVector({1, 4, 0, 3, 1, 4, 0, 3}))); } else if (node_def.op() == "Mul") { EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAre(CType(6), CType(12), CType(9), CType(18))); + ElementsAreArray( + CastTestVector({6, 12, 9, 18, 6, 12, 9, 18}))); } else if (node_def.op() == "Div") { EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAre(CType(1.5), CType(3), CType(1), CType(2))); + ElementsAreArray(CastTestVector( + {1.5, 3, 1, 2, 1.5, 3, 1, 2}))); } else if (node_def.op() == "RealDiv") { EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAre(CType(1.5), CType(3), CType(1), CType(2))); + ElementsAreArray(CastTestVector( + {1.5, 3, 1, 2, 1.5, 3, 1, 2}))); } else if (node_def.op() == "Minimum") { - EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAre(CType(2), CType(2), CType(3), CType(3))); + EXPECT_THAT( + GetSpanForData(output_data[0]), + ElementsAreArray(CastTestVector({2, 2, 3, 3, 2, 2, 3, 3}))); } else if (node_def.op() == "Maximum") { - EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAre(CType(3), CType(6), CType(3), CType(6))); + EXPECT_THAT( + GetSpanForData(output_data[0]), + ElementsAreArray(CastTestVector({3, 6, 3, 6, 3, 6, 3, 6}))); } else if (node_def.op() == "Pow") { ExpectArrayNear( - std::vector{CType(9), CType(36), CType(27), CType(216)}, + CastTestVector({9, 36, 27, 216, 9, 36, 27, 216}), GetSpanForData(output_data[0])); } else { ASSERT_TRUE(false); @@ -2287,58 +2125,48 @@ TEST_F(OpConverterTest, ConvertBinary) { "both input as constant at: my_add"); } - // Test BinaryTensorOpWeight() without broadcasting. - TestBinaryTensorOpWeightNoBroadcast(this); - TestBinaryTensorOpWeightNoBroadcast(this); - TestBinaryTensorOpWeightNoBroadcast(this); - TestBinaryTensorOpWeightNoBroadcast(this); - TestBinaryTensorOpWeightNoBroadcast(this); - - TestBinaryTensorOpWeightNoBroadcast(this); - TestBinaryTensorOpWeightNoBroadcast(this); - TestBinaryTensorOpWeightNoBroadcast(this); - TestBinaryTensorOpWeightNoBroadcast(this); - TestBinaryTensorOpWeightNoBroadcast(this); - - // Test BinaryTensorOpWeight() with channel-wise broadcasting. - TestBinaryTensorOpWeightWithChannelWiseBroadcast(this); - - // Test BinaryTensorOpWeight() with uniformly broadcasting. - TestBinaryTensorOpWeightWithUniformlyBroadcast(this); - - // Test BinaryTensorOpWeight() falling back to BinaryTensorOpTensor(). - // Unsupported op. - TestBinaryTensorOpWeightFallback(this, {1, 1, 1}, {1}); - // Rank of input tensor dimension <3. - TestBinaryTensorOpWeightFallback(this, {1, 1}, {1}); - // Broadcast on batch dimension, should fail. - TestBinaryTensorOpWeightFallback( - this, {1, 1, 1}, {2, 1, 1, 1}, error::INVALID_ARGUMENT, - "Unsupported binary op broadcast scheme for op my_binary", - /*input_batch_size=*/2); - // Incompatible dims with per-channel mode. - TestBinaryTensorOpWeightFallback(this, {1, 1, 1}, {1, 2, 1}); - // Incompatible dims. - TestBinaryTensorOpWeightFallback(this, {1, 2, 1}, {2}); - - // Test BinaryTensorOpTensor() with broadcasting. - TestBinaryTensorOpTensor(this); - TestBinaryTensorOpTensor(this); - TestBinaryTensorOpTensor(this); - TestBinaryTensorOpTensor(this); - TestBinaryTensorOpTensor(this); - TestBinaryTensorOpTensor(this); - TestBinaryTensorOpTensor(this); - TestBinaryTensorOpTensor(this); - - TestBinaryTensorOpTensor(this); - TestBinaryTensorOpTensor(this); - TestBinaryTensorOpTensor(this); - TestBinaryTensorOpTensor(this); - TestBinaryTensorOpTensor(this); - TestBinaryTensorOpTensor(this); - TestBinaryTensorOpTensor(this); - TestBinaryTensorOpTensor(this); + // Test combinations of tensor vs weight inputs (except when both inputs are + // weights). + for (const bool operand_1_is_tensor : {true, false}) { + for (const bool operand_2_is_tensor : {true, false}) { + if (!operand_1_is_tensor && !operand_2_is_tensor) continue; + // FP32 tests + TestBinaryOp(this, operand_1_is_tensor, + operand_2_is_tensor); + TestBinaryOp(this, operand_1_is_tensor, + operand_2_is_tensor); + TestBinaryOp(this, operand_1_is_tensor, + operand_2_is_tensor); + TestBinaryOp(this, operand_1_is_tensor, + operand_2_is_tensor); + TestBinaryOp(this, operand_1_is_tensor, + operand_2_is_tensor); + TestBinaryOp(this, operand_1_is_tensor, + operand_2_is_tensor); + TestBinaryOp(this, operand_1_is_tensor, + operand_2_is_tensor); + TestBinaryOp(this, operand_1_is_tensor, + operand_2_is_tensor); + // FP16 tests + // TODO(tmorris): Use templates to avoid duplication. + TestBinaryOp(this, operand_1_is_tensor, + operand_2_is_tensor); + TestBinaryOp(this, operand_1_is_tensor, + operand_2_is_tensor); + TestBinaryOp(this, operand_1_is_tensor, + operand_2_is_tensor); + TestBinaryOp(this, operand_1_is_tensor, + operand_2_is_tensor); + TestBinaryOp(this, operand_1_is_tensor, + operand_2_is_tensor); + TestBinaryOp(this, operand_1_is_tensor, + operand_2_is_tensor); + TestBinaryOp(this, operand_1_is_tensor, + operand_2_is_tensor); + TestBinaryOp(this, operand_1_is_tensor, + operand_2_is_tensor); + } + } } TEST_F(OpConverterTest, ConvertQuantize) { @@ -2583,7 +2411,6 @@ TEST_F(OpConverterTest, ConvertCombinedNMS) { // implementation that, the extra output classes that are outside of the // range specified by valid_detections[i] are not zeros but -1s. TestParams{{1, 1, 4}, {1, 3}, 3, 2, .5f, 0, {2, 4}, {2}, {2}}}; - const int batch_size = 1; for (int i = 0; i < kCombinedNMSOKCases; ++i) { Reset(); diff --git a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc index d325d11dfff..0e5ecc72c60 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc @@ -14,6 +14,8 @@ limitations under the License. #include "tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h" +#include "absl/strings/ascii.h" +#include "absl/strings/escaping.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h" #include "tensorflow/compiler/tf2tensorrt/convert/utils.h" @@ -32,9 +34,9 @@ namespace tensorflow { namespace tensorrt { namespace convert { // TODO(sami): Remove VLOG messages once the code matures +using absl::AsciiStrToUpper; using absl::StrAppend; using absl::StrCat; -using str_util::Uppercase; Status TRTOptimizationPass::Init( const RewriterConfig_CustomGraphOptimizer* config) { @@ -67,7 +69,7 @@ Status TRTOptimizationPass::Init( } if (params.count("precision_mode")) { TF_RETURN_IF_ERROR(TrtPrecisionModeFromName( - Uppercase(params.at("precision_mode").s()), &precision_mode_)); + AsciiStrToUpper(params.at("precision_mode").s()), &precision_mode_)); } if (params.count("use_calibration")) { use_calibration_ = params.at("use_calibration").b(); diff --git a/tensorflow/compiler/tf2tensorrt/kernels/get_serialized_resource_op_test.cc b/tensorflow/compiler/tf2tensorrt/kernels/get_serialized_resource_op_test.cc index ec038ebda07..d54cbf7836e 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/get_serialized_resource_op_test.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/get_serialized_resource_op_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include + #include #include @@ -68,9 +69,9 @@ TEST_F(GetSerializedResourceOpTest, Basic) { TF_ASSERT_OK(RunOpKernel()); // Verify the result. - // TODO(laigd): OpsTestBase::GetOutput() doesn't work. - Tensor* output = context_->mutable_output(0); - EXPECT_EQ("my_serialized_str", output->scalar()()); + // string type output will remain on CPU, so we're not using GetOutput() here. + EXPECT_EQ("my_serialized_str", + context_->mutable_output(0)->scalar()()); } } // namespace tensorrt diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc index b62fdc5dc4b..d4077692235 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc @@ -87,16 +87,11 @@ TYPED_TEST(TRTEngineOpTest, Basic) { TF_ASSERT_OK(OpsTestBase::RunOpKernel()); // Verify the result. - // TODO(laigd): OpsTestBase::GetOutput() doesn't work. - Tensor* output = OpsTestBase::context_->mutable_output(0); - const auto& tensor_map = output->flat(); - std::vector output_data(tensor_map.size()); - ASSERT_EQ(0, cudaDeviceSynchronize()); - ASSERT_EQ(0, cudaMemcpy(output_data.data(), tensor_map.data(), - sizeof(TypeParam) * tensor_map.size(), - cudaMemcpyDeviceToHost)); - EXPECT_THAT(absl::Span(output_data), - ElementsAre(TypeParam(0.0f), TypeParam(2.0f))); + Tensor* output = OpsTestBase::GetOutput(0); + EXPECT_THAT( + absl::Span(output->template flat().data(), + output->NumElements()), + ElementsAre(TypeParam(0.0f), TypeParam(2.0f))); } } // namespace tensorrt diff --git a/tensorflow/compiler/tf2tensorrt/segment/segment.cc b/tensorflow/compiler/tf2tensorrt/segment/segment.cc index 5d9a1b25210..932966534b7 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/segment.cc +++ b/tensorflow/compiler/tf2tensorrt/segment/segment.cc @@ -681,31 +681,33 @@ Status SegmentGraph(const Graph* tf_graph, << " with parent=" << segment_root << ":" << s; } - // Don't use small segments. - if (static_cast(segment_nodes.size()) < options.minimum_segment_size) { + const int num_effective_nodes = std::count_if( + segment_nodes.begin(), segment_nodes.end(), [](const Node* node) { + static auto noops = + new std::set{"Identity", "Snapshot", "StopGradient"}; + return noops->count(node->type_string()) == 0; + }); + + // Don't use segments whose number of effective nodes is small. + if (num_effective_nodes < options.minimum_segment_size) { VLOG(1) << "Segment " << segments->size() << " has only " - << segment_nodes.size() << " nodes, dropping"; + << num_effective_nodes << " effective nodes, dropping"; continue; } - // TODO(sami): Make segmenter placement aware once trtscopes are in place const auto& dev_itr = device_maps.find(segment_root); if (dev_itr == device_maps.end() || dev_itr->second.empty()) { VLOG(1) << "No device assigned to segment " << segments->size(); - segments->emplace_back(std::make_pair(segment_nodes, string())); } else if (dev_itr->second.size() > 1) { - string s("Segment "); - StrAppend(&s, segments->size(), " has multiple devices attached: "); + string s = StrCat("Segment ", segments->size(), + " has multiple devices attached: "); for (const auto& dev : dev_itr->second) { StrAppend(&s, dev, ", "); } - LOG(WARNING) << s << " choosing " << *(dev_itr->second.begin()); - segments->emplace_back( - std::make_pair(segment_nodes, *(dev_itr->second.begin()))); - } else { - segments->emplace_back( - std::make_pair(segment_nodes, *(dev_itr->second.begin()))); + LOG(WARNING) << s; } + + segments->emplace_back(segment_nodes); } if (VLOG_IS_ON(1)) { for (const auto& d : device_maps) { diff --git a/tensorflow/compiler/tf2tensorrt/segment/segment.h b/tensorflow/compiler/tf2tensorrt/segment/segment.h index e31f1a989d9..77c0af223c8 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/segment.h +++ b/tensorflow/compiler/tf2tensorrt/segment/segment.h @@ -31,10 +31,8 @@ namespace tensorflow { namespace tensorrt { namespace segment { -// Vector of segments, each entry contains a set of node pointers and a device -// name in the segment. -using SegmentNodesVector = - std::vector, string>>; +// Vector of segments, each entry contains a set of node pointers. +using SegmentNodesVector = std::vector>; struct SegmentOptions { // Segment must contain at least this many nodes. diff --git a/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc b/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc index 84b690ecba6..cb038e58126 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc +++ b/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc @@ -77,7 +77,7 @@ class SegmentTest : public ::testing::Test { EXPECT_EQ(expected_segments.size(), segments.size()); for (int i = 0; i < segments.size(); ++i) { std::set segment_node_names; - for (const Node* node : segments[i].first) { + for (const Node* node : segments[i]) { segment_node_names.insert(node->name()); } const auto& expected = expected_segments[i]; @@ -262,6 +262,23 @@ TEST_F(SegmentTest, BigIfElse) { {{"add0", "add1"}, {"add3", "add4", "add5", "add6", "add7"}}); } +TEST_F(SegmentTest, IdentityOps) { + Scope s = Scope::NewRootScope(); + auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT); + auto identity0 = ops::Identity(s.WithOpName("identity0"), feed); + auto identity1 = ops::Identity(s.WithOpName("identity1"), identity0); + auto identity2 = ops::Identity(s.WithOpName("identity2"), identity1); + auto identity3 = ops::Identity(s.WithOpName("identity3"), identity2); + Graph g(OpRegistry::Global()); + TF_EXPECT_OK(s.ToGraph(&g)); + + const std::set all_identities = {"identity0", "identity1", + "identity2", "identity3"}; + // Identity ops are not counted as effective ops in the segment, so no segment + // will be formed in this case. + RunTest(&g, all_identities, all_identities, all_identities, {}); +} + } // namespace test } // namespace segment } // namespace tensorrt diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index dcce43cbe70..2bc8ab45a51 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -1,7 +1,10 @@ -licenses(["notice"]) # Apache 2.0 - load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_cuda_cc_test") +package( + default_visibility = [":internal"], + licenses = ["notice"], # Apache 2.0 +) + package_group( name = "internal", packages = [ @@ -23,15 +26,12 @@ package_group( ], ) -package( - default_visibility = [":internal"], -) - load( "//tensorflow/core:platform/default/cuda_build_defs.bzl", "if_cuda_is_configured", ) load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library", "xla_py_proto_library") +load("//tensorflow:tensorflow.bzl", "tf_portable_proto_library") cc_library( name = "tf2xla_supported_ops_lib", @@ -67,6 +67,19 @@ xla_proto_library( ], ) +# A proto library that is minimal in size and dependencies for platforms like Android. +tf_portable_proto_library( + name = "portable_tf2xla_proto", + config_string = "allow_all:true", + header_outs = ["//tensorflow/compiler/tf2xla/tf2xla.proto.h"], + portable_deps = ["//tensorflow/core:android_proto_lib"], + proto_deps = [ + ":tf2xla_proto", + "//tensorflow/core:protos_all_cc", + ], + visibility = ["//visibility:public"], +) + xla_py_proto_library( name = "tf2xla_py", has_services = False, diff --git a/tensorflow/compiler/tf2xla/cc/BUILD b/tensorflow/compiler/tf2xla/cc/BUILD index adcdb6c8f76..fb7c8c56ac7 100644 --- a/tensorflow/compiler/tf2xla/cc/BUILD +++ b/tensorflow/compiler/tf2xla/cc/BUILD @@ -1,9 +1,8 @@ package( default_visibility = ["//tensorflow/compiler/tf2xla:internal"], + licenses = ["notice"], # Apache 2.0 ) -licenses(["notice"]) # Apache 2.0 - load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_cc") tf_gen_op_wrapper_cc( diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc index 6e093400e47..3aaa2eed432 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc @@ -918,10 +918,16 @@ string Conditional::name() const { Status FunctionalizeCond::AddIdentityNode(const Node* replacee, Node* if_node, int port) { + NodeBuilder id_builder(replacee->name(), "Identity"); + id_builder.Input(if_node, port); + string outside_compilation; + if (GetNodeAttr(if_node->def(), kXlaOutsideCompilationAttrName, + &outside_compilation) + .ok()) { + id_builder.Attr(kXlaOutsideCompilationAttrName, outside_compilation); + } Node* id; - TF_RETURN_IF_ERROR(NodeBuilder(replacee->name(), "Identity") - .Input(if_node, port) - .Finalize(graph_, &id)); + TF_RETURN_IF_ERROR(id_builder.Finalize(graph_, &id)); state_map_.ResetCondId(id, state_map_.LookupCondId(if_node)); state_map_.ResetAncestorId(id, state_map_.LookupAncestorId(if_node)); return Status::OK(); diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 89d5a860179..294a104b3b5 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -247,8 +247,8 @@ Status FunctionalizeControlFlowPass::Run( // multiple times, and we want to avoid functionalize it again. static std::map* kNodeTypeToFunctionAttrMapping = new std::map{ - // TPUReplicate ops are generated by EncapsulateTPUComputationsPass. - {"TPUReplicate", "computation"}, + // _TPUReplicate ops are generated by EncapsulateTPUComputationsPass. + {"_TPUReplicate", "computation"}, // XlaLaunch ops are generated by EncapsulateXlaComputationsPass. {"XlaLaunch", "function"}, }; diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index d6dfa39e658..06376f7174e 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -1,9 +1,8 @@ load("//tensorflow:tensorflow.bzl", "tf_copts", "tf_kernel_library") -licenses(["notice"]) # Apache 2.0 - package( default_visibility = ["//tensorflow/compiler/tf2xla:internal"], + licenses = ["notice"], # Apache 2.0 ) tf_kernel_library( @@ -195,6 +194,7 @@ tf_kernel_library( "//tensorflow/core/kernels:training_ops", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/compiler/tf2xla/kernels/assert_op.cc b/tensorflow/compiler/tf2xla/kernels/assert_op.cc index af4ab5e8ef6..94543686b47 100644 --- a/tensorflow/compiler/tf2xla/kernels/assert_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/assert_op.cc @@ -43,7 +43,7 @@ class AssertOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(AssertOp); }; -REGISTER_XLA_OP(Name("Assert"), AssertOp); +REGISTER_XLA_OP(Name("Assert").CompilationOnly(), AssertOp); } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc index 84eda80fc25..013a5734863 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc @@ -39,7 +39,10 @@ class FusedBatchNormOp : public XlaOpKernel { is_on_gpu_ = ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT; } - void Compile(XlaOpKernelContext* ctx) override { + void Compile(XlaOpKernelContext* ctx) override { CompileImpl(ctx); } + + protected: + virtual void CompileImpl(XlaOpKernelContext* ctx) { xla::PrimitiveType input_type; OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(ctx->input_type(0), &input_type)); @@ -116,8 +119,29 @@ class FusedBatchNormOp : public XlaOpKernel { bool is_on_gpu_; }; +class FusedBatchNormOpV3 : public FusedBatchNormOp { + public: + explicit FusedBatchNormOpV3(OpKernelConstruction* ctx) + : FusedBatchNormOp(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + FusedBatchNormOp::CompileImpl(ctx); + if (!ctx->status().ok()) { + return; + } + ctx->SetConstantOutput(5, Tensor()); + } + + private: + float epsilon_; + TensorFormat data_format_; + bool is_training_; + bool is_on_gpu_; +}; + REGISTER_XLA_OP(Name("FusedBatchNorm"), FusedBatchNormOp); REGISTER_XLA_OP(Name("FusedBatchNormV2"), FusedBatchNormOp); +REGISTER_XLA_OP(Name("FusedBatchNormV3"), FusedBatchNormOpV3); class FusedBatchNormGradOp : public XlaOpKernel { public: @@ -233,6 +257,7 @@ class FusedBatchNormGradOp : public XlaOpKernel { REGISTER_XLA_OP(Name("FusedBatchNormGrad"), FusedBatchNormGradOp); REGISTER_XLA_OP(Name("FusedBatchNormGradV2"), FusedBatchNormGradOp); +REGISTER_XLA_OP(Name("FusedBatchNormGradV3"), FusedBatchNormGradOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc index d801d560040..258d8f75cde 100644 --- a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/util/tensor_format.h" @@ -150,6 +151,15 @@ class ExtractImagePatchesOp : public XlaOpKernel { xla::XlaOp conv = xla::ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding, lhs_dilation, rhs_dilation, dims, depth); + // Feature group convolution, will end up with the kernel_size change more + // rapidly than the depth. Reshape, transpose and reshape to reorder them. + auto conv_dims = builder->GetShape(conv).ValueOrDie().dimensions(); + conv_dims.back() = depth; + conv_dims.push_back(kernel_size); + conv = xla::TransposeInMinorDims(xla::Reshape(conv, conv_dims)); + conv_dims.pop_back(); + conv_dims.back() *= kernel_size; + conv = xla::Reshape(conv, conv_dims); ctx->SetOutput(0, conv); } diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index 6472045265e..489ffd3fdad 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + +#include "absl/types/optional.h" #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -20,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -148,15 +152,22 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape, class GatherOp : public XlaOpKernel { public: - explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) {} + explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) { + // Set batch_dims_ to 0 if the attribute does not exist. + if (context->HasAttr("batch_dims")) { + OP_REQUIRES_OK(context, context->GetAttr("batch_dims", &batch_dims_)); + } else { + batch_dims_ = 0; + } + } void Compile(XlaOpKernelContext* context) override { - xla::XlaBuilder* builder = context->builder(); auto input = context->Input(0); auto input_shape = context->InputShape(0); auto indices = context->Input(1); auto indices_shape = context->InputShape(1); - int64 axis = 0; + + absl::optional axis; if (context->num_inputs() == 3) { const TensorShape axis_shape = context->InputShape(2); OP_REQUIRES(context, TensorShapeUtils::IsScalar(axis_shape), @@ -165,31 +176,73 @@ class GatherOp : public XlaOpKernel { OP_REQUIRES(context, axis_type == DT_INT32 || axis_type == DT_INT64, errors::InvalidArgument("axis must be int32 or int64")); - OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &axis)); + int64 axis_input; + OP_REQUIRES_OK(context, + context->ConstantInputAsIntScalar(2, &axis_input)); + const auto params_dims = input_shape.dims(); - OP_REQUIRES( - context, -params_dims <= axis && axis < params_dims, - errors::InvalidArgument("Expected axis in the range [", -params_dims, - ", ", params_dims, "), but got ", axis)); - if (axis < 0) { - axis += params_dims; + OP_REQUIRES(context, + -params_dims <= axis_input && axis_input < params_dims, + errors::InvalidArgument("Expected axis in the range [", + -params_dims, ", ", params_dims, + "), but got ", axis_input)); + if (axis_input < 0) { + axis_input += params_dims; } + axis = axis_input; } + if (batch_dims_ != 0) { + if (batch_dims_ < 0) { + batch_dims_ = indices_shape.dims() + batch_dims_; + } + + axis = axis.value_or(batch_dims_); + + OP_REQUIRES(context, + batch_dims_ >= -indices_shape.dims() && + batch_dims_ < indices_shape.dims(), + errors::InvalidArgument("Expected batch_dims in the range [", + -indices_shape.dims(), ", ", + indices_shape.dims(), "), but got ", + batch_dims_)); + + OP_REQUIRES(context, batch_dims_ < input_shape.dims(), + errors::InvalidArgument("batch_dims (", batch_dims_, + ") must be less than rank(input) (", + input_shape.dims(), ").")); + + OP_REQUIRES(context, *axis >= batch_dims_, + errors::InvalidArgument("batch_dims (", batch_dims_, + ") must be less than or equal to ", + "axis (", *axis, ").")); + } + + axis = axis.value_or(0); DataType index_type = input_type(1); OP_REQUIRES(context, index_type == DT_INT32 || index_type == DT_INT64, errors::InvalidArgument("indices must be int32 or int64")); xla::XlaOp gather; - OP_REQUIRES_OK( - context, XlaGather(input, input_shape, indices, indices_shape, axis, - /*indices_are_nd=*/false, input_type(0), index_type, - builder, &gather)); + if (batch_dims_ > 0) { + gather = xla::TorchIndexSelect(input, indices, *axis, batch_dims_); + } else { + // XlaGather() manages degenerate cases, like empty-indices, which are + // error conditions and caught above if batch_dims is not 0. + OP_REQUIRES_OK( + context, XlaGather(input, input_shape, indices, indices_shape, *axis, + /*indices_are_nd=*/false, input_type(0), + index_type, context->builder(), &gather)); + } context->SetOutput(0, gather); } private: TF_DISALLOW_COPY_AND_ASSIGN(GatherOp); + + // The number of batch dimensions, as passed in the batch_dims attribute. + // It must be less than rank(indices). + int32 batch_dims_ = 0; }; REGISTER_XLA_OP(Name("Gather"), GatherOp); diff --git a/tensorflow/compiler/tf2xla/kernels/in_topk_op.cc b/tensorflow/compiler/tf2xla/kernels/in_topk_op.cc index 9c6fcf429d4..246d3f6da94 100644 --- a/tensorflow/compiler/tf2xla/kernels/in_topk_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/in_topk_op.cc @@ -81,20 +81,21 @@ class InTopKOp : public XlaOpKernel { xla::CreateScalarAddComputation(xla::F32, xla_builder), {1}); // Calculate in each row of `predictions`, how many values are larger than - // the value of target class. Then return the result whether the count <= k, + // the value of target class. Then return the result whether the count < k, // which indicates the target is in topk. - xla::XlaOp ge_r2 = xla::Ge(predictions_r2, targets_values_r1, {0}); + xla::XlaOp gt_r2 = xla::Gt(predictions_r2, targets_values_r1, {0}); xla::XlaOp zero_r0 = xla::Zero(xla_builder, xla::S32); xla::XlaOp zero_r2 = xla::Broadcast(zero_r0, predictions_shape.dim_sizes()); xla::XlaOp one_r0 = xla::One(xla_builder, xla::S32); xla::XlaOp one_r2 = xla::Broadcast(one_r0, predictions_shape.dim_sizes()); - xla::XlaOp one_hot_r2 = xla::Select(ge_r2, one_r2, zero_r2); - xla::XlaOp num_ge_r1 = xla::Reduce( + xla::XlaOp one_hot_r2 = xla::Select(gt_r2, one_r2, zero_r2); + xla::XlaOp num_gt_r1 = xla::Reduce( one_hot_r2, zero_r0, xla::CreateScalarAddComputation(xla::S32, xla_builder), {1}); xla::XlaOp result = - xla::Le(num_ge_r1, xla::ConstantR0(xla_builder, k)); + xla::And(xla::Lt(num_gt_r1, xla::ConstantR0(xla_builder, k)), + xla::IsFinite(targets_values_r1)); context->SetOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc index f36e0025250..a3fcb4d4b8f 100644 --- a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc @@ -67,9 +67,9 @@ class MatMulOp : public XlaOpKernel { OP_REQUIRES(ctx, a_shape.dim_size(first_index) == b_shape.dim_size(second_index), - errors::InvalidArgument("Matrix size-compatible: In[0]: ", - a_shape.DebugString(), ", In[1]: ", - b_shape.DebugString())); + errors::InvalidArgument( + "Matrix size-incompatible: In[0]: ", a_shape.DebugString(), + ", In[1]: ", b_shape.DebugString())); xla::XlaOp a = ctx->Input(0); xla::XlaOp b = ctx->Input(1); diff --git a/tensorflow/compiler/tf2xla/kernels/select_op.cc b/tensorflow/compiler/tf2xla/kernels/select_op.cc index ed303ba2774..70e4f96c0da 100644 --- a/tensorflow/compiler/tf2xla/kernels/select_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/select_op.cc @@ -15,6 +15,7 @@ limitations under the License. #include +#include "tensorflow/compiler/tf2xla/lib/broadcast.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -22,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/util/bcast.h" namespace tensorflow { namespace { @@ -77,5 +79,58 @@ class SelectOp : public XlaOpKernel { REGISTER_XLA_OP(Name("Select"), SelectOp); +class SelectOpV2 : public XlaOpKernel { + public: + explicit SelectOpV2(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape cond_shape = ctx->InputShape(0); + const TensorShape then_shape = ctx->InputShape(1); + const TensorShape else_shape = ctx->InputShape(2); + + // Compute the output shape from the broadcast of the two data inputs, with + // the broadcast of the conditional. + // Then Broadcast all three inputs to the output shape and emit a select. + + BCast bcast_then_else(BCast::FromShape(then_shape), + BCast::FromShape(else_shape), + /*fewer_dims_optimization=*/false); + if (!bcast_then_else.IsValid()) { + ctx->SetStatus(errors::InvalidArgument( + "Incompatible shapes: ", then_shape.DebugString(), " vs. ", + else_shape.DebugString())); + return; + } + BCast bcast(bcast_then_else.output_shape(), BCast::FromShape(cond_shape), + /*fewer_dims_optimization=*/false); + if (!bcast.IsValid()) { + ctx->SetStatus(errors::InvalidArgument( + "Incompatible shapes: ", + BCast::ToShape(bcast_then_else.output_shape()).DebugString(), " vs. ", + cond_shape.DebugString())); + return; + } + + auto bcasted_cond = BroadcastTo(ctx->Input(0), bcast.output_shape()); + OP_REQUIRES_OK(ctx, bcasted_cond.status()); + auto cond_handle = bcasted_cond.ValueOrDie(); + + auto bcasted_then = BroadcastTo(ctx->Input(1), bcast.output_shape()); + OP_REQUIRES_OK(ctx, bcasted_then.status()); + auto then_handle = bcasted_then.ValueOrDie(); + + auto bcasted_else = BroadcastTo(ctx->Input(2), bcast.output_shape()); + OP_REQUIRES_OK(ctx, bcasted_else.status()); + auto else_handle = bcasted_else.ValueOrDie(); + + ctx->SetOutput(0, xla::Select(cond_handle, then_handle, else_handle)); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(SelectOpV2); +}; + +REGISTER_XLA_OP(Name("SelectV2"), SelectOpV2); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc index 20da8033536..dc1b0c21096 100644 --- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc @@ -16,6 +16,7 @@ limitations under the License. // XLA-specific Ops for softmax. #include "absl/strings/match.h" +#include "tensorflow/compiler/tf2xla/lib/broadcast.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -145,23 +146,36 @@ class SoftmaxXentWithLogitsOp : public XlaOpKernel { : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - const TensorShape logits_shape = ctx->InputShape(0); - const TensorShape labels_shape = ctx->InputShape(1); - OP_REQUIRES(ctx, logits_shape.IsSameSize(labels_shape), - errors::InvalidArgument( - "logits and labels must be same size: logits_size=", - logits_shape.DebugString(), - " labels_size=", labels_shape.DebugString())); - OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape), - errors::InvalidArgument("logits must be 2-dimensional")); - // As we already tested that both inputs have the same shape no need to - // check that "labels" is a matrix too. - const DataType type = input_type(0); const xla::PrimitiveType xla_type = ctx->input_xla_type(0); auto logits = ctx->Input(0); auto labels = ctx->Input(1); + const TensorShape logits_shape = ctx->InputShape(0); + const TensorShape labels_shape = ctx->InputShape(1); + OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape), + errors::InvalidArgument("logits must be 2-dimensional")); + OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(labels_shape), + errors::InvalidArgument("labels must be 2-dimensional")); + + // Confirm that any necessary broadcasting to make the shapes the same will + // succeed. + for (int dim = 0; dim < 2; dim++) { + OP_REQUIRES( + ctx, + labels_shape.dim_size(dim) == 1 || + logits_shape.dim_size(dim) == labels_shape.dim_size(dim), + errors::InvalidArgument("logits and labels must be same size after " + "broadcasting of labels: logits_size=", + logits_shape.DebugString(), + " labels_size=", labels_shape.DebugString())); + } + if (!logits_shape.IsSameSize(labels_shape)) { + auto labels_or = BroadcastTo(labels, logits_shape.dim_sizes()); + OP_REQUIRES_OK(ctx, labels_or.status()); + labels = labels_or.ConsumeValueOrDie(); + } + xla::XlaOp loss, backprop; std::tie(loss, backprop) = CrossEntropyWithLogits(ctx, type, xla_type, logits, labels); diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index f9ce50be6e3..5b1f92b24c8 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -1,9 +1,8 @@ # Utilities for building XLA computations. -licenses(["notice"]) # Apache 2.0 - package( default_visibility = ["//tensorflow/compiler/tf2xla:friends"], + licenses = ["notice"], # Apache 2.0 ) # Filegroup used to collect source files for dependency checking. diff --git a/tensorflow/compiler/tf2xla/ops/BUILD b/tensorflow/compiler/tf2xla/ops/BUILD index 4f1f3d7c326..17a62e83d5f 100644 --- a/tensorflow/compiler/tf2xla/ops/BUILD +++ b/tensorflow/compiler/tf2xla/ops/BUILD @@ -1,9 +1,8 @@ package( default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], # Apache 2.0 ) -licenses(["notice"]) # Apache 2.0 - load( "//tensorflow:tensorflow.bzl", "tf_custom_op_library", diff --git a/tensorflow/compiler/tf2xla/python/BUILD b/tensorflow/compiler/tf2xla/python/BUILD index c6f57b386eb..c731d52ea2b 100644 --- a/tensorflow/compiler/tf2xla/python/BUILD +++ b/tensorflow/compiler/tf2xla/python/BUILD @@ -1,9 +1,8 @@ -licenses(["notice"]) # Apache 2.0 - package( default_visibility = [ "//visibility:public", ], + licenses = ["notice"], # Apache 2.0 ) load( diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index b8eda1de94a..dcdf5acdace 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -550,6 +550,7 @@ std::unique_ptr XlaCompiler::GetGraph(const FunctionBody* fbody) { }; GraphOptimizer::Options graph_optimizer_options; graph_optimizer_options.cf_consider_fn = cf_consider_fn; + graph_optimizer_options.inline_multi_device_functions = true; optimizer.Optimize(flib_runtime_, flib_runtime_->env(), /*device=*/nullptr, &graph, graph_optimizer_options); diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index 95d1bf25150..7c6c53a225f 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -116,9 +116,12 @@ class XlaOpRegistry { // If we should cluster operations returning DT_VARIANT. bool cluster_variant_ops = false; - // Whether ops known to be slow or to have correctness issues should be + // Whether ops known to be slow should be auto-clustered. + bool cluster_slow_ops = false; + + // Whether ops known to have numerical accuracy issues should be // auto-clustered. - bool cluster_slow_and_inaccurate_ops = false; + bool cluster_inaccurate_ops = false; }; // Registers an XLA backend. `compilation_device_name` is the name of the diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 91f33ff914e..60c8c857f0e 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -1,6 +1,7 @@ -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//tensorflow:internal"]) +package( + default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], # Apache 2.0 +) package_group( name = "friends", @@ -575,6 +576,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":types", + "@com_google_absl//absl/strings", ], ) @@ -881,6 +883,26 @@ tf_cc_test( ], ) +cc_library( + name = "refcounting_hash_map", + hdrs = ["refcounting_hash_map.h"], + deps = [ + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/synchronization", + ], +) + +tf_cc_test( + name = "refcounting_hash_map_test", + srcs = ["refcounting_hash_map_test.cc"], + deps = [ + ":refcounting_hash_map", + ":test", + "//tensorflow/core:test_main", + ], +) + # ----------------------------------------------------------------------------- # This is a headers target that extra XLA devices can use to prevent circular dependencies. Devices that are compiled as separate shared objects can also use it to prevent linking of library code. diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index b800229bd90..806521756dc 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -1,9 +1,10 @@ # Description: # XLA client libraries. -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//visibility:public"]) +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) package_group( name = "friends", diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index 4a99debbe70..acf59c47f3c 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -2,9 +2,10 @@ load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites", "xla_test") -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//tensorflow/compiler/xla/client:friends"]) +package( + default_visibility = ["//tensorflow/compiler/xla/client:friends"], + licenses = ["notice"], # Apache 2.0 +) # Filegroup used to collect source files for dependency checking. filegroup( @@ -472,11 +473,6 @@ cc_library( xla_test( name = "svd_test", srcs = ["svd_test.cc"], - # Blacklisted because the tests are flaky. - blacklisted_backends = [ - "cpu", - "gpu", - ], real_hardware_only = True, shard_count = 10, tags = ["optonly"], diff --git a/tensorflow/compiler/xla/client/lib/matrix.cc b/tensorflow/compiler/xla/client/lib/matrix.cc index 93f3d3ab131..902269d9412 100644 --- a/tensorflow/compiler/xla/client/lib/matrix.cc +++ b/tensorflow/compiler/xla/client/lib/matrix.cc @@ -46,23 +46,34 @@ XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, return ConvertElementType(indicator, type); } +XlaOp GetDiagonalMask(XlaOp x, int diagonal) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + auto n_dims = static_cast(shape.rank()); + TF_RET_CHECK(n_dims >= 2); + auto m = shape.dimensions(n_dims - 2); + auto n = shape.dimensions(n_dims - 1); + absl::Span major_dims = + AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2); + auto a = Iota(builder, S32, n); + auto b = Iota(builder, S32, m) + ConstantR0WithType(builder, S32, diagonal); + auto indicator = Eq(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); + auto mask = Broadcast(indicator, major_dims); + return mask; + }); +} + XlaOp GetMatrixDiagonal(XlaOp x, int k) { XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64 n_dims = shape.rank(); + auto n_dims = static_cast(shape.rank()); TF_RET_CHECK(n_dims >= 2); const int64 m = shape.dimensions(n_dims - 2); const int64 n = shape.dimensions(n_dims - 1); - auto offset = ConstantR0WithType(builder, S32, k); - - absl::Span major_dims = - AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2); - auto a = Iota(builder, S32, n); - auto b = Iota(builder, S32, m) + offset; - auto indicator = Eq(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); - auto mask = Broadcast(indicator, major_dims); + auto mask = GetDiagonalMask(x, k); // TPUs don't support S64 add reduction at the moment. But fortunately // OR-reductions work just as well for integers. diff --git a/tensorflow/compiler/xla/client/lib/matrix.h b/tensorflow/compiler/xla/client/lib/matrix.h index 5f1ca964a41..541ce2897f5 100644 --- a/tensorflow/compiler/xla/client/lib/matrix.h +++ b/tensorflow/compiler/xla/client/lib/matrix.h @@ -31,6 +31,10 @@ namespace xla { // else. XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, int64 n); +// Returns a mask where the 'diagonal'-th diagonal is true and everything else +// is false. +XlaOp GetDiagonalMask(XlaOp x, int diagonal = 0); + // Get the diagonals of the last two dimensions. Use k>0 for diagonals above the // main diagonal, and k<0 for diagonals below the main diagonal. // diff --git a/tensorflow/compiler/xla/client/lib/svd.cc b/tensorflow/compiler/xla/client/lib/svd.cc index 53a23872709..646875a20a2 100644 --- a/tensorflow/compiler/xla/client/lib/svd.cc +++ b/tensorflow/compiler/xla/client/lib/svd.cc @@ -75,11 +75,6 @@ struct OneSidedJacobiRotation { JacobiRotation rot_r; }; -struct FrobeniusNorms { - XlaOp off_diagonal_norm; - XlaOp total_norm; -}; - // Householder reflection on the trailing elements of a vector. // // H = I - beta * [1, v]' * [1, v] @@ -567,27 +562,26 @@ StatusOr OneSidedJacobiUpdate(SVDResult svd_result, XlaOp p, XlaOp q, return svd_result; } -StatusOr ComputeFrobeniusNorms(XlaOp w) { +StatusOr ComputeToleranceComparison(XlaOp w, XlaOp epsilon) { XlaBuilder* builder = w.builder(); TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(w)); - const int64 num_dims = shape.rank(); - auto frobenius_norm = - Sqrt(Reduce(Square(w), ScalarLike(w, 0.0), - CreateScalarAddComputation(shape.element_type(), builder), - {num_dims - 2, num_dims - 1})); - auto diag = GetMatrixDiagonal(w); - auto diag_square = - Reduce(Square(diag), ScalarLike(w, 0.0), - CreateScalarAddComputation(shape.element_type(), builder), - {num_dims - 2}); - - FrobeniusNorms frobenius_norms; - - frobenius_norms.off_diagonal_norm = - Sqrt(Max(Square(frobenius_norm) - diag_square, ScalarLike(w, 0.0))); - frobenius_norms.total_norm = frobenius_norm; - - return frobenius_norms; + auto num_dims = static_cast(shape.rank()); + int64 n = shape.dimensions(num_dims - 1); + shape.set_dimensions(num_dims - 2, n); + auto w_sliced = SliceInMinorDims(w, {0, 0}, {n, n}); + auto diag = GetMatrixDiagonal(w_sliced); + diag = Select(Lt(diag, ZerosLike(diag)), -diag, diag); + std::vector broadcasted_dims(num_dims - 1); + std::iota(broadcasted_dims.begin(), broadcasted_dims.end(), 0); + auto broadcast_to_rows = + BroadcastInDim(diag, shape.dimensions(), broadcasted_dims); + broadcasted_dims.back() = num_dims - 1; + auto broadcast_to_columns = + BroadcastInDim(diag, shape.dimensions(), broadcasted_dims); + // Compute w_{i,i} * w_{j,j} * epsilon^2 < (w_{i,j})^2 + return Lt( + broadcast_to_rows * broadcast_to_columns * epsilon * epsilon, + Square(Select(GetDiagonalMask(w_sliced), ZerosLike(w_sliced), w_sliced))); } // Main boby of One-sided Jacobi Method. @@ -603,13 +597,13 @@ StatusOr> WhileLoopFn( auto max_sweeps = ScalarLike(k, max_sweep_updates); auto sweep_update_cond = Gt(max_sweeps, k); - auto norms = ComputeFrobeniusNorms(values[3]).ValueOrDie(); - auto tol = norms.total_norm * values[4]; - auto tol_cond = ReduceAll(Lt(tol, norms.off_diagonal_norm), - xla::ConstantR0(cond_builder, false), - CreateScalarOrComputation(PRED, cond_builder)); + TF_ASSIGN_OR_RETURN(auto tolerance_comparison, + ComputeToleranceComparison(values[3], values[4])); + auto tolerance_cond = ReduceAll( + tolerance_comparison, xla::ConstantR0(cond_builder, false), + CreateScalarOrComputation(PRED, cond_builder)); - return And(sweep_update_cond, tol_cond); + return And(sweep_update_cond, tolerance_cond); }; auto while_body_fn = diff --git a/tensorflow/compiler/xla/client/lib/svd_test.cc b/tensorflow/compiler/xla/client/lib/svd_test.cc index a987f7fcaf6..a39238548fc 100644 --- a/tensorflow/compiler/xla/client/lib/svd_test.cc +++ b/tensorflow/compiler/xla/client/lib/svd_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/client/lib/svd.h" + #include #include "tensorflow/compiler/xla/array2d.h" @@ -183,12 +184,14 @@ XLA_TEST_F(SVDTest, TestSingleValuesMatchNumpy) { ErrorSpec(1e-3, 1e-3)); } -XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_512x128) { +// Too slow on the interpreter backend. +XLA_TEST_F(SVDTest, + DISABLED_ON_INTERPRETER(Various_Size_Random_Matrix_512x128)) { XlaBuilder builder(TestName()); Array2D a_val = GenerateRandomMatrix(512, 128); XlaOp a; auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); - auto result = SVD(a, 100, 1e-6); + auto result = SVD(a, 100, 1e-4); GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder); ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, @@ -200,7 +203,7 @@ XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_128x256) { Array2D a_val = GenerateRandomMatrix(128, 256); XlaOp a; auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); - auto result = SVD(a, 100, 1e-6); + auto result = SVD(a, 100, 1e-4); GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder); ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, @@ -212,38 +215,44 @@ XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_256x128) { Array2D a_val = GenerateRandomMatrix(256, 128); XlaOp a; auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); - auto result = SVD(a, 100, 1e-6); + auto result = SVD(a, 100, 1e-4); GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder); ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, ErrorSpec(1e-3, 1e-3)); } -XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_128x512) { +// Too slow on the interpreter backend. +XLA_TEST_F(SVDTest, + DISABLED_ON_INTERPRETER(Various_Size_Random_Matrix_128x512)) { XlaBuilder builder(TestName()); Array2D a_val = GenerateRandomMatrix(128, 512); XlaOp a; auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); - auto result = SVD(a, 100, 1e-6); + auto result = SVD(a, 100, 1e-4); GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder); ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, ErrorSpec(1e-3, 1e-3)); } -XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_512x256) { +// Too slow on the interpreter and CPU backends. +XLA_TEST_F(SVDTest, DISABLED_ON_CPU(DISABLED_ON_INTERPRETER( + Various_Size_Random_Matrix_512x256))) { XlaBuilder builder(TestName()); Array2D a_val = GenerateRandomMatrix(512, 256); XlaOp a; auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); - auto result = SVD(a, 100, 1e-6); + auto result = SVD(a, 100, 1e-4); GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder); ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, ErrorSpec(1e-3, 1e-3)); } -XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_512x512) { +// Too slow on the CPU, GPU and interpreter backends. +XLA_TEST_F(SVDTest, DISABLED_ON_GPU(DISABLED_ON_CPU(DISABLED_ON_INTERPRETER( + Various_Size_Random_Matrix_512x512)))) { XlaBuilder builder(TestName()); Array2D a_val = GenerateRandomMatrix(512, 512); XlaOp a; diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 508f16a945f..b5fa1b6ced8 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -150,7 +150,7 @@ class XlaBuilder { // result, OpMetadata is set on the Computation Builder. All subsequent // instructions generated via this Computation Builder will have the same // OpMetadata attached until a call to ClearOpMetadata. - void SetOpMetadata(const OpMetadata& metadata) { metadata_ = metadata; } + void SetOpMetadata(OpMetadata metadata) { metadata_ = std::move(metadata); } // Clears the HloMetadata state. void ClearOpMetadata() { metadata_.Clear(); } diff --git a/tensorflow/compiler/xla/executable_run_options.cc b/tensorflow/compiler/xla/executable_run_options.cc index 39c90b60a09..1cfb449ebd0 100644 --- a/tensorflow/compiler/xla/executable_run_options.cc +++ b/tensorflow/compiler/xla/executable_run_options.cc @@ -15,8 +15,21 @@ limitations under the License. #include "tensorflow/compiler/xla/executable_run_options.h" +#include + +#include "absl/strings/str_cat.h" + namespace xla { +RunId::RunId() { + static std::atomic counter{0}; + data_ = counter.fetch_add(1); +} + +bool operator==(const RunId& a, const RunId& b) { return a.data_ == b.data_; } + +std::string RunId::ToString() const { return absl::StrCat("RunId: ", data_); } + ExecutableRunOptions& ExecutableRunOptions::set_device_ordinal( int device_ordinal) { device_ordinal_ = device_ordinal; @@ -94,4 +107,11 @@ ExecutableRunOptions& ExecutableRunOptions::set_rng_seed(int rng_seed) { int ExecutableRunOptions::rng_seed() const { return rng_seed_; } +ExecutableRunOptions& ExecutableRunOptions::set_run_id(RunId id) { + run_id_ = id; + return *this; +} + +RunId ExecutableRunOptions::run_id() const { return run_id_; } + } // namespace xla diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h index 84629593953..4de8148451b 100644 --- a/tensorflow/compiler/xla/executable_run_options.h +++ b/tensorflow/compiler/xla/executable_run_options.h @@ -16,6 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_ #define TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_ +#include + +#include "tensorflow/compiler/xla/types.h" + // These classes are forward declared so that ExecutableRunOptions can be linked // into an XLA-compiled binary without having to link all of the pointed-to // objects (e.g., for an ahead-of-time compiled CPU binary, the gpu tools don't @@ -35,6 +39,31 @@ namespace xla { class DeviceAssignment; class ExecutionProfile; +// A unique identifier for a particular "logical execution" of an XLA model. +// +// A logical execution might encompass multiple executions of one or more +// HloModules. Runs that are part of the same logical execution can +// communicate via collective ops (e.g. kAllToAll), whereas runs that are part +// of different logical executions are isolated. +class RunId { + public: + // Creates a new, unique RunId. + RunId(); + + RunId(const RunId&) = default; + RunId& operator=(const RunId&) = default; + friend bool operator==(const RunId& a, const RunId& b); + std::string ToString() const; + + template + friend H AbslHashValue(H h, const RunId& id) { + return H::combine(std::move(h), id.data_); + } + + private: + int64 data_; +}; + // Class containing options for running a LocalExecutable. class ExecutableRunOptions { public: @@ -87,6 +116,9 @@ class ExecutableRunOptions { ExecutableRunOptions& set_rng_seed(int rng_seed); int rng_seed() const; + ExecutableRunOptions& set_run_id(RunId id); + RunId run_id() const; + private: stream_executor::DeviceMemoryAllocator* allocator_ = nullptr; int device_ordinal_ = -1; @@ -96,6 +128,7 @@ class ExecutableRunOptions { ExecutionProfile* execution_profile_ = nullptr; int rng_seed_ = 0; stream_executor::Stream* host_to_device_stream_ = nullptr; + RunId run_id_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/experimental/xla_sharding/BUILD b/tensorflow/compiler/xla/experimental/xla_sharding/BUILD index a26b20c8618..57eeb25bb49 100644 --- a/tensorflow/compiler/xla/experimental/xla_sharding/BUILD +++ b/tensorflow/compiler/xla/experimental/xla_sharding/BUILD @@ -1,9 +1,10 @@ # Description: # Python API for shardings in XLA. -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//tensorflow:internal"]) +package( + default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], # Apache 2.0 +) py_library( name = "xla_sharding", diff --git a/tensorflow/compiler/xla/index_util.cc b/tensorflow/compiler/xla/index_util.cc index eebd8245abe..463a8d95fc5 100644 --- a/tensorflow/compiler/xla/index_util.cc +++ b/tensorflow/compiler/xla/index_util.cc @@ -119,6 +119,8 @@ namespace xla { int64 limit = shape.dimensions(dimno); if (indices[dimno] + 1 < limit) { indices[dimno]++; + // Whenever an index of a dimension is increased, it means that all + // following dimensions have maxed out, so they must go to 0. std::fill(indices.begin() + dimno + 1, indices.end(), 0); return true; } diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index c810ae9cbae..3c53592d040 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc index 0431bb3d54a..dc11f7caa2c 100644 --- a/tensorflow/compiler/xla/literal_comparison.cc +++ b/tensorflow/compiler/xla/literal_comparison.cc @@ -662,8 +662,11 @@ Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual) { case PRED: result = Equal(expected, actual, index, 0); break; - case U8: - result = Equal(expected, actual, index, 0); + case S8: + result = Equal(expected, actual, index, 0); + break; + case S16: + result = Equal(expected, actual, index, 0); break; case S32: result = Equal(expected, actual, index, 0); @@ -671,6 +674,12 @@ Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual) { case S64: result = Equal(expected, actual, index, 0); break; + case U8: + result = Equal(expected, actual, index, 0); + break; + case U16: + result = Equal(expected, actual, index, 0); + break; case U32: result = Equal(expected, actual, index, 0); break; diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index 45a3a264fd6..49f41d232a2 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -1,6 +1,7 @@ -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//tensorflow:internal"]) +package( + default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], # Apache 2.0 +) load("//tensorflow/core:platform/default/build_config.bzl", "pyx_library") load("//tensorflow/compiler/xla:xla.bzl", "xla_python_default_plugins") @@ -145,7 +146,6 @@ cc_library( ":shared_device_buffer", ":types", ":worker_thread", - "//tensorflow/compiler/jit:xla_launch_util", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", @@ -166,6 +166,7 @@ cc_library( "//tensorflow/core:gpu_mem_allocator", "//tensorflow/core:lib", "//tensorflow/core/profiler/lib:traceme", + "//tensorflow/stream_executor:tf_allocator_adapter", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", diff --git a/tensorflow/compiler/xla/python/local_client.cc b/tensorflow/compiler/xla/python/local_client.cc index facc61d515d..e13637c2fd9 100644 --- a/tensorflow/compiler/xla/python/local_client.cc +++ b/tensorflow/compiler/xla/python/local_client.cc @@ -71,7 +71,6 @@ limitations under the License. #include "absl/synchronization/notification.h" #include "absl/time/time.h" #include "include/pybind11/pybind11.h" -#include "tensorflow/compiler/jit/xla_launch_util.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/executable_run_options.h" @@ -88,6 +87,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/gpu/gpu_mem_allocator.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/lib/traceme.h" +#include "tensorflow/stream_executor/tf_allocator_adapter.h" namespace xla { @@ -162,10 +162,25 @@ Device::Device(se::StreamExecutor* executor, bool use_multiple_streams, } Device::~Device() { + Status status = SynchronizeAllActivity(); + if (!status.ok()) { + LOG(ERROR) << "Error when closing device: " << status; + } +} + +Status Device::SynchronizeAllActivity() { + Status status; + // TODO(phawkins): in theory the call to SynchronizeAllActivity below should + // suffice. However on the Host platform SynchronizeAllActivity is a dummy + // implementation that doesn't actually block. To make sure activity has + // stopped, also block on the compute stream. If SynchronizeAllActivity is + // fixed, we could remove the BlockHostUntilDone call. + status.Update(compute_stream_->BlockHostUntilDone()); bool ok = compute_stream_->parent()->SynchronizeAllActivity(); if (!ok) { - LOG(ERROR) << "SynchronizeAllActivity failed when destroying Device."; + status.Update(Unknown("SynchronizeAllActivity failed.")); } + return status; } void Device::ThenExecuteOnWorkerThread(se::Stream* stream, @@ -174,18 +189,17 @@ void Device::ThenExecuteOnWorkerThread(se::Stream* stream, [this, callback]() { worker_thread_->Schedule(std::move(callback)); }); } -static StatusOr> -CreateBFCAllocator(se::Platform* platform, LocalClient* client, - double memory_fraction) { +static StatusOr> CreateBFCAllocator( + se::Platform* platform, LocalClient* client, double memory_fraction) { CHECK_GT(client->backend().device_count(), 0); std::vector> allocators; for (se::StreamExecutor* executor : client->backend().stream_executors()) { int device_ordinal = executor->device_ordinal(); - tensorflow::GPUMemAllocator* sub_allocator = - new tensorflow::GPUMemAllocator( - executor, tensorflow::PlatformGpuId(device_ordinal), - /*use_unified_memory=*/false, /*alloc_visitors=*/{}, - /*free_visitors=*/{}); + auto sub_allocator = absl::make_unique( + executor, tensorflow::PlatformGpuId(device_ordinal), + /*use_unified_memory=*/false, + /*alloc_visitors=*/std::vector(), + /*free_visitors=*/std::vector()); int64 free_memory; int64 total_memory; @@ -198,13 +212,13 @@ CreateBFCAllocator(se::Platform* platform, LocalClient* client, << total_memory << " bytes on device " << device_ordinal << " for BFCAllocator."; - tensorflow::BFCAllocator* gpu_bfc_allocator = new tensorflow::BFCAllocator( - sub_allocator, allocator_memory, /*allow_growth=*/false, + auto gpu_bfc_allocator = absl::make_unique( + sub_allocator.release(), allocator_memory, /*allow_growth=*/false, absl::StrCat("GPU_", device_ordinal, "_bfc")); - allocators.emplace_back(gpu_bfc_allocator); + allocators.emplace_back(std::move(gpu_bfc_allocator)); } - return absl::make_unique( - platform, std::move(allocators)); + return absl::make_unique(platform, + std::move(allocators)); } StatusOr> PyLocalClient::Get( @@ -250,8 +264,7 @@ PyLocalClient::PyLocalClient( allocator_ = client_->backend().memory_allocator(); } devices_.reserve(client->device_count()); - // TODO(phawkins): enable multistream mode on GPU too. - bool use_multiple_streams = (platform_name == "tpu"); + bool use_multiple_streams = (platform_name_ != "cpu"); bool synchronous_deallocation = !use_multiple_streams; for (int i = 0; i < client->device_count(); ++i) { se::StreamExecutor* executor = @@ -281,7 +294,7 @@ StatusOr PyLocalClient::TransferFromOutfeed( return LiteralToPython(absl::make_unique(std::move(literal))); } -static StatusOr TransferHostToDeviceAsync( +static StatusOr> TransferHostToDeviceAsync( const PythonBufferTree& tree, int device_ordinal, std::shared_ptr client, const Device& device) { se::DeviceMemoryAllocator* allocator = client->allocator(); @@ -315,8 +328,9 @@ static StatusOr TransferHostToDeviceAsync( } std::shared_ptr definition_event; if (device.use_multiple_streams()) { - definition_event = std::make_shared( - device.host_to_device_stream()->parent()); + TF_ASSIGN_OR_RETURN(definition_event, + BufferDefinitionEvent::Create( + device.host_to_device_stream()->parent())); definition_event->RecordOnStream(device.host_to_device_stream()); } std::shared_ptr device_buffer = @@ -326,11 +340,12 @@ static StatusOr TransferHostToDeviceAsync( device.ThenReleaseOnWorkerThread(device.host_to_device_stream(), device_buffer); } - return PyLocalBuffer(shape, std::move(device_buffer), std::move(client)); + return absl::make_unique(shape, std::move(device_buffer), + std::move(client)); } /* static */ -StatusOr PyLocalBuffer::FromPython( +StatusOr> PyLocalBuffer::FromPython( const py::object& argument, std::shared_ptr client, int device_ordinal) { tensorflow::profiler::TraceMe traceme("PyLocalBuffer::FromPython"); @@ -349,7 +364,7 @@ StatusOr PyLocalBuffer::FromPython( << " device ordinal: " << device_ordinal; const Device& device = client->device(device_ordinal); - TF_ASSIGN_OR_RETURN(PyLocalBuffer buffer, + TF_ASSIGN_OR_RETURN(std::unique_ptr buffer, TransferHostToDeviceAsync(tree, device_ordinal, std::move(client), device)); @@ -357,20 +372,20 @@ StatusOr PyLocalBuffer::FromPython( return buffer; } -/*static */ StatusOr> +/*static */ StatusOr>> PyLocalBuffer::FromPythonValues( const std::vector>& arguments, std::shared_ptr client) { tensorflow::profiler::TraceMe traceme("PyLocalBuffer::FromPythonValues"); int num_arguments = static_cast(arguments.size()); - std::vector outputs(num_arguments); + std::vector> outputs(num_arguments); if (num_arguments == 0) { return outputs; } struct H2DTransfer { PythonBufferTree tree; - StatusOr buffer; + StatusOr> buffer; PythonRefManager::ManagedPyObjects py_buffer_refs; }; @@ -385,7 +400,7 @@ PyLocalBuffer::FromPythonValues( // We are done manipulating Python objects; release the GIL. py::gil_scoped_release gil_release; - auto transfer_h2d = [&](int i) -> StatusOr { + auto transfer_h2d = [&](int i) -> StatusOr> { int device_ordinal = arguments[i].second; return TransferHostToDeviceAsync(transfers[i].tree, device_ordinal, client, client->device(device_ordinal)); @@ -420,18 +435,24 @@ PyLocalBuffer::FromPythonValues( return outputs; } -/* static */ StatusOr PyLocalBuffer::MakeTuple( - const std::vector buffers, +/* static */ StatusOr> PyLocalBuffer::MakeTuple( + const std::vector buffers, std::shared_ptr client, int device_ordinal) { std::vector host_shapes; std::vector> device_buffers; host_shapes.reserve(buffers.size()); device_buffers.reserve(buffers.size()); - for (const PyLocalBuffer& buffer : buffers) { - TF_RET_CHECK(buffer.device_buffer()->device_memory().device_ordinal() == - device_ordinal); - host_shapes.push_back(buffer.on_host_shape()); - device_buffers.push_back(buffer.device_buffer()); + for (const PyLocalBuffer* buffer : buffers) { + TF_RET_CHECK(buffer->device_ordinal() == device_ordinal); + std::shared_ptr device_buffer = + buffer->DeviceBuffer(); + if (!device_buffer) { + return InvalidArgument( + "Invalid buffer passed to MakeTuple() as argument %d.", + device_buffers.size()); + } + host_shapes.push_back(buffer->on_host_shape()); + device_buffers.push_back(std::move(device_buffer)); } se::DeviceMemoryAllocator* allocator = client->allocator(); TransferManager* transfer_manager = @@ -439,19 +460,20 @@ PyLocalBuffer::FromPythonValues( const Device& device = client->device(device_ordinal); std::shared_ptr definition_event; if (device.use_multiple_streams()) { - definition_event = std::make_shared( - device.host_to_device_stream()->parent()); + TF_ASSIGN_OR_RETURN(definition_event, + BufferDefinitionEvent::Create( + device.host_to_device_stream()->parent())); } TF_ASSIGN_OR_RETURN(std::shared_ptr tuple_buffer, PySharedDeviceBuffer::MakeTuple( device_buffers, transfer_manager, allocator, device_ordinal, definition_event)); - PyLocalBuffer buffer(ShapeUtil::MakeTupleShape(host_shapes), tuple_buffer, - std::move(client)); + auto buffer = absl::make_unique( + ShapeUtil::MakeTupleShape(host_shapes), tuple_buffer, std::move(client)); // TODO(phawkins): extend TransferManager so we do not need to form a full // ShapedBuffer just to write the root tuple index table. - ShapedBuffer shaped_buffer = buffer.AsShapedBuffer(); + TF_ASSIGN_OR_RETURN(ShapedBuffer shaped_buffer, buffer->AsShapedBuffer()); if (device.use_multiple_streams() && !transfer_manager->CanShapedBufferBeAccessedNow( device.host_to_device_stream()->parent(), shaped_buffer)) { @@ -476,21 +498,33 @@ PyLocalBuffer::PyLocalBuffer( std::shared_ptr client) : client_(std::move(client)), on_host_shape_(std::move(on_host_shape)), + device_ordinal_(device_buffer->device_ordinal()), device_buffer_(std::move(device_buffer)) {} +void PyLocalBuffer::Delete() { + absl::MutexLock lock(&mu_); + device_buffer_ = nullptr; +} + StatusOr PyLocalBuffer::ToPython() const { tensorflow::profiler::TraceMe traceme("PyLocalBuffer::ToPython"); - auto literal = absl::make_unique(on_host_shape()); + std::shared_ptr device_buffer = DeviceBuffer(); + if (!device_buffer) { + return InvalidArgument("ToPython() called on invalid buffer."); + } + + auto literal = absl::make_unique(on_host_shape_); client_->py_ref_manager().CollectGarbage(); { py::gil_scoped_release gil_release; - se::Stream* stream = client_->device(device_buffer_->device_ordinal()) + se::Stream* stream = client_->device(device_buffer->device_ordinal()) .device_to_host_stream(); - WaitForBufferDefinitionEventsOnStream(*device_buffer_, stream); + WaitForBufferDefinitionEventsOnStream(*device_buffer, stream); absl::Notification done; Status status; + TF_ASSIGN_OR_RETURN(ShapedBuffer shaped_buffer, AsShapedBuffer()); client_->client()->backend().transfer_manager()->TransferLiteralFromDevice( - stream, AsShapedBuffer(), *literal, [&](Status done_status) { + stream, shaped_buffer, *literal, [&](Status done_status) { status = done_status; done.Notify(); }); @@ -499,28 +533,64 @@ StatusOr PyLocalBuffer::ToPython() const { return LiteralToPython(std::move(literal)); } -ShapedBuffer PyLocalBuffer::AsShapedBuffer() const { +std::shared_ptr PyLocalBuffer::DeviceBuffer() const { + absl::MutexLock lock(&mu_); + return device_buffer_; +} + +StatusOr PyLocalBuffer::AsShapedBuffer() const { + absl::MutexLock lock(&mu_); + if (!device_buffer_) { + return InvalidArgument( + "Attempted to fetch value of invalid/deleted buffer."); + } return device_buffer_->AsShapedBuffer(on_host_shape_); } -StatusOr> PyLocalBuffer::DestructureTuple() { +StatusOr>> +PyLocalBuffer::DestructureTuple() { tensorflow::profiler::TraceMe traceme("PyLocalBuffer::DestructureTuple"); - if (!on_host_shape().IsTuple()) { + absl::MutexLock lock(&mu_); + if (!on_host_shape_.IsTuple()) { return InvalidArgument( "Attemped to destructure a PyLocalBuffer that did not have a tuple " "shape; shape: %s", - ShapeUtil::HumanString(on_host_shape())); + ShapeUtil::HumanString(on_host_shape_)); } - int num_children = ShapeUtil::TupleElementCount(on_host_shape()); - std::vector results; + if (!device_buffer_) { + return InvalidArgument("Attempted to destructure a deleted buffer."); + } + int num_children = ShapeUtil::TupleElementCount(on_host_shape_); + std::vector> results; results.reserve(num_children); for (int64 i = 0; i < num_children; ++i) { - results.push_back(PyLocalBuffer(on_host_shape().tuple_shapes(i), - device_buffer_->children().at(i), client_)); + results.push_back(absl::make_unique( + on_host_shape_.tuple_shapes(i), device_buffer_->children().at(i), + client_)); } return results; } +Status PyLocalBuffer::BlockHostUntilReady() { + tensorflow::profiler::TraceMe traceme("PyLocalBuffer::BlockHostUntilReady"); + std::shared_ptr device_buffer = DeviceBuffer(); + if (!device_buffer) { + return InvalidArgument("BlockHostUntilReady() called on invalid buffer."); + } + + client_->py_ref_manager().CollectGarbage(); + py::gil_scoped_release gil_release; + + // This code waits at least until the buffer is ready, but it may wait longer + // if there are other device to host transfers scheduled. If this proves to + // be an issue, we could either use a separate stream for this purpose, or + // poll for the buffer definition events. + se::Stream* stream = + client_->device(device_buffer->device_ordinal()).device_to_host_stream(); + WaitForBufferDefinitionEventsOnStream(*device_buffer, stream); + return stream->BlockHostUntilDone(); +} + PyLocalExecutable::PyLocalExecutable( std::shared_ptr executable, DeviceAssignment device_assignment, std::shared_ptr client) @@ -538,7 +608,7 @@ std::vector PyLocalExecutable::DeviceOrdinals() const { return device_ordinals; } -StatusOr PyLocalExecutable::ExecuteHelper( +StatusOr> PyLocalExecutable::ExecuteHelper( absl::Span argument_handles, int replica) { const int device_ordinal = device_assignment_(replica, 0); tensorflow::profiler::TraceMe traceme("LocalExecutable::Execute"); @@ -546,28 +616,34 @@ StatusOr PyLocalExecutable::ExecuteHelper( << " mapped to device ordinal for execution: " << device_ordinal; absl::flat_hash_set events; + std::vector> device_buffers; std::vector argument_buffers; std::vector argument_buffer_ptrs; + device_buffers.reserve(argument_handles.size() + 1); argument_buffers.reserve(argument_handles.size()); argument_buffer_ptrs.reserve(argument_handles.size()); - for (auto& handle : argument_handles) { - if (handle->device_buffer() == nullptr) { + for (int i = 0; i < argument_handles.size(); ++i) { + PyLocalBuffer* handle = argument_handles[i]; + std::shared_ptr device_buffer = + handle->DeviceBuffer(); + if (!device_buffer) { return InvalidArgument( "Deleted buffer passed to Execute() as argument " "%d to replica %d", - argument_buffers.size(), replica); + i, replica); } - if (handle->device_buffer()->device_ordinal() != device_ordinal) { + if (device_buffer->device_ordinal() != device_ordinal) { return InvalidArgument( "Buffer passed to Execute() as argument %d to replica %d is on " "device %d, but replica is assigned to device %d.", - argument_buffers.size(), replica, - handle->device_buffer()->device_ordinal(), device_ordinal); + i, replica, device_buffer->device_ordinal(), device_ordinal); } - argument_buffers.push_back(handle->AsShapedBuffer()); + TF_ASSIGN_OR_RETURN(ShapedBuffer shaped_buffer, handle->AsShapedBuffer()); + argument_buffers.push_back(std::move(shaped_buffer)); argument_buffer_ptrs.push_back(&argument_buffers.back()); - GetDeviceBufferDefinitionEvents(*handle->device_buffer(), &events); - VLOG(4) << "Argument " << argument_buffers.size() - 1 + GetDeviceBufferDefinitionEvents(*device_buffer, &events); + device_buffers.push_back(std::move(device_buffer)); + VLOG(4) << "Argument " << i << " buffer: " << argument_buffers.back().ToString(); } @@ -603,8 +679,9 @@ StatusOr PyLocalExecutable::ExecuteHelper( std::shared_ptr definition_event; if (device.use_multiple_streams()) { - definition_event = std::make_shared( - device.compute_stream()->parent()); + TF_ASSIGN_OR_RETURN( + definition_event, + BufferDefinitionEvent::Create(device.compute_stream()->parent())); definition_event->RecordOnStream(device.compute_stream()); } Shape on_host_shape = result_buffer.ValueOrDie().on_host_shape(); @@ -613,20 +690,16 @@ StatusOr PyLocalExecutable::ExecuteHelper( std::move(result_buffer.ValueOrDie()), definition_event); if (device.synchronous_deallocation()) { - std::vector> buffers; - buffers.reserve(argument_handles.size() + 1); - for (auto& handle : argument_handles) { - buffers.push_back(handle->device_buffer()); - } - buffers.push_back(out_buffer); + device_buffers.push_back(out_buffer); device.ThenReleaseOnWorkerThread(device.compute_stream(), - std::move(buffers)); - device.ThenReleaseOnWorkerThread(device.compute_stream(), executable_); + std::move(device_buffers)); } - return PyLocalBuffer(on_host_shape, std::move(out_buffer), client_); + device.ThenReleaseOnWorkerThread(device.compute_stream(), executable_); + return absl::make_unique(on_host_shape, std::move(out_buffer), + client_); } -StatusOr PyLocalExecutable::Execute( +StatusOr> PyLocalExecutable::Execute( absl::Span argument_handles) { if (num_replicas() != 1) { return InvalidArgument( @@ -636,7 +709,8 @@ StatusOr PyLocalExecutable::Execute( return ExecuteHelper(argument_handles, /*replica=*/0); } -StatusOr> PyLocalExecutable::ExecutePerReplica( +StatusOr>> +PyLocalExecutable::ExecutePerReplica( absl::Span> argument_handles) { tensorflow::profiler::TraceMe traceme("LocalExecutable::ExecutePerReplica"); const int num_devices = client_->device_count(); @@ -654,7 +728,7 @@ StatusOr> PyLocalExecutable::ExecutePerReplica( VLOG(1) << "Executing replicated computation; num_replicas=" << num_replicas(); - std::vector> results(num_replicas()); + std::vector>> results(num_replicas()); if (num_replicas() == 1) { // Fast-path if there is only one replica — run the computation on the // current thread. @@ -710,7 +784,7 @@ StatusOr> PyLocalExecutable::ExecutePerReplica( } VLOG(1) << "Replicated execution complete."; - std::vector wrapped_results(num_replicas()); + std::vector> wrapped_results(num_replicas()); for (int replica = 0; replica < num_replicas(); ++replica) { auto& statusor = results[replica]; if (!statusor.ok()) { @@ -728,12 +802,45 @@ StatusOr> PyLocalExecutable::ExecutePerReplica( /*static*/ StatusOr> PyLocalExecutable::Compile(const XlaComputation& computation, - std::vector argument_layouts, + absl::optional> argument_layouts, const ExecutableBuildOptions* build_options, - std::shared_ptr client) { + std::shared_ptr client, + absl::optional device_assignment) { tensorflow::profiler::TraceMe traceme("LocalExecutable::Compile"); + + ExecutableBuildOptions options; + if (build_options != nullptr) { + options = *build_options; + } + + if (device_assignment) { + if (device_assignment->replica_count() != options.num_replicas()) { + return InvalidArgument( + "Mismatched number of replicas for device " + "assignment and computation (%d vs %d).", + device_assignment->replica_count(), options.num_replicas()); + } else if (device_assignment->computation_count() != 1) { + return Unimplemented( + "Only 1 computation per replica supported, %d requested.", + device_assignment->computation_count()); + } + } else { + TF_ASSIGN_OR_RETURN( + device_assignment, + client->client()->backend().computation_placer()->AssignDevices( + options.num_replicas(), /*computation_count=*/1)); + } + + if (!argument_layouts) { + TF_ASSIGN_OR_RETURN(ProgramShape program_shape, + computation.GetProgramShape()); + argument_layouts = program_shape.parameters(); + for (Shape& shape : *argument_layouts) { + LayoutUtil::ClearLayout(&shape); + } + } std::vector argument_layout_pointers; - argument_layout_pointers.reserve(argument_layouts.size()); + argument_layout_pointers.reserve(argument_layouts->size()); // Assign a default layout to any array subshapes that are missing layouts. auto assign_layouts = [client](Shape* shape) { @@ -751,16 +858,11 @@ PyLocalExecutable::Compile(const XlaComputation& computation, }); }; - for (Shape& layout : argument_layouts) { + for (Shape& layout : *argument_layouts) { argument_layout_pointers.push_back(&layout); TF_RETURN_IF_ERROR(assign_layouts(&layout)); } - ExecutableBuildOptions options; - if (build_options != nullptr) { - options = *build_options; - } - Shape result_layout; if (options.result_layout()) { result_layout = *options.result_layout(); @@ -776,14 +878,10 @@ PyLocalExecutable::Compile(const XlaComputation& computation, TF_ASSIGN_OR_RETURN(std::unique_ptr local_executable, client->client()->Compile( computation, argument_layout_pointers, options)); - TF_ASSIGN_OR_RETURN( - DeviceAssignment device_assignment, - client->client()->backend().computation_placer()->AssignDevices( - options.num_replicas(), /*computation_count=*/1)); return absl::make_unique( std::shared_ptr(std::move(local_executable)), - std::move(device_assignment), std::move(client)); + std::move(*device_assignment), std::move(client)); } } // namespace xla diff --git a/tensorflow/compiler/xla/python/local_client.h b/tensorflow/compiler/xla/python/local_client.h index 1ad0f933007..e70567ff6b0 100644 --- a/tensorflow/compiler/xla/python/local_client.h +++ b/tensorflow/compiler/xla/python/local_client.h @@ -169,6 +169,8 @@ class Device { } private: + Status SynchronizeAllActivity(); + bool use_multiple_streams_; bool synchronous_deallocation_; bool asynchronous_; @@ -242,49 +244,67 @@ class PyLocalClient { }; // Holds a reference from Python to one or more device buffers. +// A PyLocalBuffer can be either valid or invalid. An invalid buffer is one that +// has never been initialized, or a buffer that has been deleted (e.g., by +// calling Delete). We allow PyLocalBuffer objects to outlive the underlying +// device buffers so we can decouple buffer lifetimes from the corresponding +// Python references if needed. +// Thread-safe. class PyLocalBuffer { public: - static StatusOr FromPython( + static StatusOr> FromPython( const pybind11::object& argument, std::shared_ptr client, int device_ordinal); // Converts multiple (python object, device ordinal) pairs into // PyLocalBuffers in parallel. - static StatusOr> FromPythonValues( + static StatusOr>> FromPythonValues( const std::vector>& argument, std::shared_ptr client); - static StatusOr MakeTuple( - const std::vector buffers, + static StatusOr> MakeTuple( + const std::vector buffers, std::shared_ptr client, int device_ordinal); PyLocalBuffer() = default; PyLocalBuffer(Shape on_host_shape, std::shared_ptr device_buffer, std::shared_ptr client); + + PyLocalBuffer(const PyLocalBuffer&) = delete; + PyLocalBuffer(PyLocalBuffer&&) = delete; + PyLocalBuffer& operator=(const PyLocalBuffer&) = delete; + PyLocalBuffer& operator=(PyLocalBuffer&&) = delete; + StatusOr ToPython() const; const Shape& on_host_shape() const { return on_host_shape_; } - const std::shared_ptr& device_buffer() const { - return device_buffer_; - } - int device_ordinal() const { return device_buffer_->device_ordinal(); } + int device_ordinal() const { return device_ordinal_; } - void Delete() { - device_buffer_ = nullptr; - client_ = nullptr; - } + // Returns the associated device buffer. Returns a nullptr if the buffer is + // invalid. + std::shared_ptr DeviceBuffer() const; + + // Deletes the device memory associated with this buffer, leaving it in an + // invalid state. + void Delete(); // Returns a view of the PyLocalBuffer DAG as a ShapedBuffer. The // PyLocalBuffer retains ownership of the device buffers. - ShapedBuffer AsShapedBuffer() const; + StatusOr AsShapedBuffer() const; // Destructures a tuple-valued PyLocalBuffer into its constituent elements. - StatusOr> DestructureTuple(); + StatusOr>> DestructureTuple(); + + // Blocks the host until the buffer's value has been computed and is ready for + // immediate use on the device. Useful in particular for timing benchmarks. + Status BlockHostUntilReady(); private: - std::shared_ptr client_ = nullptr; - Shape on_host_shape_; - std::shared_ptr device_buffer_; + const std::shared_ptr client_; + const Shape on_host_shape_; + const int device_ordinal_; + mutable absl::Mutex mu_; + std::shared_ptr device_buffer_ GUARDED_BY(mu_); }; // Represents a compiled computation that can be executed given handles to @@ -293,9 +313,11 @@ class PyLocalExecutable { public: // Compiles a computation to an executable. static StatusOr> Compile( - const XlaComputation& computation, std::vector argument_layouts, + const XlaComputation& computation, + absl::optional> argument_layouts, const ExecutableBuildOptions* build_options, - std::shared_ptr client); + std::shared_ptr client, + absl::optional device_assignment); PyLocalExecutable(std::shared_ptr executable, DeviceAssignment device_assignment, @@ -312,19 +334,19 @@ class PyLocalExecutable { return device_assignment_; } - StatusOr Execute( + StatusOr> Execute( absl::Span argument_handles); // Execute on many replicas. Takes a sequence of argument lists (one argument // list per replica) and returns a tuple of results (one result per replica). // The number of argument lists must be equal to the replica count. - StatusOr> ExecutePerReplica( + StatusOr>> ExecutePerReplica( absl::Span> argument_handles); void Delete() { executable_ = nullptr; } private: - StatusOr ExecuteHelper( + StatusOr> ExecuteHelper( absl::Span argument_handles, int replica); std::shared_ptr const client_; diff --git a/tensorflow/compiler/xla/python/shared_device_buffer.cc b/tensorflow/compiler/xla/python/shared_device_buffer.cc index 8d7ce0088a4..23cf99f682e 100644 --- a/tensorflow/compiler/xla/python/shared_device_buffer.cc +++ b/tensorflow/compiler/xla/python/shared_device_buffer.cc @@ -15,10 +15,20 @@ limitations under the License. #include "tensorflow/compiler/xla/python/shared_device_buffer.h" +#include + #include "tensorflow/stream_executor/device_memory_allocator.h" namespace xla { +/*static*/ StatusOr> +BufferDefinitionEvent::Create(se::StreamExecutor* executor) { + auto event = std::make_shared(executor); + TF_RET_CHECK(event->event_.Init()) + << "Buffer definition event initialization failed"; + return event; +} + BufferDefinitionEvent::BufferDefinitionEvent(se::StreamExecutor* executor) : event_(executor) {} diff --git a/tensorflow/compiler/xla/python/shared_device_buffer.h b/tensorflow/compiler/xla/python/shared_device_buffer.h index 31cab5ade45..98f8e6a9e13 100644 --- a/tensorflow/compiler/xla/python/shared_device_buffer.h +++ b/tensorflow/compiler/xla/python/shared_device_buffer.h @@ -51,6 +51,9 @@ namespace xla { class BufferDefinitionEvent { public: // Creates a new definition event whose event has not yet been triggered. + static StatusOr> Create( + se::StreamExecutor* executor); + explicit BufferDefinitionEvent(se::StreamExecutor* executor); // Records the definition event on the tail of 'stream'. diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index a592b0823be..298a57d32ff 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -59,7 +59,7 @@ Uniquer* GetUniquer() { return uniquer; } -static string UniquifyName(const string& name) { +static std::string UniquifyName(const std::string& name) { Uniquer* uniquer = GetUniquer(); absl::MutexLock lock(&uniquer->mu); return uniquer->name_uniquer.GetUniqueName(name); @@ -246,7 +246,7 @@ PYBIND11_MODULE(xla_extension, m) { // Device assignments py::class_(m, "DeviceAssignment") - .def_static("Create", + .def_static("create", [](py::array_t array) -> StatusOr { if (array.ndim() != 2) { return InvalidArgument( @@ -295,11 +295,12 @@ PYBIND11_MODULE(xla_extension, m) { .def_static("make_tuple", &PyLocalBuffer::MakeTuple) .def("delete", &PyLocalBuffer::Delete) .def("destructure", &PyLocalBuffer::DestructureTuple) + .def("block_host_until_ready", &PyLocalBuffer::BlockHostUntilReady) .def("to_py", &PyLocalBuffer::ToPython) .def("shape", &PyLocalBuffer::on_host_shape) .def("device", &PyLocalBuffer::device_ordinal) .def("is_deleted", [](const PyLocalBuffer& buffer) { - return buffer.device_buffer() == nullptr; + return buffer.DeviceBuffer() == nullptr; }); py::class_(m, "LocalExecutable") @@ -441,10 +442,8 @@ PYBIND11_MODULE(xla_extension, m) { ops.def("Outfeed", &Outfeed, py::arg("operand"), py::arg("shape_with_layout"), py::arg("outfeed_config") = ""); ops.def("Pad", &Pad); - ops.def( - "Parameter", - static_cast( - &Parameter)); + ops.def("Parameter", static_cast(&Parameter)); ops.def("QR", [](XlaOp a, bool full_matrices) -> StatusOr> { TF_ASSIGN_OR_RETURN(auto qr, QRDecomposition(a, full_matrices)); diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index e208cacc19c..4fde9e0da74 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -108,16 +108,14 @@ class LocalBackend(Backend): def compile(self, c_computation, compile_options): options = _xla.ExecutableBuildOptions() options.num_replicas = compile_options.num_replicas - if compile_options.argument_layouts: - argument_layouts = compile_options.argument_layouts - else: - argument_layouts = c_computation.GetProgramShape().parameter_shapes() if compile_options.result_layout: options.result_layout = compile_options.result_layout options.debug_options.xla_cpu_fast_math_honor_infs = True options.debug_options.xla_cpu_fast_math_honor_nans = True - return _xla.LocalExecutable.Compile(c_computation, argument_layouts, - options, self.client) + return _xla.LocalExecutable.Compile(c_computation, + compile_options.argument_layouts, + options, self.client, + compile_options.device_assignment) def _cpu_backend_factory(): @@ -145,7 +143,7 @@ def _gpu_backend_factory(): config.memory_fraction = float(memory_fraction) client = _xla.LocalClient.Get( - platform='gpu', xla_platform_id='CUDA', asynchronous=False, + platform='gpu', xla_platform_id='CUDA', asynchronous=True, allocator_config=config) return LocalBackend(platform='gpu', client=client) @@ -362,6 +360,8 @@ class Buffer(object): # def delete(self): # def destructure(self) -> [Buffer] # def is_deleted(self) -> bool: + # def block_host_until_ready(self): + # """Blocks the calling thread until the buffer is ready on device.""" # # TODO(phawkins): remove Buffer and its static methods completely, have # clients call methods on Backend to create buffers. @@ -419,6 +419,27 @@ def transfer_from_outfeed(shape, device_ordinal=0): shape.with_major_to_minor_layout_if_absent(), device_ordinal) +DeviceAssignment = _xla.DeviceAssignment +DeviceAssignment.__doc__ = """ +A DeviceAssignment is a C++ object with the following signature. + +def create(assignment): + '''Builds a device assignment. + + Args: + assignment: a 2D numpy array of device ordinal integers, indexed by + [replica][computation_in_replica]. + Returns: + A device assignment. + ''' + +def replica_count(): + '''Returns the number of replicas.''' +def computation_count(): + '''Returns the number of computations per replica.''' +""" + + class CompileOptions(object): """Python object for XLA compile options. @@ -436,6 +457,7 @@ class CompileOptions(object): self.num_replicas = 1 self.argument_layouts = None self.result_layout = None + self.device_assignment = None class Computation(object): diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 682a6c099a6..f553601a561 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -520,6 +520,13 @@ class BufferTest(ComputationTest): self.assertEqual(xla_shape.dimensions(), (1, 2)) self.assertEqual(np.dtype(xla_shape.element_type()), np.dtype(np.float32)) + def testBlockHostUntilReadyWorks(self): + arg = np.array([[1., 2.]], np.float32) + arg_buffer = xla_client.Buffer.from_pyval(arg) + arg_buffer.block_host_until_ready() + # This test merely checks that nothing goes awry when we call + # block_host_until_ready(); it's difficult to test anything else. + class SingleOpTest(ComputationTest): """Tests for single ops. diff --git a/tensorflow/compiler/xla/python/xrt.cc b/tensorflow/compiler/xla/python/xrt.cc index 5292de8a079..147aafc356a 100644 --- a/tensorflow/compiler/xla/python/xrt.cc +++ b/tensorflow/compiler/xla/python/xrt.cc @@ -148,8 +148,14 @@ void AddXrtSubmodule(py::module* module) { }) .def("delete", &XrtBuffer::Delete) .def("destructure", &XrtBuffer::DestructureTuple) + .def("device", &XrtBuffer::xrt_device_ordinal) + .def("shape", &XrtBuffer::shape) .def("is_deleted", - [](const XrtBuffer& buffer) { return !buffer.handle().valid(); }); + [](const XrtBuffer& buffer) { return !buffer.handle().valid(); }) + .def("block_host_until_ready", [](const XrtBuffer& buffer) { + return errors::Unimplemented( + "block_host_until_ready not implemented in XRT backend."); + }); py::class_>(m, "XrtExecutable") .def_static("Compile", diff --git a/tensorflow/compiler/xla/python/xrt.py b/tensorflow/compiler/xla/python/xrt.py index 76a99f20481..40dea45e442 100644 --- a/tensorflow/compiler/xla/python/xrt.py +++ b/tensorflow/compiler/xla/python/xrt.py @@ -65,7 +65,7 @@ class XrtBackend(xla_client.Backend): return _xla.xrt.XrtBuffer.from_literal(self.context, device, pyval) def make_tuple(self, buffers, device_ordinal): - return _xla.xrt.XrtBuffer.make_tuple(self.context, buffers) + return _xla.xrt.XrtBuffer.make_tuple(self.context, buffers, device_ordinal) def compile(self, computation, compile_options): # pylint: disable=protected-access diff --git a/tensorflow/compiler/xla/python_api/BUILD b/tensorflow/compiler/xla/python_api/BUILD index d790c4db6c4..348a80abe2c 100644 --- a/tensorflow/compiler/xla/python_api/BUILD +++ b/tensorflow/compiler/xla/python_api/BUILD @@ -1,9 +1,10 @@ # Description: # Python API for XLA. -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//tensorflow:internal"]) +package( + default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], # Apache 2.0 +) py_library( name = "types", diff --git a/tensorflow/compiler/xla/refcounting_hash_map.h b/tensorflow/compiler/xla/refcounting_hash_map.h new file mode 100644 index 00000000000..19b27d6fc3a --- /dev/null +++ b/tensorflow/compiler/xla/refcounting_hash_map.h @@ -0,0 +1,115 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_REFCOUNTING_HASH_MAP_H_ +#define TENSORFLOW_COMPILER_XLA_REFCOUNTING_HASH_MAP_H_ + +#include +#include + +#include "absl/container/node_hash_map.h" +#include "absl/memory/memory.h" +#include "absl/synchronization/mutex.h" + +namespace xla { + +// RefcountingHashMap is an "eager, thread-safe cache". +// +// Given a key k you can retrieve a shared_ptr to a value v. If k is not +// already in the map, we construct a new V; if it is already in the map, we'll +// return the existing v. Once all shared_ptrs are destroyed, the entry is +// removed from the map. +// +// This class is thread-safe. +// +// Word to the wise: You might want an erase() function here that removes a +// value from the map but leaves existing shared_ptrs intact. My experience is, +// this is extremely complicated to implement correctly. +template +class RefcountingHashMap { + public: + // Default-constructs new values. + RefcountingHashMap() + : value_factory_([](const K&) { return absl::make_unique(); }) {} + + // Constructs new values according to the given factory function. + explicit RefcountingHashMap( + std::function(const K&)> value_factory) + : value_factory_(std::move(value_factory)) {} + + // Not copyable or movable because this contains internal pointers (namely, + // instances of Deleter contain pointers to `this` and into `map_`). + RefcountingHashMap(const RefcountingHashMap&) = delete; + RefcountingHashMap(RefcountingHashMap&&) = delete; + RefcountingHashMap& operator=(const RefcountingHashMap&) = delete; + RefcountingHashMap& operator=(RefcountingHashMap&&) = delete; + + // Gets the value for the given key. + // + // If the map doesn't contain a live value for the key, constructs one + // according to the factory passed to the map's constructor. + std::shared_ptr operator[](const K& key) { + absl::MutexLock lock(&mu_); + auto it = map_.find(key); + if (it == map_.end()) { + // Create entry in the map and then set its value, so the value can + // contain a pointer back into the map. + it = map_.emplace(key, std::weak_ptr()).first; + std::shared_ptr value(value_factory_(key).release(), + Deleter{&it->first, this}); + it->second = value; // Set the weak ptr to the shared ptr. + return value; + } + return it->second.lock(); + } + + // Runs a function over every key/value in the map. + // + // Touching the map from within this function may deadlock; don't do it. + // + // Function signature must be compatible with + // void fn(const K&, std::shared_ptr) + // + template + void ForEach(Fn&& fn) { + absl::MutexLock lock(&mu_); + for (const auto& kv : map_) { + fn(kv.first, kv.second.lock()); + } + } + + private: + struct Deleter { + const K* key; // Points into parent->map_. + RefcountingHashMap* parent; + + void operator()(V* v) { + delete v; + absl::MutexLock lock(&parent->mu_); + auto it = parent->map_.find(*key); + CHECK(it != parent->map_.end()); + CHECK(it->second.expired()); + parent->map_.erase(it); + } + }; + + std::function(const K&)> value_factory_; + absl::Mutex mu_; + absl::node_hash_map> map_ GUARDED_BY(mu_); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_REFCOUNTING_HASH_MAP_H_ diff --git a/tensorflow/compiler/xla/refcounting_hash_map_test.cc b/tensorflow/compiler/xla/refcounting_hash_map_test.cc new file mode 100644 index 00000000000..65120ba3df4 --- /dev/null +++ b/tensorflow/compiler/xla/refcounting_hash_map_test.cc @@ -0,0 +1,101 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/refcounting_hash_map.h" + +#include + +#include "tensorflow/compiler/xla/test.h" + +namespace xla { +namespace { + +struct DeleteNotifier { + DeleteNotifier() = default; + DeleteNotifier(const DeleteNotifier&) = delete; + DeleteNotifier& operator=(const DeleteNotifier&) = delete; + DeleteNotifier(DeleteNotifier&& o) noexcept : fn(std::move(o.fn)) { + o.fn = nullptr; + } + DeleteNotifier& operator=(DeleteNotifier&& o) noexcept { + fn = o.fn; + o.fn = nullptr; + return *this; + } + + ~DeleteNotifier() { + if (fn) { + fn(); + } + } + + std::function fn; +}; + +TEST(RefcountingHashMapTest, PointerIdentity) { + RefcountingHashMap m; + std::shared_ptr a = m[0]; + std::shared_ptr b = m[0]; + std::shared_ptr c = m[1]; + EXPECT_EQ(a.get(), b.get()); + EXPECT_NE(a.get(), c.get()); +} + +TEST(RefcountingHashMapTest, DefaultInitialized) { + RefcountingHashMap m; + EXPECT_EQ(*m[42], 0); +} + +TEST(RefcountingHashMapTest, DeletesEagerly) { + RefcountingHashMap m; + bool deleted = false; + auto handle = m[0]; + handle->fn = [&] { deleted = true; }; + EXPECT_FALSE(deleted); + handle = nullptr; + EXPECT_TRUE(deleted); +} + +TEST(RefcountingHashMapTest, CustomFactory) { + RefcountingHashMap m( + [](const int& x) { return absl::make_unique(x + 1); }); + EXPECT_EQ(*m[0], 1); + EXPECT_EQ(*m[100], 101); +} + +TEST(RefcountingHashMapTest, ForEachEmpty) { + RefcountingHashMap m; + int64 count = 0; + m.ForEach([&](const int&, std::shared_ptr) { ++count; }); + EXPECT_EQ(count, 0); +} + +TEST(RefcountingHashMapTest, ForEachNonempty) { + RefcountingHashMap m; + auto a = m[0]; + auto b = m[1]; + + std::vector seen_keys; + std::vector seen_values; + m.ForEach([&](const int& k, std::shared_ptr v) { + seen_keys.push_back(k); + seen_values.push_back(v.get()); + }); + EXPECT_THAT(seen_keys, testing::UnorderedElementsAre(0, 1)); + EXPECT_THAT(seen_values, testing::UnorderedElementsAre(a.get(), b.get())); +} + +} // anonymous namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index 08b78ee2448..59b60e2b9c5 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -161,24 +161,24 @@ ReferenceUtil::ReduceWindow1DGeneric( const std::function& reduce_func, absl::Span window, absl::Span stride, absl::Span> padding) { - std::vector dim_lengths{static_cast(operand.size())}; - std::vector window_counts(window.size(), 0); - std::vector pad_low(window.size(), 0); - for (int64 i = 0; i < window.size(); ++i) { - int64 padded_width = padding[i].first + dim_lengths[i] + padding[i].second; - window_counts[i] = - window_util::StridedBound(padded_width, window[i], stride[i]); - pad_low[i] = padding[i].first; - } - auto result = absl::make_unique>(window_counts[0]); + CHECK_EQ(window.size(), 1); + CHECK_EQ(stride.size(), 1); + CHECK_EQ(padding.size(), 1); + + int64 padded_width = padding[0].first + operand.size() + padding[0].second; + int64 stride_amount = stride[0]; + int64 window_size = window[0]; + int64 result_size = + window_util::StridedBound(padded_width, window_size, stride_amount); + int64 pad_low = padding[0].first; + auto result = absl::make_unique>(result_size); // Do a full 1D reduce window. - for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { - int64 i0_base = i0 * stride[0] - pad_low[0]; - + for (int64 i0 = 0; i0 < result_size; ++i0) { + int64 i0_base = i0 * stride_amount - pad_low; float val = init; - for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) { - if (i0_base + i0_win >= 0 && i0_base + i0_win < dim_lengths[0]) { + for (int64 i0_win = 0; i0_win < window_size; ++i0_win) { + if (i0_base + i0_win >= 0 && i0_base + i0_win < operand.size()) { val = reduce_func(val, operand[i0_base + i0_win]); } } @@ -199,57 +199,6 @@ ReferenceUtil::ReduceWindow1DAdd(absl::Span operand, float init, xla::MakePadding(dim_lengths, window, stride, padding)); } -/* static */ std::unique_ptr> -ReferenceUtil::ReduceWindow2DGeneric( - const Array2D& operand, float init, - const std::function& reduce_func, - absl::Span window, absl::Span stride, - absl::Span> padding) { - std::vector dim_lengths{operand.height(), operand.width()}; - - std::vector window_counts(window.size(), 0); - std::vector pad_low(window.size(), 0); - for (int64 i = 0; i < window.size(); ++i) { - int64 padded_width = padding[i].first + dim_lengths[i] + padding[i].second; - window_counts[i] = - window_util::StridedBound(padded_width, window[i], stride[i]); - pad_low[i] = padding[i].first; - } - auto result = - absl::make_unique>(window_counts[0], window_counts[1]); - - // Do a full 2D reduce window. - for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { - for (int64 i1 = 0; i1 < window_counts[1]; ++i1) { - int64 i0_base = i0 * stride[0] - pad_low[0]; - int64 i1_base = i1 * stride[1] - pad_low[1]; - - float val = init; - for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) { - for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) { - if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 && - i0_base + i0_win < operand.n1() && - i1_base + i1_win < operand.n2()) { - val = reduce_func(val, operand(i0_base + i0_win, i1_base + i1_win)); - } - } - } - (*result)(i0, i1) = val; - } - } - return result; -} - -/* static */ std::unique_ptr> ReferenceUtil::ReduceWindow2DAdd( - const Array2D& operand, float init, absl::Span window, - absl::Span stride, Padding padding) { - const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; - std::vector dim_lengths{operand.height(), operand.width()}; - return ReduceWindow2DGeneric( - operand, init, add_reduce, window, stride, - xla::MakePadding(dim_lengths, window, stride, padding)); -} - /* static */ std::unique_ptr> ReferenceUtil::ReduceWindow3DAdd( const Array3D& operand, float init, absl::Span window, absl::Span stride, Padding padding) { diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h index 8654fbb9b5e..00920aa8e6a 100644 --- a/tensorflow/compiler/xla/reference_util.h +++ b/tensorflow/compiler/xla/reference_util.h @@ -180,9 +180,6 @@ class ReferenceUtil { absl::Span operand, float init, absl::Span window, absl::Span stride, Padding padding); - static std::unique_ptr> ReduceWindow2DAdd( - const Array2D& operand, float init, absl::Span window, - absl::Span stride, Padding padding); static std::unique_ptr> ReduceWindow3DAdd( const Array3D& operand, float init, absl::Span window, absl::Span stride, Padding padding); @@ -196,11 +193,6 @@ class ReferenceUtil { const std::function& reduce_func, absl::Span window, absl::Span stride, absl::Span> padding); - static std::unique_ptr> ReduceWindow2DGeneric( - const Array2D& operand, float init, - const std::function& reduce_func, - absl::Span window, absl::Span stride, - absl::Span> padding); static std::unique_ptr> ReduceWindow4DGeneric( const Array4D& operand, float init, const std::function& reduce_func, diff --git a/tensorflow/compiler/xla/rpc/BUILD b/tensorflow/compiler/xla/rpc/BUILD index 26affbcceb3..a0cb479fbdc 100644 --- a/tensorflow/compiler/xla/rpc/BUILD +++ b/tensorflow/compiler/xla/rpc/BUILD @@ -1,6 +1,7 @@ -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//tensorflow:internal"]) +package( + default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], # Apache 2.0 +) load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 1e7a924e350..be917d6763b 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -9,9 +9,10 @@ load( ) load("//tensorflow:tensorflow.bzl", "tf_cc_test") -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = [":friends"]) +package( + default_visibility = [":friends"], + licenses = ["notice"], # Apache 2.0 +) package_group( name = "friends", @@ -402,6 +403,27 @@ tf_cc_test( ], ) +xla_test( + name = "dynamic_update_slice_test", + srcs = ["dynamic_update_slice_test.cc"], + backends = [ + "cpu", + "gpu", + ], + deps = [ + ":hlo_parser", + "//tensorflow/compiler/xla:execution_options_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service/cpu:cpu_executable", + "//tensorflow/compiler/xla/service/cpu:parallel_task_assignment", + "//tensorflow/compiler/xla/service/cpu:target_machine_features", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + tf_cc_test( name = "dfs_hlo_visitor_with_default_test", srcs = ["dfs_hlo_visitor_with_default_test.cc"], @@ -857,6 +879,7 @@ cc_library( name = "shaped_buffer", srcs = ["shaped_buffer.cc"], hdrs = ["shaped_buffer.h"], + visibility = ["//visibility:public"], deps = [ "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:shape_util", @@ -1127,6 +1150,9 @@ cc_library( ":buffer_value_containers", ":heap_simulator", ":hlo", + ":hlo_alias_analysis", + ":hlo_buffer", + ":hlo_dataflow_analysis", ":hlo_proto", ":logical_buffer", ":tuple_points_to_analysis", diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc index ea56c75b2f2..cc501161ce9 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.cc +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/stream_executor/device_memory_allocator.h" @@ -237,7 +238,7 @@ Status AllocationTracker::DecrementRefCount(se::DeviceMemoryBase device_memory, Allocation& allocation = it->second; TF_RET_CHECK(allocation.ref_count >= 1); if (allocation.ref_count == 1) { - allocation.device_memory.Free(); + TF_RETURN_IF_ERROR(allocation.device_memory.Free()); allocation_map.erase(it); } else { allocation.ref_count--; diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.cc b/tensorflow/compiler/xla/service/ar_crs_combiner.cc index dbd89911d92..24f910caa7c 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner.cc +++ b/tensorflow/compiler/xla/service/ar_crs_combiner.cc @@ -68,29 +68,31 @@ absl::optional ArCrsCombiner::MatchesArCrsPattern( Match(c->root_instruction(), m::Add(m::Parameter(), m::Parameter())); }; - if (!instruction->IsCrossModuleAllReduce() || - !computation_is_addition(instruction->called_computations()[0]) || - instruction->user_count() != 1) { - return absl::nullopt; - } - auto next = instruction->users()[0]; - int64 distance = 1; - while (!next->IsCrossReplicaAllReduce()) { - if (can_ar_move_past_instruction(next)) { - next = next->users()[0]; - } else { - return absl::nullopt; + // We only support combining cross-partition all-reduce where each replica + // belongs to its own group, since the later cross-replica all-reduce combines + // along the replica dimension. + if (instruction->IsCrossModuleAllReduce() && + instruction->replica_groups().size() == num_replicas_ && + computation_is_addition(instruction->called_computations()[0]) && + instruction->user_count() == 1) { + auto next = instruction->users()[0]; + int64 distance = 1; + while (!next->IsCrossReplicaAllReduce()) { + if (can_ar_move_past_instruction(next)) { + next = next->users()[0]; + } else { + return absl::nullopt; + } + ++distance; + } + if (!Cast(next)->IsNoop() && + computation_is_addition(next->called_computations()[0])) { + ArCrsPair pair(instruction, next, distance); + VLOG(2) << "ArCrsPair matching pattern: " << pair.ToString(); + return pair; } - ++distance; - } - if (!Cast(next)->IsNoop() && - computation_is_addition(next->called_computations()[0])) { - ArCrsPair pair(instruction, next, distance); - VLOG(2) << "ArCrsPair matching pattern: " << pair.ToString(); - return pair; - } else { - return absl::nullopt; } + return absl::nullopt; } absl::optional ArCrsCombiner::WhileFromBodyParameter( @@ -238,7 +240,7 @@ bool ArCrsCombiner::TupleElementsComputeSameValue( /* static */ bool ArCrsCombiner::TestInstructionsComputeSameValue(HloInstruction* i1, HloInstruction* i2) { - ArCrsCombiner combiner(/*num_spatial_partitions=*/2); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/1); auto module = i1->parent()->parent(); CHECK_EQ(module, i2->parent()->parent()); combiner.call_graph_ = CallGraph::Build(module); diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.h b/tensorflow/compiler/xla/service/ar_crs_combiner.h index 4d17d5d8a31..250252b6390 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner.h +++ b/tensorflow/compiler/xla/service/ar_crs_combiner.h @@ -69,8 +69,9 @@ namespace xla { // class ArCrsCombiner : public HloModulePass { public: - ArCrsCombiner(int num_spatial_partitions) - : num_spatial_partitions_(num_spatial_partitions) {} + ArCrsCombiner(int num_spatial_partitions, int num_replicas) + : num_spatial_partitions_(num_spatial_partitions), + num_replicas_(num_replicas) {} absl::string_view name() const override { return "ar-crs-combiner"; } StatusOr Run(HloModule* module) override; @@ -160,6 +161,8 @@ class ArCrsCombiner : public HloModulePass { int num_spatial_partitions_; + int num_replicas_; + // Map from all-reduce ids to the AR/CRS pairs. absl::flat_hash_map> all_reduce_map_; diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc index 0ea26f63b95..0be31899d53 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc +++ b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc @@ -452,7 +452,7 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) { auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); - ArCrsCombiner combiner(2); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_TRUE(changed); EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -520,7 +520,7 @@ ENTRY %entrycomp (p: f32[2,1]) -> (f32[2], f32[2]) { auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); - ArCrsCombiner combiner(2); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_TRUE(changed); EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -587,7 +587,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); - ArCrsCombiner combiner(2); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_TRUE(changed); EXPECT_THAT( @@ -668,7 +668,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); - ArCrsCombiner combiner(2); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_TRUE(changed); EXPECT_THAT( @@ -750,7 +750,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(module_str)); - ArCrsCombiner combiner(2); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_FALSE(changed); } @@ -810,7 +810,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); - ArCrsCombiner combiner(2); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_TRUE(changed); EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -884,7 +884,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); - ArCrsCombiner combiner(2); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_TRUE(changed); EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -957,7 +957,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); - ArCrsCombiner combiner(2); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_TRUE(changed); EXPECT_THAT( @@ -1047,7 +1047,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); - ArCrsCombiner combiner(2); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_TRUE(changed); EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -1139,7 +1139,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); - ArCrsCombiner combiner(2); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_TRUE(changed); EXPECT_THAT( @@ -1217,7 +1217,7 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(module_str)); - ArCrsCombiner combiner(2); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_FALSE(changed); } @@ -1264,5 +1264,37 @@ ENTRY Parameters1.v4 { EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(f0, f1)); } +TEST_F(ArCrsCombinerTest, AllReduceWithReplicas) { + const char* module_str = R"( +HloModule foobar + +%sum.f32 (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) { + %p = bf16[] parameter(0) + %all-reduce.0 = f32[] all-reduce(%p), all_reduce_id=1, replica_groups={{0,1}}, + to_apply=%sum.f32, sharding={maximal device=0} + %all-reduce.1 = f32[] all-reduce(%p), all_reduce_id=1, replica_groups={{0,1}}, + to_apply=%sum.f32, sharding={maximal device=1} + %all-reduce.2 = f32[] all-reduce(%all-reduce.0), replica_groups={{0,1}}, + to_apply=%sum.f32, sharding={maximal device=0} + %all-reduce.3 = f32[] all-reduce(%all-reduce.1), replica_groups={{0,1}}, + to_apply=%sum.f32, sharding={maximal device=1} + ROOT %tuple = (f32[], f32[]) tuple(%all-reduce.2, %all-reduce.3), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_FALSE(changed); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc index d859f647ea0..40283d12314 100644 --- a/tensorflow/compiler/xla/service/backend.cc +++ b/tensorflow/compiler/xla/service/backend.cc @@ -126,16 +126,11 @@ Backend::Backend(se::Platform* platform, Compiler* compiler, : platform_(platform), compiler_(compiler), transfer_manager_(transfer_manager), - computation_placer_(computation_placer) { - // The given set of stream executors set may include invalid executors. - for (se::StreamExecutor* exec : stream_executors) { - if (exec != nullptr) { - stream_executors_.push_back(exec); - } - } + computation_placer_(computation_placer), + stream_executors_(stream_executors.begin(), stream_executors.end()) { // Create a memory allocator for the valid stream executors. memory_allocator_ = absl::make_unique( - platform, stream_executors); + platform, stream_executors_); CHECK(!stream_executors_.empty()) << "Service found no devices for backend " << platform_->Name() << '.'; diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index aa57f28448e..5cbe6c44622 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include #include "absl/container/flat_hash_map.h" @@ -32,6 +31,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_value_containers.h" #include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_alias_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_buffer.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -50,30 +51,6 @@ using absl::StrAppend; using absl::StrAppendFormat; using ::tensorflow::strings::HumanReadableNumBytes; -template -string ColocatedBufferSetsToString(const T& container, const char* title) { - string result; - StrAppend(&result, title, "\n"); - for (const auto& it : container) { - StrAppend(&result, "\t", it->ToString(), "\n"); - } - return result; -} - -// Checks that points-to set of 'instruction' is unambiguous and distinct -// (ensured by CopyInsertion), then adds the buffer from the points-to set at -// 'index' to 'colocated_set'. -const LogicalBuffer* AddBufferToColocatedSet( - const HloInstruction* instruction, const ShapeIndex& index, - const TuplePointsToAnalysis& points_to_analysis, - std::vector* colocated_set) { - // CopyInsertion ensures root points-to set is unambiguous and distinct. - const auto& points_to = points_to_analysis.GetPointsToSet(instruction); - DCHECK(!points_to.IsAmbiguous()); - colocated_set->push_back(points_to.element(index)[0]); - return colocated_set->back(); -} - // Given the interference map of a graph (the list of interfering node indices // for each node), perform graph coloring such that interfering nodes are // assigned to different colors. Returns the assigned color of the nodes, where @@ -226,14 +203,15 @@ string BufferAllocation::Slice::ToString() const { } BufferAllocation::Slice BufferAllocation::GetSlice( - const LogicalBuffer& buffer) const { + const BufferValue& buffer) const { const OffsetSize os = FindOrDie(assigned_buffers_, &buffer); return Slice(this, os.offset, os.size); } -void BufferAllocation::AddAssignment(const LogicalBuffer& buffer, int64 offset, +void BufferAllocation::AddAssignment(const BufferValue& buffer, int64 offset, int64 size) { - VLOG(4) << "Trying to add " << buffer << " to allocation #" << index(); + VLOG(4) << "Adding the following buffer to allocation #" << index() << ": " + << buffer; CHECK(!assigned_buffers_.contains(&buffer)) << "LogicalBuffer " << buffer << " already assigned to allocation " << index_; @@ -306,15 +284,14 @@ string BufferAllocation::ToString() const { } StrAppend(&output, ":\n"); // Dump the assigned buffers ordered by id. - std::vector sorted_buffers; + std::vector sorted_buffers; for (const auto& buffer_offset_size : assigned_buffers_) { sorted_buffers.push_back(buffer_offset_size.first); } - absl::c_sort(sorted_buffers, - [](const LogicalBuffer* a, const LogicalBuffer* b) { - return a->id() < b->id(); - }); - for (const LogicalBuffer* buffer : sorted_buffers) { + absl::c_sort(sorted_buffers, [](const BufferValue* a, const BufferValue* b) { + return a->id() < b->id(); + }); + for (const BufferValue* buffer : sorted_buffers) { const OffsetSize& offset_size = FindOrDie(assigned_buffers_, buffer); StrAppend(&output, absl::StrFormat( " %s [%d,%d]: %s\n", buffer->ToString(), @@ -339,28 +316,37 @@ const PointsToSet& BufferAssignment::GetPointsToSet( return points_to_analysis().GetPointsToSet(instruction); } -bool BufferAssignment::HasAllocation(const LogicalBuffer& buffer) const { - TF_CHECK_OK(points_to_analysis().VerifyBuffer(buffer)); - return allocation_index_for_buffer_.contains(&buffer); +bool BufferAssignment::HasAllocation(const BufferValue& value) const { + return allocation_index_for_value_.contains(&value); +} + +bool BufferAssignment::HasAllocation(const HloBuffer& buffer) const { + return allocation_index_for_value_.contains(buffer.values()[0]); } const BufferAllocation& BufferAssignment::GetAssignedAllocation( - const LogicalBuffer& buffer) const { - CHECK(HasAllocation(buffer)); - return GetAllocation(allocation_index_for_buffer_.at(&buffer)); + const BufferValue& value) const { + CHECK(HasAllocation(value)); + return GetAllocation(allocation_index_for_value_.at(&value)); +} + +const BufferAllocation& BufferAssignment::GetAssignedAllocation( + const HloBuffer& hlo_buffer) const { + return GetAssignedAllocation(*hlo_buffer.values()[0]); } BufferAllocation* BufferAssignment::GetMutableAssignedAllocation( - const LogicalBuffer& buffer) { + const HloBuffer& buffer) { return const_cast(&GetAssignedAllocation(buffer)); } std::set BufferAssignment::GetAllSlices( const HloInstruction* instruction, const ShapeIndex& index) const { std::set result; - for (const LogicalBuffer* buffer : GetSourceBuffers(instruction, index)) { - if (HasAllocation(*buffer)) { - result.insert(GetAssignedAllocation(*buffer).GetSlice(*buffer)); + for (const BufferValue* value : + dataflow_analysis().GetValueSet(instruction, index).values()) { + if (HasAllocation(*value)) { + result.insert(GetAssignedAllocation(*value).GetSlice(*value)); } } return result; @@ -375,15 +361,15 @@ const BufferAllocation& BufferAssignment::GetAllocation( const BufferAllocation* BufferAssignment::GetInstructionAllocation( const HloInstruction* hlo, const ShapeIndex& shape_index) const { - const PointsToSet& points_to_set = points_to_analysis().GetPointsToSet(hlo); - const LogicalBuffer* buffer = points_to_set.element(shape_index)[0]; + const BufferValue* value = + dataflow_analysis().GetValueSet(hlo, shape_index).values()[0]; - if (!HasAllocation(*buffer)) { + if (!HasAllocation(*value)) { return nullptr; } const BufferAllocation& instruction_allocation = - GetAssignedAllocation(*buffer); + GetAssignedAllocation(*value); return &instruction_allocation; } @@ -394,9 +380,9 @@ BufferAllocation* BufferAssignment::GetMutableAllocation( bool BufferAssignment::HasAllocationAt(const HloInstruction* instruction, const ShapeIndex& index) const { - for (const LogicalBuffer* buffer : - GetPointsToSet(instruction).element(index)) { - if (allocation_index_for_buffer_.contains(buffer)) { + for (const BufferValue* value : + dataflow_analysis().GetValueSet(instruction, index).values()) { + if (allocation_index_for_value_.contains(value)) { return true; } } @@ -413,13 +399,13 @@ StatusOr BufferAssignment::GetUniqueSlice( VLOG(3) << "Trying to find unique slice for " << instruction->name() << " [" << index << "]"; BufferAllocation::Slice result; - for (const LogicalBuffer* buffer : - GetPointsToSet(instruction).element(index)) { - VLOG(3) << "Examining buffer " << *buffer; - if (HasAllocation(*buffer)) { + for (const BufferValue* value : + dataflow_analysis().GetValueSet(instruction, index).values()) { + VLOG(3) << "Examining value " << *value; + if (HasAllocation(*value)) { VLOG(3) << "Has allocation"; const BufferAllocation::Slice slice = - GetAssignedAllocation(*buffer).GetSlice(*buffer); + GetAssignedAllocation(*value).GetSlice(*value); if (result.allocation() == nullptr) { result = slice; } else if (result != slice) { @@ -500,39 +486,55 @@ BufferAllocation* BufferAssignment::NewEmptyAllocation( return allocation; } -BufferAllocation* BufferAssignment::NewAllocation(const LogicalBuffer& buffer, +BufferAllocation* BufferAssignment::NewAllocation(const HloBuffer& buffer, int64 size) { BufferAllocation* allocation = NewEmptyAllocation(size, buffer.color()); AddAssignment(allocation, buffer, /*offset=*/0, size); - allocation->peak_buffers_.push_back(&buffer); + allocation->peak_buffers_.push_back(buffer.values()[0]); return allocation; } -// Adds an instruction to the set assigned to the given buffer. void BufferAssignment::AddAssignment(BufferAllocation* allocation, - const LogicalBuffer& buffer, int64 offset, + const HloBuffer& buffer, int64 offset, int64 size) { - CHECK(!allocation_index_for_buffer_.contains(&buffer)) - << "LogicalBuffer " << buffer << " already has an allocation."; CHECK(allocation->is_reusable() || allocation->assigned_buffers().empty()) << "Non-reusable allocation already assigned a buffer: " << allocation->ToString(); - TF_CHECK_OK(points_to_analysis().VerifyBuffer(buffer)); + for (const BufferValue* buffer_value : buffer.values()) { + CHECK(!allocation_index_for_value_.contains(buffer_value)) + << "BufferValue " << buffer_value << " already has an allocation."; + allocation->AddAssignment(*buffer_value, offset, size); + allocation_index_for_value_[buffer_value] = allocation->index(); + } - allocation->AddAssignment(buffer, offset, size); - if (liveness().MaybeLiveOut(buffer)) { + if (alias_analysis().BufferLivesOut(buffer)) { + VLOG(3) << "HloBuffer lives out" << buffer.ToString(); + VLOG(3) << "Set maybe live out: " << allocation->ToString(); + allocation->set_maybe_live_out(true); + } +} + +void BufferAssignment::AddAssignment(BufferAllocation* allocation, + const BufferValue& value, int64 offset, + int64 size) { + allocation->AddAssignment(value, offset, size); + allocation_index_for_value_[&value] = allocation->index(); + const HloValue& hlo_value = + *CHECK_NOTNULL(dynamic_cast(&value)); + if (alias_analysis().ValueLivesOut(hlo_value)) { + VLOG(3) << "HloValue lives out: " << hlo_value.ToString(); + VLOG(3) << "Set maybe live out: " << allocation->ToString(); allocation->set_maybe_live_out(true); } - allocation_index_for_buffer_[&buffer] = allocation->index(); } // Combines allocations of temporary buffers of the same color into one big // BufferAllocation. void BufferAssignment::CombineTempAllocations() { VLOG(1) << "CombineTempAllocations()"; - flat_hash_map + flat_hash_map combined_allocation_map; // Move all temp allocations into a single run at the end of the allocations @@ -548,7 +550,7 @@ void BufferAssignment::CombineTempAllocations() { if (first_temp_it != allocations_.end()) { for (auto it = first_temp_it; it != allocations_.end(); ++it) { const BufferAllocation& temp_allocation = *it; - LogicalBuffer::Color color = temp_allocation.color(); + BufferValue::Color color = temp_allocation.color(); auto combined_it = combined_allocation_map.find(color); if (combined_it == combined_allocation_map.end()) { // We have found the first temp allocation of this color. Collect @@ -571,15 +573,16 @@ void BufferAssignment::CombineTempAllocations() { RoundUpToNearest(combined_allocation->size(), alignment); combined_allocation->set_size(base + temp_allocation.size()); for (const auto& buffer_offset_size : temp_allocation.assigned_buffers_) { - const LogicalBuffer* buffer = buffer_offset_size.first; + const BufferValue* value = buffer_offset_size.first; const int64 offset = buffer_offset_size.second.offset; const int64 size = buffer_offset_size.second.size; - combined_allocation->AddAssignment(*buffer, base + offset, size); + combined_allocation->AddAssignment(*value, base + offset, size); } if (!temp_allocation.HeapTraces().empty()) { CHECK_EQ(temp_allocation.HeapTraces().size(), 1); combined_allocation->AddHeapTrace(temp_allocation.HeapTraces().front()); } + combined_allocation->peak_buffers_.insert( combined_allocation->peak_buffers_.end(), temp_allocation.peak_buffers_.begin(), @@ -595,14 +598,14 @@ void BufferAssignment::CombineTempAllocations() { } // Update allocation indices to their new positions. - allocation_index_for_buffer_.erase(allocation_index_for_buffer_.begin(), - allocation_index_for_buffer_.end()); + allocation_index_for_value_.erase(allocation_index_for_value_.begin(), + allocation_index_for_value_.end()); for (size_t index = 0; index < allocations_.size(); ++index) { BufferAllocation* allocation = &allocations_[index]; allocation->set_index(index); for (const auto& buffer_offset_size : allocation->assigned_buffers_) { - const LogicalBuffer* buffer = buffer_offset_size.first; - allocation_index_for_buffer_[buffer] = index; + const BufferValue* value = buffer_offset_size.first; + allocation_index_for_value_[value] = index; } } } @@ -694,30 +697,28 @@ string BufferAssignment::ToString() const { BufferAssignmentProto BufferAssignment::ToProto() const { BufferAssignmentProto proto; - // NOTE: TuplePointsToAnalysis state is serialized here in BufferAssigment, + // NOTE: DataflowAnalysis state is serialized here in BufferAssignment, // because we need to do the HasAllocation check for each buffer. Otherwise // the buffer_size_ call might fail for some backends. - const TuplePointsToAnalysis& points_to_analysis = - liveness_->points_to_analysis(); - for (LogicalBuffer::Id id = 0; id < points_to_analysis.num_logical_buffers(); - id++) { - auto& buffer = points_to_analysis.logical_buffer(id); - if (HasAllocation(buffer)) { - LogicalBufferProto proto_buffer = buffer.ToProto(buffer_size_); + const HloDataflowAnalysis& dataflow = this->dataflow_analysis(); + for (BufferValue::Id id = 0; id < dataflow.values().size(); id++) { + auto& value = dataflow.values().at(id); + if (HasAllocation(*value)) { + LogicalBufferProto proto_buffer = value->ToProto(buffer_size_); proto.add_logical_buffers()->Swap(&proto_buffer); // Fill buffer aliases. - for (const BufferAlias& alias : - points_to_analysis.GetBufferAliases(buffer)) { - if (alias.instruction() == buffer.instruction() && - alias.index() == buffer.index()) { + for (const HloValue* alias : + alias_analysis().GetBufferContainingValue(*value).values()) { + if (alias->instruction() == value->instruction() && + alias->index() == value->index()) { continue; // skip self-aliases } BufferAssignmentProto::BufferAlias* proto_alias = proto.add_buffer_aliases(); LogicalBufferProto::Location proto_alias_location = - BufferValue::ToLocationProto(*alias.instruction(), alias.index()); - proto_alias->set_source_buffer_id(buffer.id()); + BufferValue::ToLocationProto(*alias->instruction(), alias->index()); + proto_alias->set_source_buffer_id(value->id()); proto_alias->mutable_location()->Swap(&proto_alias_location); } } @@ -735,114 +736,70 @@ BufferAssignmentProto BufferAssignment::ToProto() const { /* static */ StatusOr> BufferAssigner::Run( const HloModule* module, std::unique_ptr hlo_ordering, - LogicalBuffer::SizeFunction buffer_size, + BufferValue::SizeFunction buffer_size, LogicalBuffer::AlignmentFunction color_alignment, - bool allow_input_output_aliasing, bool allocate_buffers_for_constants, - BufferLiveness::Colorer colorer, ReuseAllocationFunction reuse_checker, - ReuseColocatedAllocationForTempChecker reuse_colocated_checker) { + bool allocate_buffers_for_constants, BufferAssigner::Colorer colorer, + const absl::flat_hash_set& reuse_checker, + HloDataflowAnalysis::FusionCanShareBufferFunction fusion_can_share_buffer) { BufferAssigner assigner(allocate_buffers_for_constants, std::move(colorer), - std::move(reuse_checker), - std::move(reuse_colocated_checker)); - return assigner.CreateAssignment(module, std::move(hlo_ordering), - std::move(buffer_size), - std::move(color_alignment)); + reuse_checker); + return assigner.CreateAssignment( + module, std::move(hlo_ordering), std::move(buffer_size), + std::move(color_alignment), std::move(fusion_can_share_buffer)); } namespace { -// a and b are in different subcomputations. Check for the case -// where a is inside the while body, and b is outside, part of the same while's -// init-operand or while-result. -bool MayInterfereAcrossSubcomputations(BufferAssignment* assignment, - const LogicalBuffer& a_buffer, - const LogicalBuffer& b_buffer) { - const CallGraph& call_graph = - assignment->liveness().hlo_ordering().call_graph(); - const HloInstruction* a_ancestor; - const HloInstruction* b_ancestor; - std::tie(a_ancestor, b_ancestor) = - call_graph.NearestAncestorsInSameComputation(a_buffer.instruction(), - b_buffer.instruction()); - if (a_ancestor == nullptr) { - // No common ancestor. - return true; - } - if (a_ancestor->opcode() == HloOpcode::kWhile && - call_graph.InstructionIsNestedIn(a_buffer.instruction(), - a_ancestor->while_body())) { - const PointsToSet& init_set = - assignment->liveness().points_to_analysis().GetPointsToSet( - a_ancestor->operand(0)); - if (init_set.ContainsBuffer(b_buffer)) { - VLOG(4) << "Can't interfere: " << a_buffer << " and " << b_buffer - << " (part of while-operand)"; - return false; - } - const PointsToSet& while_set = - assignment->liveness().points_to_analysis().GetPointsToSet(a_ancestor); - if (while_set.ContainsBuffer(b_buffer)) { - VLOG(4) << "Can't interfere: " << a_buffer << " and " << b_buffer - << " (part of while)"; - return false; +void ConvertHeapSimulatorResultToHloValue( + HeapSimulator::Result* result, const HloDataflowAnalysis& dataflow_analysis, + const TuplePointsToAnalysis& points_to) { + absl::flat_hash_map + chunk_map_with_hlo_value; + for (auto& value_to_chunk : result->chunk_map) { + const BufferValue* value = value_to_chunk.first; + if (!dataflow_analysis.ValueIsDefinedAt(value->instruction(), + value->index())) { + continue; } + const HloValue& hlo_value = dataflow_analysis.GetValueDefinedAt( + value->instruction(), value->index()); + chunk_map_with_hlo_value[&hlo_value] = value_to_chunk.second; } - return true; -} - -// Return true, if a and b can't possibly interfere (and therefore further -// checking for interference can be skipped). This function checks for special -// cases where copy insertion guarantees no interference, but the regular buffer -// liveness is too conservative: -// -// Operations inside a while-body can't interfere with operations outside the -// while op if their last use is at the while-loop itself as part of the -// while-init op, or the while-result. For ops that are live across a -// while-loop, copy insertion will already insert the necessary copies to avoid -// such interference. -// -// This allows sharing buffers in cases like this: -// init = {...} -// while (init): -// p = param(0) -// gte = get-tuple-element(p), index=i -// t1 = op1 (gte) -// t2 = op2 (t1) -// ROOT tuple = {..., t2, ...} -// -// where t1 and t2 can share the same buffer. -bool MaySkipInterferenceCheck(BufferAssignment* assignment, - const LogicalBuffer& a_buffer, - const LogicalBuffer& b_buffer) { - if (a_buffer.instruction()->parent() == b_buffer.instruction()->parent()) { - // Ops within the same computation are not handled here. Assume that they - // may interfere. - return false; + result->chunk_map = chunk_map_with_hlo_value; + // Set up debug trace. + for (int64 i = 0; i < result->debug_trace.events_size(); ++i) { + int64 buffer_id = result->debug_trace.mutable_events(i)->buffer_id(); + const LogicalBuffer& logical_buffer = points_to.GetBuffer(buffer_id); + const HloValue* hlo_value = + dataflow_analysis + .GetValueSet(logical_buffer.instruction(), logical_buffer.index()) + .values()[0]; + result->debug_trace.mutable_events(i)->set_buffer_id(hlo_value->id()); } - return !MayInterfereAcrossSubcomputations(assignment, a_buffer, b_buffer) || - !MayInterfereAcrossSubcomputations(assignment, b_buffer, a_buffer); } } // namespace bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, - const LogicalBuffer& buffer, + const HloBuffer& hlo_buffer, BufferAssignment* assignment) { - const LogicalBuffer::SizeFunction& buffer_size = assignment->buffer_size_; + CHECK(!assignment->HasAllocation(hlo_buffer)) + << "buffer " << hlo_buffer << " already has an allocation assigned."; - CHECK(!assignment->HasAllocation(buffer)) - << "buffer " << buffer << " already has an allocation assigned."; + VLOG(4) << "Trying to assign " << hlo_buffer << " size " + << assignment->HloBufferSize(hlo_buffer) + << " to allocation: " << *allocation; - VLOG(4) << "Trying to assign " << buffer << " to allocation: " << *allocation; - - if (buffer.color() != allocation->color()) { - VLOG(4) << "Can't assign: buffer has color" << buffer.color() + if (hlo_buffer.color() != allocation->color()) { + VLOG(4) << "Can't assign: buffer has color" << hlo_buffer.color() << " and allocation has color " << allocation->color() << "."; return false; } - if (buffer_size(buffer) > allocation->size()) { + if (assignment->HloBufferSize(hlo_buffer) > allocation->size()) { VLOG(4) << "Can't assign: buffer is larger than allocation (" - << buffer_size(buffer) << " > " << allocation->size() << ")"; + << assignment->HloBufferSize(hlo_buffer) << " > " + << allocation->size() << ")"; return false; } @@ -851,10 +808,33 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, return false; } - if (reuse_checker_ != nullptr && - !reuse_checker_(*assignment, *allocation, buffer)) { - VLOG(4) << "Can't assign: reuse_checker_(allocation, buffer) == false"; - return false; + if (!must_not_live_out_.empty()) { + if (allocation->maybe_live_out()) { + // If a buffer maybe live out, the allocation cannot contain any node from + // the "must_not_live_out_" set. + for (const HloValue* value : hlo_buffer.values()) { + if (must_not_live_out_.count(value->instruction()->opcode()) > 0) { + VLOG(4) << "Can't assign: " << value->instruction()->ToString() + << " cannot live out of the module"; + return false; + } + } + } + // The above check is not enough -- There could be the case where an + // allocation can be not live out and contains an instruction with opcode + // from the "must_not_live_out_" set, but assigning a live out buffer to + // that allocation makes the allocation live out and also contains + // instruction from the "must_not_live_out_" set. + if (assignment->alias_analysis().BufferLivesOut(hlo_buffer)) { + for (const auto& buffer_offset_size : allocation->assigned_buffers()) { + if (must_not_live_out_.count( + buffer_offset_size.first->instruction()->opcode()) > 0) { + VLOG(4) << "Can't assign: " << buffer_offset_size.first->instruction() + << " cannot live out of the module"; + return false; + } + } + } } if (!allocation->is_reusable()) { @@ -863,299 +843,218 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, } for (const auto& buffer_offset_size : allocation->assigned_buffers()) { - const LogicalBuffer& assigned_buffer = *buffer_offset_size.first; - if (MaySkipInterferenceCheck(assignment, buffer, assigned_buffer)) { - continue; - } - if (assignment->liveness().MayInterfere(assigned_buffer, buffer)) { - VLOG(4) << "Can't assign: assignee " << assigned_buffer - << " may interfere with " << buffer; - return false; - } - // Copy instruction don't share a buffer with their input operand. - if (buffer.instruction()->IsUserOf(assigned_buffer.instruction()) && - buffer.instruction()->opcode() == HloOpcode::kCopy) { - VLOG(4) << "Can't assign: assignee " << assigned_buffer - << " is used at copy instruction " << buffer; - return false; + // Pairwise compare. + const HloValue& assigned_buffer = + *CHECK_NOTNULL(dynamic_cast(buffer_offset_size.first)); + for (const HloValue* new_value : hlo_buffer.values()) { + if (assignment->liveness().hlo_ordering().MayInterfere( + assigned_buffer, *new_value, assignment->dataflow_analysis())) { + VLOG(4) << "Can't assign: assignee " << assigned_buffer + << " may interfere with " << new_value; + return false; + } + + for (const HloPosition& assgiend_buffer_position : + assigned_buffer.positions()) { + // Copy instruction don't share a buffer with their input operand. + if (new_value->instruction()->IsUserOf( + assgiend_buffer_position.instruction) && + new_value->instruction()->opcode() == HloOpcode::kCopy) { + VLOG(4) << "Can't assign: assignee " << assigned_buffer + << " is used at copy instruction " << new_value; + return false; + } + } } } // If the buffer is live out of the computation then it should only be // assigned a buffer which exactly fits the result to avoid wasting memory // (result buffers can have arbitrary lifetimes). - if (assignment->liveness().MaybeLiveOut(buffer) && - allocation->size() != buffer_size(buffer)) { - VLOG(4) << "Can't assign: buffer " << buffer + if (assignment->alias_analysis().BufferLivesOut(hlo_buffer) && + allocation->size() != assignment->HloBufferSize(hlo_buffer)) { + VLOG(4) << "Can't assign: buffer " << hlo_buffer << "is live out and size not the same as allocation"; return false; } - assignment->AddAssignment(allocation, buffer, /*offset=*/0, - buffer_size(buffer)); + assignment->AddAssignment(allocation, hlo_buffer, /*offset=*/0, + assignment->HloBufferSize(hlo_buffer)); return true; -} +} // namespace xla -Status BufferAssigner::AssignBuffersForComputation( - const HloComputation* computation, bool is_thread_local, - const flat_hash_set& colocated_buffers, - const flat_hash_set& colocated_allocations, - flat_hash_map>* - buffers_to_assign_sequentially, - BufferAssignment* assignment) { - // Buffers are sorted and assigned to BufferAllocations in decreasing order of - // size. - std::vector sorted_buffers; - for (auto* instruction : computation->instructions()) { - // Add all buffers which this instruction defines. Instruction which don't - // define buffers (eg, bitcast which just forwards a pointer) don't need - // any allocations. - for (const LogicalBuffer* buffer : - assignment->points_to_analysis().GetBuffersDefinedByInstruction( - instruction)) { - sorted_buffers.push_back(buffer); - } - } - - // Generate a post order sort of instructions for sorting of the - // LogicalBuffers. - flat_hash_map post_order_position; - int position = 0; - for (auto* instruction : computation->MakeInstructionPostOrder()) { - post_order_position.emplace(instruction, position); - position++; - } - - // If there is a sequential instruction ordering, we'll delay assignment of - // temp buffers until after the main assignment loop. - const BufferLiveness& liveness = assignment->liveness(); - const bool has_sequential_order = - liveness.hlo_ordering().SequentialOrder(*computation) != nullptr; - if (has_sequential_order && buffers_to_assign_sequentially != nullptr) { - // Every sequential computation must get an entry in the - // buffers_to_assign_sequentially map, even if we end up with an empty set - // of buffers. This ensures we can correctly determine whether to run - // whole-module heap simulation. - buffers_to_assign_sequentially->emplace( - computation, flat_hash_set()); - } - - // Sort the LogicalBuffers first by size. We assign the larger LogicalBuffers - // first for simplicity. This means any previously created BufferAllocation is - // necessarily large enough to hold the output of the current Buffer in - // consideration. +Status BufferAssigner::MergeInplaceOpBuffers(BufferAssignment* assignment) { + // Try allocate same buffer for dynamic update slice's operand and output. // - // As a secondary sorting criteria, if the instructions are sequentially - // ordered, we assign live-out buffers before others. Note that for sequential - // computations, we'll take temp buffers that can't re-use any allocations and - // assign them via a heap scheduler. By assigning live-out buffers first, we - // increase the odds that temp buffers can re-use an allocation. - // - // As a final tiebreaker use post order position of the HLO instruction which - // defines the buffer. This means an instruction will appear after its - // operands (assuming operands are the same/larger size) enabling the - // important reuse case where an elementwise instruction reuses one of its - // operand's buffer. This improves locality. - absl::c_sort(sorted_buffers, - [has_sequential_order, &liveness, &post_order_position, - assignment](const LogicalBuffer* a, const LogicalBuffer* b) { - // Primary sort is by decreasing buffer size. - const int64 a_size = assignment->buffer_size_(*a); - const int64 b_size = assignment->buffer_size_(*b); - if (a_size != b_size) { - return a_size > b_size; // use ">" for decreasing size. - } - // Otherwise live out buffers come before others, if the - // instructions are sequentially ordered. - if (has_sequential_order) { - const bool a_live_out = liveness.MaybeLiveOut(*a); - const bool b_live_out = liveness.MaybeLiveOut(*b); - if (a_live_out != b_live_out) { - return a_live_out; - } - } - // Final tiebreaker is in instruction post order. - return post_order_position.at(a->instruction()) < - post_order_position.at(b->instruction()); - }); + // TODO(yunxing): Moving this logic to alias analysis and add must-alias rule + // to operations that can be done in place. + for (HloComputation* computation : assignment->module().computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (!(instruction->opcode() == HloOpcode::kDynamicUpdateSlice || + (instruction->opcode() == HloOpcode::kFusion && + (instruction->fused_expression_root()->opcode() == + HloOpcode::kDynamicUpdateSlice)))) { + continue; + } + if (instruction->parent()->IsFusionComputation()) { + continue; + } + if (instruction->operand_count() == 0) { + continue; + } + // Can't share the buffer. + if (!assignment->dataflow_analysis().CanShareOperandBufferWithUser( + instruction->mutable_operand(0), {}, instruction, {})) { + continue; + } + HloBuffer& instruction_buffer = + assignment->alias_analysis().GetUniqueBufferAt(instruction, {}); - // BufferAllocations are necessarily created in decreasing size order. Keep - // indices of previously created BufferAllocations in new_allocation_indices. - std::vector new_allocation_indices; + HloBuffer& operand_buffer = + assignment->alias_analysis().GetUniqueBufferAt( + instruction->operand(0), {}); - // A sorted multimap from size to indices of colocated allocations. - std::multimap - colocated_allocation_size_to_indices; - { - std::priority_queue sorted_colocated_indices; - for (auto index : colocated_allocations) { - bool consider_reusing = true; - // Output tuple table may be allocated at run-time, so make sure we don't - // overwrite them. - for (const auto& buffer_offset_size : - assignment->GetAllocation(index).assigned_buffers()) { - if (buffer_offset_size.first->shape().IsTuple()) { - consider_reusing = false; - break; + // Already have the same buffer. No need to merge those. + if (instruction_buffer.id() == operand_buffer.id()) { + continue; + } + + bool interfere = false; + + for (const HloValue* instruction_value : instruction_buffer.values()) { + for (const HloValue* operand_value : operand_buffer.values()) { + if (assignment->liveness().hlo_ordering().MayInterfere( + *instruction_value, *operand_value, + assignment->dataflow_analysis())) { + interfere = true; + break; + } } } - if (consider_reusing) { - sorted_colocated_indices.push(index); + if (interfere) { + continue; } - } - while (!sorted_colocated_indices.empty()) { - auto index = sorted_colocated_indices.top(); - sorted_colocated_indices.pop(); - colocated_allocation_size_to_indices.emplace( - assignment->GetAllocation(index).size(), index); + if (assignment->alias_analysis().BufferLivesOut(instruction_buffer)) { + continue; + } + if (instruction_buffer.color() != operand_buffer.color()) { + continue; + } + VLOG(3) << "Merging inplace " << instruction_buffer << " and " + << operand_buffer; + assignment->alias_analysis().MergeBuffers(instruction_buffer, + operand_buffer); } } - for (const LogicalBuffer* buffer : sorted_buffers) { - VLOG(3) << "Assigning allocation to: " << *buffer; - if (colocated_buffers.contains(buffer)) { - // Colocated buffers are currently assigned in an earlier pass. - VLOG(3) << "Skipping colocated buffer: " << *buffer; - continue; - } + return Status::OK(); +} - TF_RET_CHECK(!assignment->HasAllocation(*buffer)); - - const HloInstruction* instruction = buffer->instruction(); - const int64 buffer_size = assignment->buffer_size_(*buffer); - - if (instruction->opcode() == HloOpcode::kConstant) { +Status BufferAssigner::AssignSingleHloBuffer( + const HloBuffer* hlo_buffer, bool is_thread_local, + absl::flat_hash_map>* + buffers_to_assign_sequentially, + std::vector* allocation_indices, + BufferAssignment* assignment) { + const int64 buffer_size = assignment->HloBufferSize(*hlo_buffer); + for (const HloValue* value : hlo_buffer->values()) { + if (value->instruction()->opcode() == HloOpcode::kConstant) { if (allocate_buffers_for_constants_) { BufferAllocation* allocation = - assignment->NewAllocation(*buffer, buffer_size); + assignment->NewAllocation(*hlo_buffer, buffer_size); allocation->set_constant(true); VLOG(3) << "New allocation #" << allocation->index() << " for constant " - << *buffer; + << *hlo_buffer << " value ptr: " << value; } - continue; + VLOG(3) << "Not allocating buffer for constant"; + return Status::OK(); } + const HloInstruction* instruction = value->instruction(); const bool is_entry_parameter = instruction->opcode() == HloOpcode::kParameter && - computation == computation->parent()->entry_computation(); + instruction->parent() == + instruction->parent()->parent()->entry_computation(); + if (is_entry_parameter) { - // If the LogicalBuffer is part of an external parameter, creates a new + bool parameter_has_alias = + assignment->module().input_output_alias_config().ParameterHasAlias( + instruction->parameter_number(), value->index()); + // If the hlo buffer is part of an external parameter, creates a new // allocation and sets its parameter number. Parameters of non-entry // computations do not need special allocations because they live inside // callers. BufferAllocation* allocation = - assignment->NewAllocation(*buffer, buffer_size); - bool parameter_has_alias = - assignment->module().input_output_alias_config().ParameterHasAlias( - instruction->parameter_number(), buffer->index()); + assignment->NewAllocation(*hlo_buffer, buffer_size); + allocation->set_entry_computation_parameter( - instruction->parameter_number(), buffer->index(), - parameter_has_alias); - VLOG(3) << "Mark allocation #" << allocation->index() - << " as entry computation parameter: " << *buffer; - continue; - } - - if (is_thread_local) { - BufferAllocation* allocation = - assignment->NewAllocation(*buffer, buffer_size); - allocation->set_is_thread_local(true); + instruction->parameter_number(), value->index(), parameter_has_alias); + if (parameter_has_alias) { + allocation_indices->push_back(allocation->index()); + } VLOG(3) << "New allocation #" << allocation->index() - << " for thread-local: " << *buffer; - continue; + << " marked as entry computation parameter: " << *hlo_buffer; + return Status::OK(); } + } - if (buffer->shape().IsTuple()) { + if (is_thread_local) { + BufferAllocation* allocation = + assignment->NewAllocation(*hlo_buffer, buffer_size); + allocation->set_is_thread_local(true); + VLOG(3) << "New allocation #" << allocation->index() + << " for thread-local: " << *hlo_buffer; + return Status::OK(); + } + + for (const HloValue* value : hlo_buffer->values()) { + if (value->shape().IsTuple()) { BufferAllocation* allocation = - assignment->NewAllocation(*buffer, buffer_size); + assignment->NewAllocation(*hlo_buffer, buffer_size); allocation->set_is_tuple(true); VLOG(3) << "New allocation #" << allocation->index() - << " for tuple-shaped buffer: " << *buffer; - continue; + << " for tuple-shaped buffer: " << *hlo_buffer; + return Status::OK(); } - // First try to assign a LogicalBuffer to one of its operand allocations to - // improve locality. This is only possible with elementwise operations - // (checked in liveness analysis) which are necessarily top-level - // array-shaped buffers. - if (buffer->IsTopLevel() && !buffer->IsTuple()) { + if (value->IsTopLevel() && !value->IsTuple()) { + const HloInstruction* instruction = value->instruction(); for (auto* operand : instruction->operands()) { - bool assigned_operand = false; for (const auto& operand_slice : assignment->GetAllSlices(operand, /*index=*/{})) { BufferAllocation* allocation = assignment->GetMutableAllocation(operand_slice.index()); - if (!colocated_allocations.contains(allocation->index())) { - // TODO(b/32491382) Colocated buffers are currently assigned in an - // earlier pass, and so can break the "increasing allocation size" - // invariant in this function (causing this CHECK to fail). However, - // the call to MaybeAssignBuffer is safe as it returns false if - // allocation.size < buffer.size. - CHECK_GE(allocation->size(), buffer_size); - } - if (MaybeAssignBuffer(allocation, *buffer, assignment)) { + if (MaybeAssignBuffer(allocation, *hlo_buffer, assignment)) { VLOG(3) << "Reusing (operand) allocation #" << allocation->index() - << " for: " << *buffer; - assigned_operand = true; - break; + << " for: " << *hlo_buffer; + return Status::OK(); } } - if (assigned_operand) { - break; - } } } + } - if (reuse_colocated_checker_ != nullptr && - reuse_colocated_checker_(*buffer, buffer_size) && - !assignment->HasAllocation(*buffer)) { - // Find the smallest buffer which can be reused iterating from the lower - // bound of the buffer size in colocated_allocation_size_to_indices. - auto it = colocated_allocation_size_to_indices.lower_bound(buffer_size); - while (it != colocated_allocation_size_to_indices.end()) { - CHECK_GE(it->first, buffer_size); - BufferAllocation* allocation = - assignment->GetMutableAllocation(it->second); - if (MaybeAssignBuffer(allocation, *buffer, assignment)) { - VLOG(3) << "Reusing allocation #" << allocation->index() - << " for: " << *buffer; - // We remove the assigned allocation from - // colocated_allocation_size_to_indices to prevent putting too many - // buffers into collocated allocations, and to reduce the search space - // for subsequent buffers. This is to avoid excessive pairwise checks - // for interference that may slow down compilation. The heap simulator - // is more efficient in live range checks. - // - // Another benefit of removing the allocation is that the reused - // allocation will be less likely to contain interferences that - // prevent operand-output reuse, which is important for in-place - // dynamic update slices. - colocated_allocation_size_to_indices.erase(it); - break; - } - ++it; - } + // Find the smallest buffer which can be reused iterating from end of + // allocation_indices (smallest) to beginning (largest). + for (int allocation_index = allocation_indices->size() - 1; + allocation_index >= 0; allocation_index--) { + BufferAllocation* allocation = assignment->GetMutableAllocation( + allocation_indices->at(allocation_index)); + if (MaybeAssignBuffer(allocation, *hlo_buffer, assignment)) { + VLOG(3) << "Reusing allocation #" << allocation->index() + << " for: " << *hlo_buffer; + return Status::OK(); } + } - if (!assignment->HasAllocation(*buffer)) { - // Find the smallest buffer which can be reused iterating from end of - // new_allocation_indices (smallest) to beginning (largest). - for (int allocation_index = new_allocation_indices.size() - 1; - allocation_index >= 0; allocation_index--) { - BufferAllocation* allocation = assignment->GetMutableAllocation( - new_allocation_indices[allocation_index]); - // Instructions are iterated in increasing buffer size, so any - // previously create allocation must be large enough to hold this - // instruction's output. - if (MaybeAssignBuffer(allocation, *buffer, assignment)) { - VLOG(3) << "Reusing allocation #" << allocation->index() - << " for: " << *buffer; - break; - } - } - } - - if (!assignment->HasAllocation(*buffer) && has_sequential_order && - !liveness.MaybeLiveOut(*buffer)) { + if (hlo_buffer->values().size() == 1) { + HloComputation* computation = + hlo_buffer->values()[0]->instruction()->parent(); + const bool has_sequential_order = + assignment->liveness().hlo_ordering().SequentialOrder(*computation) != + nullptr; + if (!assignment->HasAllocation(*hlo_buffer) && has_sequential_order && + !assignment->alias_analysis().BufferLivesOut(*hlo_buffer)) { // There is a sequential instruction ordering, so we delay assignment of // temp buffers until after the loop. We do this right before we decide to // create a new allocation, to ensure we've exhausted all the buffer @@ -1164,30 +1063,124 @@ Status BufferAssigner::AssignBuffersForComputation( // Entry parameters and thread local buffers were already handled earlier // in this loop iteration. See BufferAllocation::IsPreallocatedTempBuffer // for the definition of temp buffers. - CHECK(!is_entry_parameter) << *buffer; - CHECK(!is_thread_local) << *buffer; - (*buffers_to_assign_sequentially)[computation].insert(buffer); - VLOG(3) << "Delaying assignment of temp buffer: " << *buffer; - continue; - } - - if (!assignment->HasAllocation(*buffer)) { - BufferAllocation* allocation = - assignment->NewAllocation(*buffer, buffer_size); - new_allocation_indices.push_back(allocation->index()); - VLOG(3) << "New allocation #" << allocation->index() - << " for: " << *buffer; + (*buffers_to_assign_sequentially)[computation].insert( + hlo_buffer->values()[0]); + VLOG(3) << "Delaying assignment of temp buffer: " << *hlo_buffer; + return Status::OK(); } } + if (!assignment->HasAllocation(*hlo_buffer)) { + BufferAllocation* allocation = + assignment->NewAllocation(*hlo_buffer, buffer_size); + allocation_indices->push_back(allocation->index()); + VLOG(3) << "New allocation #" << allocation->index() + << " for: " << *hlo_buffer; + } + + TF_RET_CHECK(assignment->HasAllocation(*hlo_buffer)); return Status::OK(); } -flat_hash_map, +Status BufferAssigner::AssignBuffersForComputations( + const std::vector& computations, + bool is_thread_local, + absl::flat_hash_map>* + buffers_to_assign_sequentially, + BufferAssignment* assignment) { + if (computations.empty()) { + return Status::OK(); + } + std::vector sorted_buffers; + + const HloAliasAnalysis& alias_analysis = assignment->alias_analysis(); + for (const HloBuffer& buffer : alias_analysis.buffers()) { + TF_RET_CHECK(!buffer.values().empty()); + const HloComputation* comp = buffer.values()[0]->instruction()->parent(); + if (absl::c_linear_search(computations, comp)) { + sorted_buffers.push_back(&buffer); + } + } + + // Generate a post order sort of instructions for sorting of the + // HloBuffers. + flat_hash_map post_order_position; + int position = 0; + std::vector reverse_post_order_computations; + std::unique_ptr call_graph = + CallGraph::Build(computations[0]->parent()); + TF_RETURN_IF_ERROR(call_graph->VisitNodes([&](const CallGraphNode& node) { + if (absl::c_linear_search(computations, node.computation())) { + reverse_post_order_computations.push_back(node.computation()); + } + return Status::OK(); + })); + absl::c_reverse(reverse_post_order_computations); + for (auto* computation : reverse_post_order_computations) { + for (auto* instruction : computation->MakeInstructionPostOrder()) { + post_order_position.emplace(instruction, position); + position++; + } + } + + const BufferLiveness& liveness = assignment->liveness(); + for (const HloComputation* computation : computations) { + const bool has_sequential_order = + liveness.hlo_ordering().SequentialOrder(*computation) != nullptr; + if (has_sequential_order && buffers_to_assign_sequentially != nullptr) { + // Every sequential computation must get an entry in the + // buffers_to_assign_sequentially map, even if we end up with an empty + // set of buffers. This ensures we can correctly determine whether to + // run whole-module heap simulation. + buffers_to_assign_sequentially->emplace( + computation, flat_hash_set()); + } + } + + absl::c_sort( + sorted_buffers, [&post_order_position, &alias_analysis, assignment]( + const HloBuffer* a, const HloBuffer* b) { + // Primary sort is by decreasing buffer size. + const int64 a_size = assignment->HloBufferSize(*a); + const int64 b_size = assignment->HloBufferSize(*b); + if (a_size != b_size) { + return a_size > b_size; // use ">" for decreasing size. + } + + const bool a_live_out = alias_analysis.BufferLivesOut(*a); + const bool b_live_out = alias_analysis.BufferLivesOut(*b); + if (a_live_out != b_live_out) { + return a_live_out; + } + auto compare = [&post_order_position](const HloValue* value1, + const HloValue* value2) { + return post_order_position.at(value1->instruction()) < + post_order_position.at(value2->instruction()); + }; + const HloValue* a_min = *absl::c_min_element(a->values(), compare); + const HloValue* b_min = *absl::c_min_element(b->values(), compare); + return post_order_position.at(a_min->instruction()) < + post_order_position.at(b_min->instruction()); + }); + + std::vector allocation_indices; + + for (const HloBuffer* buffer : sorted_buffers) { + VLOG(3) << "================================================="; + VLOG(3) << "Assigning buffer for " << *buffer; + TF_RETURN_IF_ERROR(AssignSingleHloBuffer(buffer, is_thread_local, + buffers_to_assign_sequentially, + &allocation_indices, assignment)); + } + return Status::OK(); +} + +flat_hash_map, LogicalBuffer::Color::Hasher> BufferAssigner::SplitBuffersByColor( - const flat_hash_set& buffers) { - flat_hash_map, + const flat_hash_set& buffers) { + flat_hash_map, LogicalBuffer::Color::Hasher> color_map; for (auto buffer : buffers) { @@ -1196,14 +1189,76 @@ BufferAssigner::SplitBuffersByColor( return color_map; } +BufferValueFlatSet BufferAssigner::HloValueSetToLogicalBufferSet( + const absl::flat_hash_set& hlo_value_set, + const TuplePointsToAnalysis& points_to_analysis) { + BufferValueFlatSet output; + for (const BufferValue* buffer_value : hlo_value_set) { + const HloValue& hlo_value = + *CHECK_NOTNULL(dynamic_cast(buffer_value)); + + for (const HloPosition& position : hlo_value.positions()) { + if (!points_to_analysis.InstructionDefinesBufferAtIndex( + position.instruction, position.index)) { + continue; + } + int64 buffer_id = + points_to_analysis + .GetBufferDefinedAt(position.instruction, position.index) + .ValueOrDie() + ->id(); + LogicalBuffer& logical_buffer = + points_to_analysis.logical_buffer(buffer_id); + if (hlo_value.has_color()) { + logical_buffer.set_color(hlo_value.color()); + } + output.insert(&logical_buffer); + } + } + return output; +} + +std::vector BufferAssigner::BuildMustAliasLogicalBufferSet( + BufferAssignment* assignment) { + VLOG(1) << "Building must alias groups."; + std::vector output; + for (const HloBuffer& hlo_buffer : assignment->alias_analysis().buffers()) { + std::vector positions = hlo_buffer.ComputePositions(); + if (positions.size() <= 1) { + continue; + } + VLOG(2) << " Must alias group:"; + BufferValueFlatSet must_alias; + for (const HloPosition& hlo_position : positions) { + VLOG(2) << " hlo_position:" << hlo_position.ToString(); + + StatusOr logical_buffer = + assignment->points_to_analysis().GetBufferDefinedAt( + hlo_position.instruction, hlo_position.index); + if (!logical_buffer.ok()) { + // Buffer is not defined at this position. + continue; + } + + VLOG(2) << " logical buffer:" + << logical_buffer.ValueOrDie()->ToString(); + must_alias.insert(logical_buffer.ValueOrDie()); + } + if (must_alias.size() > 1) { + output.push_back(must_alias); + } + } + return output; +} + Status BufferAssigner::AssignBuffersWithSequentialOrdering( const flat_hash_map>& + flat_hash_set>& buffers_to_assign_sequentially, bool run_whole_module_heap_simulation, BufferAssignment* assignment) { - // Run the sequence of instructions through the heap simulator. The heuristic - // that seems to give the best results is lazy-best-fit, with all runs of - // alloc / free calls sorted in decreasing size order. + // Run the sequence of instructions through the heap simulator. The + // heuristic that seems to give the best results is lazy-best-fit, with all + // runs of alloc / free calls sorted in decreasing size order. const HloOrdering& hlo_ordering = assignment->liveness().hlo_ordering(); // Returns a heap algorithm that chooses the best result from several @@ -1218,17 +1273,23 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( return absl::make_unique(std::move(algorithms)); }; + // The API of heap simulator is currently logical buffer based and buffer + // assignment currently uses HloValue. As an intermediate step, we convert + // between logical buffer and HloValue around the API boundary. + // + // TODO(yunxing): Update heap simulator to use HloValue and remove the + // conversions. if (run_whole_module_heap_simulation) { - // Run the heap simulation over the whole module. This reduces memory usage, - // since buffers for kCall, kWhile, and kConditional sub-computations are - // only live for the duration of their calling instructions. + // Run the heap simulation over the whole module. This reduces memory + // usage, since buffers for kCall, kWhile, and kConditional + // sub-computations are only live for the duration of their calling + // instructions. VLOG(1) << "Running whole-module heap simulation"; HloSchedule schedule(&assignment->module()); - flat_hash_set all_buffers_to_assign; + flat_hash_set all_buffers_to_assign; for (const auto& pair : buffers_to_assign_sequentially) { const HloComputation* computation = pair.first; - const flat_hash_set& buffers_to_assign = - pair.second; + const flat_hash_set& buffers_to_assign = pair.second; const HloInstructionSequence* instruction_sequence = hlo_ordering.SequentialOrder(*computation); CHECK(instruction_sequence != nullptr) << computation->name(); @@ -1243,15 +1304,22 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( int64 alignment = assignment->color_alignment_(color); HeapSimulator::Options options; options.alloc_constants = allocate_buffers_for_constants_; - BufferValueFlatSet buffer_value_set = - ToBufferValueFlatSet(single_colored_set.second); + // At the API boundary between buffer_assignment and heap simulator, + // TuplePointsTo and LogicalBuffer are expected. + BufferValueFlatSet buffer_value_set = HloValueSetToLogicalBufferSet( + single_colored_set.second, assignment->points_to_analysis()); options.buffers_to_assign = &buffer_value_set; + + options.must_alias_sets = BuildMustAliasLogicalBufferSet(assignment); TF_ASSIGN_OR_RETURN( - const HeapSimulator::Result result, + HeapSimulator::Result result, HeapSimulator::Run(get_heap_algorithm(alignment), assignment->module(), schedule, assignment->points_to_analysis(), assignment->buffer_size_, options)); + ConvertHeapSimulatorResultToHloValue(&result, + assignment->dataflow_analysis(), + assignment->points_to_analysis()); AssignBuffersFromHeapSimulator(result, assignment, single_colored_set.first); } @@ -1262,8 +1330,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( VLOG(1) << "Running per-computation heap simulation"; for (const auto& pair : buffers_to_assign_sequentially) { const HloComputation* computation = pair.first; - const flat_hash_set& buffers_to_assign = - pair.second; + const flat_hash_set& buffers_to_assign = pair.second; const HloInstructionSequence* instruction_sequence = hlo_ordering.SequentialOrder(*computation); CHECK(instruction_sequence != nullptr) << computation->name(); @@ -1273,15 +1340,21 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( VLOG(2) << "Simulating heap for color " << color; int64 alignment = assignment->color_alignment_(color); HeapSimulator::Options options; - BufferValueFlatSet buffer_value_set = - ToBufferValueFlatSet(single_colored_set.second); + // At the API boundary between buffer_assignment and heap simulator, + // TuplePointsTo and LogicalBuffer are expected. + BufferValueFlatSet buffer_value_set = HloValueSetToLogicalBufferSet( + single_colored_set.second, assignment->points_to_analysis()); options.buffers_to_assign = &buffer_value_set; + options.must_alias_sets = BuildMustAliasLogicalBufferSet(assignment); TF_ASSIGN_OR_RETURN( - const HeapSimulator::Result result, + HeapSimulator::Result result, HeapSimulator::Run(get_heap_algorithm(alignment), *computation, *instruction_sequence, assignment->points_to_analysis(), assignment->buffer_size_, options)); + ConvertHeapSimulatorResultToHloValue(&result, + assignment->dataflow_analysis(), + assignment->points_to_analysis()); AssignBuffersFromHeapSimulator(result, assignment, single_colored_set.first); } @@ -1291,35 +1364,37 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( } namespace { - -// Computes and returns the set of logical buffers live at the point of maximal -// liveness in the given heap trace. LogicalBuffers are (stabily) sorted by id. -std::vector ComputePeakMemoryLogicalBuffers( +// Computes and returns the set of logical buffers live at the point of +// maximal liveness in the given heap trace. LogicalBuffers are (stabily) +// sorted by id. +std::vector ComputePeakMemoryLogicalBuffers( const BufferAllocation& allocation, const HeapSimulatorTrace& heap_trace) { // Create a map from LogicalBuffer::Id to LogicalBuffer* for the logical // buffers in this allocation. - absl::flat_hash_map id_to_buffer; - absl::flat_hash_map buffer_sizes; + absl::flat_hash_map id_to_value; + absl::flat_hash_map buffer_sizes; for (const auto& pair : allocation.assigned_buffers()) { - const LogicalBuffer* buffer = pair.first; + const BufferValue* value = pair.first; const BufferAllocation::OffsetSize& offset_size = pair.second; - id_to_buffer[buffer->id()] = buffer; - buffer_sizes[buffer] = offset_size.size; + id_to_value[value->id()] = value; + buffer_sizes[value] = offset_size.size; } + VLOG(1) << "Compute peak memory logical buffers"; // Returns how much the given event increases the total size of live // buffers. Can be negative. - auto memory_delta = [&id_to_buffer, &buffer_sizes]( + auto memory_delta = [&id_to_value, &buffer_sizes]( const HeapSimulatorTrace::Event& event) -> int64 { - const LogicalBuffer* buffer = id_to_buffer.at(event.buffer_id()); + if (event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) { + // Sharing a buffer does not change the live set size for the purposes + // of the heap simulator. Even though the shared-with buffer may be + // smaller, the entire allocation remains live. + return 0; + } + const BufferValue* buffer = id_to_value.at(event.buffer_id()); const int64 buffer_size = buffer_sizes.at(buffer); if (event.kind() == HeapSimulatorTrace::Event::ALLOC) { return buffer_size; - } else if (event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) { - // Sharing a buffer does not change the live set size for the purposes of - // the heap simulator. Even though the shared-with buffer may be smaller, - // the entire allocation remains live. - return 0; } else if (event.kind() == HeapSimulatorTrace::Event::FREE) { return -1 * buffer_size; } @@ -1338,43 +1413,48 @@ std::vector ComputePeakMemoryLogicalBuffers( // Next gather the set of logical buffers live at the earliest point of // maximal live set size. - absl::flat_hash_set live_buffers; + absl::flat_hash_set live_values; live_size = 0; for (const auto& event : heap_trace.events()) { - const LogicalBuffer* buffer = id_to_buffer.at(event.buffer_id()); - if (event.kind() == HeapSimulatorTrace::Event::ALLOC) { - InsertOrDie(&live_buffers, buffer); - } else if (event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) { - // Nothing to do. - } else if (event.kind() == HeapSimulatorTrace::Event::FREE) { - CHECK(ContainsKey(live_buffers, buffer)); - live_buffers.erase(buffer); + if (event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) { + continue; + } + const BufferValue* value = id_to_value.at(event.buffer_id()); + if (event.kind() == HeapSimulatorTrace::Event::ALLOC) { + InsertOrDie(&live_values, value); + } else if (event.kind() == HeapSimulatorTrace::Event::FREE) { + CHECK(ContainsKey(live_values, value)); + live_values.erase(value); } - live_size += memory_delta(event); + if (live_size == max_live_size) { break; } } CHECK_EQ(live_size, max_live_size); - std::vector live_buffers_vector; - live_buffers_vector.insert(live_buffers_vector.end(), live_buffers.begin(), - live_buffers.end()); + std::vector live_values_vector; + live_values_vector.insert(live_values_vector.end(), live_values.begin(), + live_values.end()); // Stabily sort the live buffers. - absl::c_sort(live_buffers_vector, - [](const LogicalBuffer* a, const LogicalBuffer* b) { + absl::c_sort(live_values_vector, + [](const BufferValue* a, const BufferValue* b) { return a->id() < b->id(); }); - return live_buffers_vector; + VLOG(4) << "Peak memory buffer:"; + for (auto value : live_values_vector) { + VLOG(4) << " " << value->ToString(); + } + return live_values_vector; } } // namespace void BufferAssigner::AssignBuffersFromHeapSimulator( const HeapSimulator::Result& result, BufferAssignment* assignment, - LogicalBuffer::Color color) { + BufferValue::Color color) { if (assignment->stats_.preallocated_temp_fragmentation_bytes == -1) { assignment->stats_.preallocated_temp_fragmentation_bytes = result.fragmentation_size; @@ -1386,499 +1466,96 @@ void BufferAssigner::AssignBuffersFromHeapSimulator( BufferAllocation* allocation = assignment->NewEmptyAllocation(result.heap_size, color); for (const auto& buffer_chunk : result.chunk_map) { - // TODO(lauj) Remove this down_cast after downstream users of - // BufferAllocation::assigned_buffers() are updated to use BufferValue. - const LogicalBuffer& buffer = - *CHECK_NOTNULL(dynamic_cast(buffer_chunk.first)); + const BufferValue& value = *buffer_chunk.first; const HeapSimulator::Chunk& chunk = buffer_chunk.second; - assignment->AddAssignment(allocation, buffer, chunk.offset, chunk.size); + assignment->AddAssignment(allocation, value, chunk.offset, chunk.size); } allocation->peak_buffers_ = ComputePeakMemoryLogicalBuffers(*allocation, result.debug_trace); - VLOG(1) << "Ran heap simulation for allocation: " << allocation->ToString(); + VLOG(1) << "Ran heap simulation for allocation: "; + XLA_VLOG_LINES(2, allocation->ToString()); + allocation->AddHeapTrace(result.debug_trace); } -// Adds the 'colocated_set' of buffers to 'colocated_buffer_sets', maintaining -// the invariant that all sets in 'colocated_buffer_sets' are disjoint. -// -// A practical example of when this is necessary is a chain of kCall ops: -// computation.entry -// %a = call() -> computation.1 -// computation.1 -// %b = call() -> computation.2 -// computation.2 -// %c = parameter() -// This yields the logical sets {%a,%b} {%b,%c} {%c}, which need to be merged -// into a single set {%a,%b,%c} -void BufferAssigner::AddSetToColocatedBufferSets( - const std::vector& colocated_set, - std::vector* colocated_buffer_sets) { - if (colocated_set.empty()) { - return; - } - VLOG(5) << ColocatedBufferSetsToString(colocated_set, - "Adding colocated buffer set"); - // Find existing sets that overlap with at least one buffer from the - // colocated_set. The resulting 'overlap_set_indices' will have at most - // colocated_buffer_sets->size() entries, and will be in increasing order. - std::vector overlap_set_indices; - for (size_t index = 0; index < colocated_buffer_sets->size(); ++index) { - for (const LogicalBuffer* buffer : colocated_set) { - if ((*colocated_buffer_sets)[index].contains(buffer)) { - VLOG(5) << "Found overlap with existing set on buffer " - << buffer->ToString() << "\n" - << ColocatedBufferSetsToString((*colocated_buffer_sets)[index], - "Overlapping set"); - overlap_set_indices.push_back(index); - break; - } - } - } - - // If there is no overlap with existing sets, create a new set. - if (overlap_set_indices.empty()) { - colocated_buffer_sets->emplace_back(); - colocated_buffer_sets->back().insert(colocated_set.begin(), - colocated_set.end()); - VLOG(5) << "No overlap found, new group created"; - return; - } - - // Merge all overlap sets and the colocated set into the first overlap set. - ColocatedBufferSet* first = &(*colocated_buffer_sets)[overlap_set_indices[0]]; - for (size_t index = 1; index < overlap_set_indices.size(); ++index) { - const ColocatedBufferSet& overlap_set = - (*colocated_buffer_sets)[overlap_set_indices[index]]; - first->insert(overlap_set.begin(), overlap_set.end()); - } - first->insert(colocated_set.begin(), colocated_set.end()); - VLOG(5) << ColocatedBufferSetsToString( - *first, "Result of the colocated buffer set merging"); - - // Remove overlap sets that we just merged. The offset accounts for the fact - // that as elements are erased, the indices need to be adjusted. Keep in mind - // that overlap_set_indices is in increasing order. - for (size_t index = 1; index < overlap_set_indices.size(); ++index) { - const size_t offset = overlap_set_indices[index] - index + 1; - colocated_buffer_sets->erase(colocated_buffer_sets->begin() + offset); - } -} - -std::vector -BufferAssigner::MergeColocatedBufferSets( - const std::vector& colocated_buffer_sets, - const BufferLiveness& buffer_liveness, - const LogicalBuffer::SizeFunction& buffer_size) { - VLOG(1) << "colocation sets count before coalescing:" - << colocated_buffer_sets.size(); - - // Returns true if the given buffer is for the entry parameter. - auto is_readonly_entry_parameter = [](const LogicalBuffer& buffer) { - auto* instruction = buffer.instruction(); - auto* computation = instruction->parent(); - auto* module = computation->parent(); - return instruction->opcode() == HloOpcode::kParameter && - computation == module->entry_computation() && - !module->input_output_alias_config().ParameterHasAlias( - instruction->parameter_number(), buffer.index()); - }; - - std::vector set_can_be_merged(colocated_buffer_sets.size(), true); - - // Do not merge if one of the sets includes live outs, entry parameters or - // constants. - // - // Buffer liveness does not report the correct live range for entry - // parameter and live out buffers so we have to special case them here. On - // backends that support constant buffer allocations, constant buffers are - // assigned globals in readonly storage so we can't merge colocated buffer - // sets containing constants with colocated buffer sets containing writing - // instructions or other constants. - // - // Moreover (on the CPU/GPU backends) the entry parameter buffers belong to - // the caller of the executable so we can't write to entry parameters - // either, and the argument for not merging constants also applies to entry - // parameters. - for (int64 i = 0; i < colocated_buffer_sets.size(); ++i) { - for (auto& buffer : colocated_buffer_sets[i]) { - if (buffer_liveness.MaybeLiveOut(*buffer) || - is_readonly_entry_parameter(*buffer) || - buffer->instruction()->opcode() == HloOpcode::kConstant) { - set_can_be_merged[i] = false; - break; - } - } - } - - // Returns true if the two colocated buffer sets (specified by their indices - // into the colocated_buffer_sets) can be merged into a single set. - auto cannot_merge_buffer_sets = [&colocated_buffer_sets, &buffer_liveness, - &buffer_size, - &set_can_be_merged](int64 i, int64 j) { - if (!set_can_be_merged[i] || !set_can_be_merged[j]) { - return true; - } - - // Colocated sets satisfy the invariant that all buffers within a set have - // the same size. That means we need to check whether the size is the same - // between the two sets, but also that it's enough to look at just one - // buffer within each set. - if (buffer_size(**colocated_buffer_sets[i].begin()) != - buffer_size(**colocated_buffer_sets[j].begin())) { - return true; - } - - // Do not merge if some pair of buffers interferes with each other. - for (auto& buffer_a : colocated_buffer_sets[i]) { - for (auto& buffer_b : colocated_buffer_sets[j]) { - if (buffer_a->id() != buffer_b->id() && - buffer_liveness.MayInterfere(*buffer_a, *buffer_b)) { - return true; - } - } - } - - return false; - }; - - // Build the interference map among the colocated buffer sets (nodes), by - // adding an edge between any two nodes that cannot be merged into a single - // colocated buffer set. - std::vector> interference_map( - colocated_buffer_sets.size()); - for (int64 i = 0; i < colocated_buffer_sets.size(); ++i) { - for (int64 j = i + 1; j < colocated_buffer_sets.size(); ++j) { - if (cannot_merge_buffer_sets(i, j)) { - interference_map[i].push_back(j); - interference_map[j].push_back(i); - } - } - } - - // Assign a color to each colocation set in colocated_buffer_sets, such that - // the sets that can be merged are assigned with the same color. - auto assigned_colors = ColorInterferenceGraph(interference_map); - - // Merge the buffer sets with the same color. - CHECK(!assigned_colors.empty()); - int64 num_sets = - *std::max_element(assigned_colors.begin(), assigned_colors.end()) + 1; - std::vector new_colocated_buffer_sets(num_sets); - for (int64 i = 0; i < colocated_buffer_sets.size(); ++i) { - const auto& buffer_set = colocated_buffer_sets[i]; - new_colocated_buffer_sets[assigned_colors[i]].insert(buffer_set.begin(), - buffer_set.end()); - } - - VLOG(1) << "colocation sets count after coalescing:" - << colocated_buffer_sets.size(); - return new_colocated_buffer_sets; -} - -// Builds sets of buffers in 'colocated_buffer_sets' which should be colocated -// in the same allocation (currently just supports kWhile, kCall, and -// kConditional and input output aliasing). -void BufferAssigner::BuildColocatedBufferSets( - const HloModule* module, const BufferLiveness& buffer_liveness, - const LogicalBuffer::SizeFunction& buffer_size, - std::vector* colocated_buffer_sets) { - const TuplePointsToAnalysis& points_to_analysis = - buffer_liveness.points_to_analysis(); - - // Set up colocated buffer set for input and output. - VLOG(4) << "Input/Output Alias Config: "; - VLOG(4) << module->input_output_alias_config(); - module->input_output_alias_config().ForEachAlias( - [&](const ShapeIndex& output_index, - const HloInputOutputAliasConfig::Alias& alias) { - std::vector colocated_set; - AddBufferToColocatedSet(module->entry_computation()->root_instruction(), - output_index, points_to_analysis, - &colocated_set); - AddBufferToColocatedSet( - module->entry_computation()->parameter_instruction( - alias.parameter_number), - alias.parameter_index, points_to_analysis, &colocated_set); - AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); - }); - - for (const HloComputation* computation : module->MakeComputationPostOrder()) { - if (computation->IsFusionComputation()) { - continue; - } - for (const HloInstruction* instruction : - computation->MakeInstructionPostOrder()) { - const HloOpcode opcode = instruction->opcode(); - if (opcode == HloOpcode::kWhile) { - const HloInstruction* while_hlo = instruction; - ShapeUtil::ForEachSubshape( - while_hlo->shape(), - [this, while_hlo, &points_to_analysis, buffer_size, - colocated_buffer_sets](const Shape& /*subshape*/, - const ShapeIndex& index) { - std::vector colocated_set; - // Add while.init. - AddBufferToColocatedSet(while_hlo->operand(0), index, - points_to_analysis, &colocated_set); - // Add while.result. - AddBufferToColocatedSet(while_hlo, index, points_to_analysis, - &colocated_set); - // Add while.cond.parameter. - AddBufferToColocatedSet( - while_hlo->while_condition()->parameter_instruction(0), index, - points_to_analysis, &colocated_set); - // Add while.body.parameter. - AddBufferToColocatedSet( - while_hlo->while_body()->parameter_instruction(0), index, - points_to_analysis, &colocated_set); - // Add while.body.root. - AddBufferToColocatedSet( - while_hlo->while_body()->root_instruction(), index, - points_to_analysis, &colocated_set); - AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); - }); - } else if (opcode == HloOpcode::kCall) { - const HloInstruction* call_hlo = instruction; - const HloComputation* callee = call_hlo->to_apply(); - const HloInstruction* root_hlo = callee->root_instruction(); - for (int64 i = 0; i < call_hlo->operand_count(); i++) { - const HloInstruction* call_param = callee->parameter_instruction(i); - const HloInstruction* call_operand = call_hlo->operand(i); - ShapeUtil::ForEachSubshape( - call_operand->shape(), - [&](const Shape& /*subshape*/, const ShapeIndex& index) { - std::vector colocated_set; - AddBufferToColocatedSet(call_param, index, points_to_analysis, - &colocated_set); - AddBufferToColocatedSet(call_operand, index, points_to_analysis, - &colocated_set); - AddSetToColocatedBufferSets(colocated_set, - colocated_buffer_sets); - }); - } - ShapeUtil::ForEachSubshape( - call_hlo->shape(), - [this, call_hlo, root_hlo, &points_to_analysis, - colocated_buffer_sets](const Shape& /*subshape*/, - const ShapeIndex& index) { - std::vector colocated_set; - // Add call.result. - AddBufferToColocatedSet(call_hlo, index, points_to_analysis, - &colocated_set); - // Add call.subcomputation.root. - AddBufferToColocatedSet(root_hlo, index, points_to_analysis, - &colocated_set); - AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); - }); - } else if (opcode == HloOpcode::kConditional) { - const HloInstruction* conditional = instruction; - ShapeUtil::ForEachSubshape( - conditional->shape(), - [this, conditional, &points_to_analysis, colocated_buffer_sets]( - const Shape& /*subshape*/, const ShapeIndex& index) { - std::vector colocated_set; - // Add cond.result. - AddBufferToColocatedSet(conditional, index, points_to_analysis, - &colocated_set); - for (int j = 0; j < conditional->branch_count(); ++j) { - // Add each cond.branch_computation[j].root. - AddBufferToColocatedSet( - conditional->branch_computation(j)->root_instruction(), - index, points_to_analysis, &colocated_set); - } - AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); - }); - - for (int j = 0; j < conditional->branch_count(); ++j) { - // Add branch_operand[j] (which is operand[j+1]) and - // cond.branch_computation[j].parameter(0) as a colocated - // buffer set. Note that this has to be done for each subshape in the - // branch_operand of the case. - ShapeUtil::ForEachSubshape( - conditional->operand(j + 1)->shape(), - [this, j, conditional, &points_to_analysis, - colocated_buffer_sets](const Shape& /*subshape*/, - const ShapeIndex& index) { - std::vector branch_set; - // Add cond.operand[j+1]. - AddBufferToColocatedSet(conditional->operand(j + 1), index, - points_to_analysis, &branch_set); - // Add cond.branch_computation[j].parameter_instruction(0). - AddBufferToColocatedSet( - conditional->branch_computation(j)->parameter_instruction( - 0), - index, points_to_analysis, &branch_set); - AddSetToColocatedBufferSets(branch_set, colocated_buffer_sets); - }); - } - } - } - } - - if (colocated_buffer_sets->empty()) { - return; - } - - int64 i = 0; - for (const auto& colocated_set : *colocated_buffer_sets) { - VLOG(4) << "Colocated set " << i++ << ":"; - for (const auto& buffer : colocated_set) { - VLOG(4) << " " << buffer->ToString(); - } - } - // Try to find more coalescing opportunities among the colocated buffer sets. - // - // TODO(b/32491382): We should be able to remove this by using the - // module-level liveness analysis, which would let us directly detect buffer - // sharing opportunities between the while instruction buffer and the buffers - // from the predicate and body computation, as well as sharing across - // different while instructions. - std::vector new_colocated_buffer_sets = - MergeColocatedBufferSets(*colocated_buffer_sets, buffer_liveness, - buffer_size); - std::swap(*colocated_buffer_sets, new_colocated_buffer_sets); -} - -// Assigns all colocated buffer sets in 'colocated_buffer_sets' to the same -// allocation in 'assignment'. -void BufferAssigner::AssignColocatedBufferSets( - const std::vector& colocated_buffer_sets, - BufferAssignment* assignment, - flat_hash_set* colocated_buffers, - flat_hash_set* colocated_allocations) { - for (const ColocatedBufferSet& colocated_buffer_set : colocated_buffer_sets) { - BufferAllocation* allocation = nullptr; - // Set 'entry_parameter_number' and 'entry_parameter_shape_idx' if entry - // param in 'colocated_buffer_set'. - int64 entry_parameter_number = -1; - const ShapeIndex* entry_parameter_shape_idx = nullptr; - bool is_constant = false; - for (const LogicalBuffer* buffer : colocated_buffer_set) { - const HloInstruction* instruction = buffer->instruction(); - const HloComputation* computation = instruction->parent(); - if (instruction->opcode() == HloOpcode::kParameter && - computation == computation->parent()->entry_computation()) { - entry_parameter_number = instruction->parameter_number(); - entry_parameter_shape_idx = &buffer->index(); - } else if (instruction->opcode() == HloOpcode::kConstant) { - is_constant = true; - } - } - - CHECK(!is_constant || entry_parameter_number == -1) - << "Copy insertion should have inserted copies to prevent this."; - - for (const LogicalBuffer* buffer : colocated_buffer_set) { - const int64 buffer_size = assignment->buffer_size_(*buffer); - if (allocation == nullptr) { - // TODO(b/32491382) Avoid current trivial solution of using new - // allocations for each colocated buffer set. When liveness has - // module-level scope, we can allow buffers to be shared across - // computations (in some cases). - allocation = assignment->NewAllocation(*buffer, buffer_size); - if (is_constant) { - allocation->set_constant(true); - } - colocated_allocations->insert(allocation->index()); - } else { - CHECK_EQ(buffer_size, allocation->size()) - << "Buffer: " << *buffer << " size mismatch in colocated buffer " - << "allocation: " << *allocation; - assignment->AddAssignment(allocation, *buffer, /*offset=*/0, - buffer_size); - } - colocated_buffers->insert(buffer); - } - - // If an allocation contains a parameter, set corresponding fields. - if (entry_parameter_number >= 0) { - bool parameter_has_alias = - assignment->module().input_output_alias_config().ParameterHasAlias( - entry_parameter_number, *entry_parameter_shape_idx); - allocation->set_entry_computation_parameter(entry_parameter_number, - *entry_parameter_shape_idx, - parameter_has_alias); - } - } -} - StatusOr> BufferAssigner::CreateAssignment( const HloModule* module, std::unique_ptr hlo_ordering, - LogicalBuffer::SizeFunction buffer_size, - LogicalBuffer::AlignmentFunction color_alignment) { + BufferValue::SizeFunction buffer_size, + LogicalBuffer::AlignmentFunction color_alignment, + HloDataflowAnalysis::FusionCanShareBufferFunction fusion_can_share_buffer) { TF_ASSIGN_OR_RETURN(std::unique_ptr liveness, BufferLiveness::Run(module, std::move(hlo_ordering))); + TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, + HloAliasAnalysis::Run(module, fusion_can_share_buffer)); + VLOG(1) << "Assigning buffers to module " << module->name(); XLA_VLOG_LINES(2, module->ToString()); - XLA_VLOG_LINES(3, liveness->ToString()); - XLA_VLOG_LINES(3, liveness->points_to_analysis().ToString()); + XLA_VLOG_LINES(3, alias_analysis->ToString()); + XLA_VLOG_LINES(3, alias_analysis->dataflow_analysis().ToString()); + VLOG(1) << "Number of buffers to assign: " + << alias_analysis->buffers().size(); // Can't use absl::make_unique because BufferAssignment constructor is // private. - std::unique_ptr assignment( - new BufferAssignment(module, std::move(liveness), std::move(buffer_size), - std::move(color_alignment))); + std::unique_ptr assignment(new BufferAssignment( + module, std::move(liveness), std::move(buffer_size), + std::move(color_alignment), std::move(alias_analysis))); - // Assign buffers with the tightest constraints first (colocated buffer sets). - // Once b/32491382 enables module-level liveness analysis, we may be able - // to assign colocated buffers (or at least reuse their allocation for - // buffers outside of the set) in AssignBuffersForComputation. - flat_hash_set colocated_buffers; - flat_hash_set colocated_allocations; - std::vector colocated_buffer_sets; - BuildColocatedBufferSets(module, assignment->liveness(), - assignment->buffer_size_, &colocated_buffer_sets); - TF_RETURN_IF_ERROR(colorer_(assignment->liveness())); + TF_RETURN_IF_ERROR(colorer_(&assignment->alias_analysis(), + assignment->liveness().hlo_ordering())); VLOG(3) << "After coloring:"; - XLA_VLOG_LINES(3, assignment->points_to_analysis().ToString()); - - AssignColocatedBufferSets(colocated_buffer_sets, assignment.get(), - &colocated_buffers, &colocated_allocations); + XLA_VLOG_LINES(3, + assignment->alias_analysis().dataflow_analysis().ToString()); + TF_RETURN_IF_ERROR(MergeInplaceOpBuffers(assignment.get())); std::vector thread_local_computations; std::vector global_computations; TF_RETURN_IF_ERROR(GatherComputationsByAllocationType( module, &thread_local_computations, &global_computations)); - // First assign buffers for global computatations. Temporary buffers for - // sequential computations are collected in 'buffers_to_assign_sequentially'. - flat_hash_map> + // First assign buffers for global computations. Temporary buffers for + // sequential computations are collected in + // 'buffers_to_assign_sequentially'. + flat_hash_map> buffers_to_assign_sequentially; - for (auto* computation : global_computations) { - TF_RETURN_IF_ERROR(AssignBuffersForComputation( - computation, - /*is_thread_local=*/false, colocated_buffers, colocated_allocations, - &buffers_to_assign_sequentially, assignment.get())); - } + TF_RETURN_IF_ERROR(AssignBuffersForComputations( + global_computations, + /*is_thread_local=*/false, &buffers_to_assign_sequentially, + assignment.get())); // Assign buffers with sequential ordering, if any. If all global computations // are sequential, we can run heap simuation on the whole module, which // reduces memory usage. const bool run_whole_module_heap_simulation = buffers_to_assign_sequentially.size() == global_computations.size(); + VLOG(2) << "Running whole module heap simulation" + << run_whole_module_heap_simulation; TF_RETURN_IF_ERROR(AssignBuffersWithSequentialOrdering( buffers_to_assign_sequentially, run_whole_module_heap_simulation, assignment.get())); + std::vector thread_local_computations_no_fusion; // Now assign buffers for thread-local computations. All LogicalBuffers get // their own BufferAllocation. + for (auto* computation : thread_local_computations) { TF_RET_CHECK(computation != module->entry_computation()); if (computation->IsFusionComputation()) { continue; } - TF_RETURN_IF_ERROR(AssignBuffersForComputation( - computation, - /*is_thread_local=*/true, colocated_buffers, colocated_allocations, - /*buffers_to_assign_sequentially=*/nullptr, assignment.get())); + thread_local_computations_no_fusion.push_back(computation); } + TF_RETURN_IF_ERROR(AssignBuffersForComputations( + thread_local_computations_no_fusion, + /*is_thread_local=*/true, + /*buffers_to_assign_sequentially=*/nullptr, assignment.get())); + // Mark all buffers which may be live out of the entry computation as // "liveout". - for (const LogicalBuffer* buffer : - assignment->liveness().maybe_live_out_buffers()) { + for (const HloBuffer* buffer : + assignment->alias_analysis().LiveOutBuffers()) { VLOG(3) << "maybe_live_out LogicalBuffer: " << *buffer; if (assignment->HasAllocation(*buffer)) { BufferAllocation* alloc = @@ -1897,6 +1574,7 @@ StatusOr> BufferAssigner::CreateAssignment( XLA_VLOG_LINES(2, assignment->ToString()); TF_RETURN_IF_ERROR(assignment->ComputeSummaryStats()); XLA_VLOG_LINES(1, assignment->GetStats().ToString()); + VLOG(1) << "Buffer assignment done."; return std::move(assignment); } diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 41adf1b80a5..ee56e826eaf 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -28,7 +28,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_alias_analysis.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" @@ -152,8 +154,8 @@ class BufferAllocation { // Access to the logical buffers assigned to this allocation, and their // associated logical offsets and sizes. - const absl::flat_hash_map& - assigned_buffers() const { + const absl::flat_hash_map& assigned_buffers() + const { return assigned_buffers_; } @@ -206,7 +208,7 @@ class BufferAllocation { // GetSlice returns the Slice of contiguous memory that holds the value // described by the given 'buffer'. // REQUIRES: 'buffer' must be assigned to this allocation. - Slice GetSlice(const LogicalBuffer& buffer) const; + Slice GetSlice(const BufferValue& buffer) const; string ToString() const; BufferAllocationProto ToProto() const; @@ -248,9 +250,9 @@ class BufferAllocation { // for this allocation. The point of peak memory usage is the point at which // the total size of all live logical buffers is maximal. If peak memory is // reached at multiple points, the set of logical buffers live at the earliest - // maximal point is returned. The vector is stabily sorted by - // LogicalBuffer::Index. - const std::vector& PeakMemoryLogicalBuffers() const { + // maximal point is returned. The vector is stably sorted by + // BufferValue::Index. + const std::vector& PeakMemoryLogicalBuffers() const { return peak_buffers_; } @@ -275,7 +277,7 @@ class BufferAllocation { friend class BufferAssignment; // Adds a LogicalBuffer to the set assigned to this buffer. - void AddAssignment(const LogicalBuffer& buffer, int64 offset, int64 size); + void AddAssignment(const BufferValue& buffer, int64 offset, int64 size); void set_entry_computation_parameter(int64 parameter_number, ShapeIndex param_shape_index, @@ -333,13 +335,13 @@ class BufferAllocation { // Mapping from the set of buffers assigned to this allocation to their // logical offsets and sizes. - absl::flat_hash_map assigned_buffers_; + absl::flat_hash_map assigned_buffers_; int64 fragmentation_bytes_ = 0; std::vector heap_traces_; // Set of buffers live at the point of peak memory usage for this allocation. - std::vector peak_buffers_; + std::vector peak_buffers_; }; // Add stream operators for nicer output of CHECK/RET_CHECK failures. @@ -361,12 +363,16 @@ class BufferAssignment { } // Returns whether the given buffer has been assigned an allocation. - bool HasAllocation(const LogicalBuffer& buffer) const; + bool HasAllocation(const BufferValue& value) const; + + bool HasAllocation(const HloBuffer& buffer) const; // Returns the allocation that a particular LogicalBuffer has been assigned // to. CHECKs if buffer has not been assigned an allocation. + const BufferAllocation& GetAssignedAllocation(const BufferValue& value) const; + const BufferAllocation& GetAssignedAllocation( - const LogicalBuffer& buffer) const; + const HloBuffer& hlo_buffer) const; // Returns the allocation with the given index. CHECKs if no allocation exists // with the given index. @@ -405,11 +411,11 @@ class BufferAssignment { // computation). StatusOr GetUniqueTopLevelOutputSlice() const; - // Returns the set LogicalBuffers which may be the source of the value at the + // Returns the set BufferValues which may be the source of the value at the // given index and instruction. - const PointsToSet::BufferList& GetSourceBuffers( + const std::vector& GetSourceBuffers( const HloInstruction* instruction, const ShapeIndex& index) const { - return GetPointsToSet(instruction).element(index); + return dataflow_analysis().GetValueSet(instruction, index).values(); } // Returns true if 'hlo_a{shape_index_a}' and 'hlo_b{shape_index_b}' @@ -439,6 +445,12 @@ class BufferAssignment { return liveness_->points_to_analysis(); } + const HloDataflowAnalysis& dataflow_analysis() const { + return alias_analysis_->dataflow_analysis(); + } + + HloAliasAnalysis& alias_analysis() const { return *alias_analysis_; } + // Returns the BufferLiveness object used to construct this assignment. const BufferLiveness& liveness() const { return *liveness_; } @@ -472,12 +484,14 @@ class BufferAssignment { BufferAssignment(const HloModule* module, std::unique_ptr liveness, - LogicalBuffer::SizeFunction buffer_size, - LogicalBuffer::AlignmentFunction color_alignment) + BufferValue::SizeFunction buffer_size, + LogicalBuffer::AlignmentFunction color_alignment, + std::unique_ptr alias_analysis) : module_(module), liveness_(std::move(liveness)), buffer_size_(std::move(buffer_size)), - color_alignment_(std::move(color_alignment)) {} + color_alignment_(std::move(color_alignment)), + alias_analysis_(std::move(alias_analysis)) {} // Creates and returns a new BufferAllocation, with no assigned // LogicalBuffers. Ownership is maintained internally. @@ -485,10 +499,13 @@ class BufferAssignment { // Helper that calls NewEmptyAllocation and AddAssignment in one call, // creating an allocation containing a single LogicalBuffer. - BufferAllocation* NewAllocation(const LogicalBuffer& buffer, int64 size); + BufferAllocation* NewAllocation(const HloBuffer& buffer, int64 size); // Adds a LogicalBuffer to the set assigned to the given allocation. - void AddAssignment(BufferAllocation* allocation, const LogicalBuffer& buffer, + void AddAssignment(BufferAllocation* allocation, const HloBuffer& buffer, + int64 offset, int64 size); + + void AddAssignment(BufferAllocation* allocation, const BufferValue& value, int64 offset, int64 size); // Returns the HloModule used to construct this assignment. @@ -499,9 +516,17 @@ class BufferAssignment { const PointsToSet& GetPointsToSet(const HloInstruction* instruction) const; // Mutable accessors for allocations. - BufferAllocation* GetMutableAssignedAllocation(const LogicalBuffer& buffer); + BufferAllocation* GetMutableAssignedAllocation(const HloBuffer& buffer); BufferAllocation* GetMutableAllocation(BufferAllocation::Index index); + int64 HloBufferSize(const HloBuffer& buffer) { + int64 result = buffer_size_(*buffer.values()[0]); + for (const HloValue* value : buffer.values()) { + DCHECK_EQ(result, buffer_size_(*value)); + } + return result; + } + // Combines allocations of temporary buffers into one big BufferAllocation. void CombineTempAllocations(); @@ -515,18 +540,20 @@ class BufferAssignment { int64 temp_allocation_total_size_ = 0; // Maps Buffers to the index of the BufferAllocation which holds the buffer. - absl::flat_hash_map - allocation_index_for_buffer_; + absl::flat_hash_map + allocation_index_for_value_; const HloModule* module_; const std::unique_ptr liveness_; // Function which returns the buffer size for a given logical buffer (shape). - LogicalBuffer::SizeFunction buffer_size_; + BufferValue::SizeFunction buffer_size_; // Function which returns the alignment for a given logical buffer color. LogicalBuffer::AlignmentFunction color_alignment_; + std::unique_ptr alias_analysis_; + Stats stats_; TF_DISALLOW_COPY_AND_ASSIGN(BufferAssignment); @@ -535,61 +562,86 @@ class BufferAssignment { // A class which constructs a buffer assignment. class BufferAssigner { public: - // Returns false if a buffer cannot be assigned to given allocation. - using ReuseAllocationFunction = std::function; + using Colorer = std::function; - // Returns whether a logical buffer can be considered reusing memory for - // colocated buffers. - using ReuseColocatedAllocationForTempChecker = - std::function; + static Colorer DefaultColorer() { + return [](HloAliasAnalysis* alias_analysis, const HloOrdering&) { + for (HloValue* value : alias_analysis->dataflow_analysis().values()) { + value->set_color(BufferValue::Color(0)); + } + return Status::OK(); + }; + } + + // Returns false if a buffer cannot be assigned to given allocation. // Build and return a BufferAssignment for the given module. The given // HloOrdering is used to determine buffer liveness. buffer_size and // color_alignment are functions which returns the size and alignment of a - // LogicalBuffer. allow_input_output_aliasing specifies whether input buffer - // are allowed to be reused as outbut buffers by the client code. + // LogicalBuffer. static StatusOr> Run( const HloModule* module, std::unique_ptr hlo_ordering, - LogicalBuffer::SizeFunction buffer_size, + BufferValue::SizeFunction buffer_size, LogicalBuffer::AlignmentFunction color_alignment, - bool allow_input_output_aliasing = false, bool allocate_buffers_for_constants = false, - BufferLiveness::Colorer colorer = BufferLiveness::DefaultColorer(), - ReuseAllocationFunction reuse_checker = nullptr, - ReuseColocatedAllocationForTempChecker reuse_colocated_checker = nullptr); + Colorer colorer = DefaultColorer(), + const absl::flat_hash_set& must_not_live_out = {}, + HloDataflowAnalysis::FusionCanShareBufferFunction + fusion_can_share_buffer = nullptr); private: - BufferAssigner(bool allocate_buffers_for_constants, - BufferLiveness::Colorer colorer, - ReuseAllocationFunction reuse_checker, - ReuseColocatedAllocationForTempChecker reuse_colocated_checker) + BufferAssigner(bool allocate_buffers_for_constants, Colorer colorer, + const absl::flat_hash_set& must_not_live_out) : allocate_buffers_for_constants_(allocate_buffers_for_constants), - colorer_(std::move(colorer)), - reuse_checker_(std::move(reuse_checker)), - reuse_colocated_checker_(std::move(reuse_colocated_checker)) {} + colorer_(colorer), + must_not_live_out_(must_not_live_out) {} virtual ~BufferAssigner() = default; // Create a buffer assignment. StatusOr> CreateAssignment( const HloModule* module, std::unique_ptr hlo_ordering, - LogicalBuffer::SizeFunction buffer_size, - LogicalBuffer::AlignmentFunction color_alignment); + BufferValue::SizeFunction buffer_size, + LogicalBuffer::AlignmentFunction color_alignment, + HloDataflowAnalysis::FusionCanShareBufferFunction + fusion_can_share_buffer); - // Assigns buffers to the instructions in the given computation. "assignment" + // Assigns buffers to the instructions in the given computations. "assignment" // is modified to reflect the new buffer assignments. If is_thread_local is // true, then all assigned buffers have the is_thread_local flag set to // true. - Status AssignBuffersForComputation( - const HloComputation* computation, bool is_thread_local, - const absl::flat_hash_set& colocated_buffers, - const absl::flat_hash_set& colocated_allocations, + Status AssignBuffersForComputations( + const std::vector& computations, + bool is_thread_local, absl::flat_hash_map>* + absl::flat_hash_set>* buffers_to_assign_sequentially, BufferAssignment* assignment); + // Converts a HloValueSet to LogicalBufferSet, this is needed for buffer + // assignment, which uses dataflow analysis, to talk to heap simulator that + // still uses tuple-points-to analysis. + BufferValueFlatSet HloValueSetToLogicalBufferSet( + const absl::flat_hash_set& hlo_value_set, + const TuplePointsToAnalysis& points_to_analysis); + + // Creates sets of buffer values that must be aliased with each other (e.g., + // while init and loop body parameter). + std::vector BuildMustAliasLogicalBufferSet( + BufferAssignment* assignment); + + // Promotes operations (DUS, scatter) to be done in place: If an operation can + // be done in place, merge its buffer with its operand buffer. + Status MergeInplaceOpBuffers(BufferAssignment* assignment); + + // Assigns a single hlo buffer to an HLO allocation. + Status AssignSingleHloBuffer( + const HloBuffer* hlo_buffer, bool is_thread_local, + absl::flat_hash_map>* + buffers_to_assign_sequentially, + std::vector* allocation_indices, + BufferAssignment* assignment); + // Assigns 'buffers_to_assign_sequentially' using heap simulation, assuming // the HLO instructions will be executed in the sequential order given by // assignment->liveness().hlo_ordering().SequentialOrder. If @@ -597,7 +649,7 @@ class BufferAssigner { // assuming all global computations are sequentially ordered. Status AssignBuffersWithSequentialOrdering( const absl::flat_hash_map>& + absl::flat_hash_set>& buffers_to_assign_sequentially, bool run_whole_module_heap_simulation, BufferAssignment* assignment); @@ -609,64 +661,24 @@ class BufferAssigner { // Tries to assign the given instruction to the given buffer. Returns if the // assignment was successful. - bool MaybeAssignBuffer(BufferAllocation* allocation, - const LogicalBuffer& buffer, + bool MaybeAssignBuffer(BufferAllocation* allocation, const HloBuffer& buffer, BufferAssignment* assignment); - // Colocated buffers are logical buffers from different computations which - // alias. Explicitly handling these colocated buffers is necessary because - // points-to analysis is computation level scope and does not recognize - // aliasing across computations (b/32491382). - using ColocatedBufferSet = absl::flat_hash_set; - - // Returns a vector of ColocatedBufferSet objects, where each - // ColocatedBufferSet aggregates a set of related LogicalBuffers from 'module' - // which should be colocated in the same buffer allocation. - void BuildColocatedBufferSets( - const HloModule* module, const BufferLiveness& buffer_liveness, - const LogicalBuffer::SizeFunction& buffer_size, - std::vector* colocated_buffer_sets); - - // For each buffer set in 'colocated_buffer_sets', assigns all buffers in the - // same set to the same buffer allocation in 'assignment'. - void AssignColocatedBufferSets( - const std::vector& colocated_buffer_sets, - BufferAssignment* assignment, - absl::flat_hash_set* colocated_buffers, - absl::flat_hash_set* colocated_allocations); - - // Adds the 'colocated_set' of buffers to 'colocated_buffer_sets', maintaining - // the invariant that all sets in 'colocated_buffer_sets' are disjoint. - void AddSetToColocatedBufferSets( - const std::vector& colocated_set, - std::vector* colocated_buffer_sets); - - // Given a list of colocated buffer sets (each colocated buffer set represents - // the logical buffers that would be assigned to the same physical buffer), - // try to merge the sets if the buffers can be shared. Returns the merged set. - std::vector MergeColocatedBufferSets( - const std::vector& colocated_buffer_sets, - const BufferLiveness& buffer_liveness, - const LogicalBuffer::SizeFunction& buffer_size); - // Split a set of buffers into several sets, each of which contains buffers // colored with the same color. absl::flat_hash_map, + absl::flat_hash_set, LogicalBuffer::Color::Hasher> - SplitBuffersByColor(const absl::flat_hash_set& buffers); + SplitBuffersByColor(const absl::flat_hash_set& buffers); // If true, allocate buffers for constant instructions. bool allocate_buffers_for_constants_; // Functor used to assign colors to newly allocated logical buffers. - BufferLiveness::Colorer colorer_; + Colorer colorer_; - // Functor to check if a buffer can reuse an allocation. - ReuseAllocationFunction reuse_checker_; - - // Functor to check if a temp buffer can reuse a colocated allocation. - ReuseColocatedAllocationForTempChecker reuse_colocated_checker_; + // A set of hlo opcodes that can't live out of a computation. + absl::flat_hash_set must_not_live_out_; TF_DISALLOW_COPY_AND_ASSIGN(BufferAssigner); }; diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index acdf5d25e1d..8837e6d9344 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -92,7 +92,6 @@ class BufferAssignmentTest : public HloTestBase { module, absl::make_unique(module), backend().compiler()->BufferSizeBytesFunction(), [alignment](LogicalBuffer::Color) { return alignment; }, - /*allow_input_output_aliasing=*/false, /*allocate_buffers_for_constants=*/true) .ConsumeValueOrDie(); } @@ -103,36 +102,30 @@ class BufferAssignmentTest : public HloTestBase { module, absl::make_unique(module), backend().compiler()->BufferSizeBytesFunction(), [alignment](LogicalBuffer::Color) { return alignment; }, - /*allow_input_output_aliasing=*/false, /*allocate_buffers_for_constants=*/false) .ConsumeValueOrDie(); } std::unique_ptr RunBufferAssignmentNoBuffersReuseForAdd( HloModule* module, int64 alignment = 1) { - auto reuse_checker = [](const BufferAssignment& assignment, - const BufferAllocation& alloc, - const LogicalBuffer& buffer) { - return (buffer.instruction()->opcode() != HloOpcode::kAdd); - }; + absl::flat_hash_set must_not_live_out = {HloOpcode::kAdd}; + return BufferAssigner::Run( module, absl::make_unique(module), backend().compiler()->BufferSizeBytesFunction(), [alignment](LogicalBuffer::Color) { return alignment; }, - /*allow_input_output_aliasing=*/false, /*allocate_buffers_for_constants=*/false, - /*colorer=*/BufferLiveness::DefaultColorer(), - /*reuse_checker=*/reuse_checker) + /*colorer=*/BufferAssigner::DefaultColorer(), + /*must_not_live_out=*/must_not_live_out) .ConsumeValueOrDie(); } std::unique_ptr RunColoredBufferAssignment( - HloModule* module, BufferLiveness::Colorer colorer, int64 alignment = 1) { + HloModule* module, BufferAssigner::Colorer colorer, int64 alignment = 1) { return BufferAssigner::Run( module, absl::make_unique(module), backend().compiler()->BufferSizeBytesFunction(), [alignment](LogicalBuffer::Color) { return alignment; }, - /*allow_input_output_aliasing=*/false, /*allocate_buffers_for_constants=*/true, std::move(colorer)) .ConsumeValueOrDie(); } @@ -146,29 +139,10 @@ class BufferAssignmentTest : public HloTestBase { module, absl::make_unique(schedule), backend().compiler()->BufferSizeBytesFunction(), [alignment](LogicalBuffer::Color) { return alignment; }, - /*allow_input_output_aliasing=*/false, /*allocate_buffers_for_constants=*/true) .ConsumeValueOrDie(); } - std::unique_ptr - RunBufferAssignmentWithReusingColocatedBuffersForTemp(HloModule* module, - int64 alignment = 1) { - return BufferAssigner::Run( - module, absl::make_unique(module), - backend().compiler()->BufferSizeBytesFunction(), - [alignment](LogicalBuffer::Color) { return alignment; }, - /*allow_input_output_aliasing=*/false, - /*allocate_buffers_for_constants=*/true, - /*colorer=*/BufferLiveness::DefaultColorer(), - /*reuse_checker=*/nullptr, - /*reuse_colocated_checker=*/ - [](const LogicalBuffer& buffer, int64 byte_size) { - return true; - }) - .ConsumeValueOrDie(); - } - // Builds an x+1.0 computation to use in a Map. std::unique_ptr BuildMapComputationPlus1(const string& name) { auto builder = HloComputation::Builder(name); @@ -518,77 +492,8 @@ TEST_F(BufferAssignmentTest, AliasedParamCanBeReused) { EXPECT_EQ(neg_2_buffer.index(), neg_1_buffer.index()); } -TEST_F(BufferAssignmentTest, ReuseColocatedBuffersForTemp) { - const char* const hlo_string = R"( -HloModule test - -sum (a: f32[], b: f32[]) -> f32[] { - a = f32[] parameter(0) - b = f32[] parameter(1) - ROOT add = f32[] add(a, b) -} - -while_body { - state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0) - get-tuple-element.4 = f32[1280,1,128]{2,1,0} get-tuple-element(state), index=1 - get-tuple-element.3 = s32[] get-tuple-element(state), index=0 - constant.2 = s32[] constant(128) - add.5 = s32[] add(get-tuple-element.3, constant.2) - broadcast = f32[2,1280,1,128]{3,2,1,0} broadcast(get-tuple-element.4), dimensions={1,2,3} - constant.3 = s32[] constant(0) - reduce = f32[1280,1,128]{2,1,0} reduce(broadcast, constant.3), dimensions={3}, to_apply=sum - ROOT tuple.85 = (s32[], f32[1280,1,128]{2,1,0}) tuple(add.5, reduce) -} - -while_condition { - state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0) - get-tuple-element = s32[] get-tuple-element(state), index=0 - get-tuple-element.1 = s32[] constant(3) - ROOT less-than.339.338 = pred[] compare(get-tuple-element, get-tuple-element.1), direction=LT -} - -sum.1 (a: f32[], b: f32[]) -> f32[] { - a = f32[] parameter(0) - b = f32[] parameter(1) - ROOT add = f32[] add(a, b) -} - -ENTRY entry_computation { - parameter = f32[2,1280,1,128]{3,2,1,0} parameter(0) - constant.6 = f32[] constant(0) - reduce.1 = f32[1280,1,128]{2,1,0} reduce(parameter, constant.6), dimensions={3}, to_apply=sum.1 - constant.7 = s32[] constant(0) - tuple.1 = (s32[], f32[1280,1,128]{2,1,0}) tuple(constant.7, reduce.1) - while.0 = (s32[], f32[1280,1,128]{2,1,0}) while(tuple.1), condition=while_condition, body=while_body - get-tuple-element.1 = f32[1280,1,128] get-tuple-element(while.0), index=1 - ROOT broadcast.1 = f32[2,1280,1,128]{3,2,1,0} broadcast(get-tuple-element.1), dimensions={1,2,3} -} - -)"; - auto module_or_status = - HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); - auto module = module_or_status.ConsumeValueOrDie(); - - TF_ASSERT_OK(module->input_output_alias_config().SetUpAlias( - {}, 0, {}, HloInputOutputAliasConfig::kUserAlias)); - - auto assignment = - RunBufferAssignmentWithReusingColocatedBuffersForTemp(module.get()); - // Get BufferAllocation for root instruction. - auto broadcast = FindInstruction(module.get(), "broadcast"); - auto broadcast_alloc_slice = - assignment->GetUniqueTopLevelSlice(broadcast).ConsumeValueOrDie(); - auto parameter = FindInstruction(module.get(), "parameter"); - auto parameter_alloc_slice = - assignment->GetUniqueTopLevelSlice(parameter).ConsumeValueOrDie(); - - EXPECT_EQ(broadcast_alloc_slice.allocation(), - parameter_alloc_slice.allocation()); - EXPECT_EQ(broadcast_alloc_slice, parameter_alloc_slice); -} - TEST_F(BufferAssignmentTest, AddCannotReuse) { - // Pass in a special rule to indicate that "add" cannot reuse any buffer. + // Pass in a special rule to indicate that "add" cannot be live out. // // paramscalar ------- (mul) -- (add) -- (sub) // / / / @@ -625,13 +530,13 @@ TEST_F(BufferAssignmentTest, AddCannotReuse) { EXPECT_NE(param0_buffer.index(), param1_buffer.index()); // The mul node has a valid buffer assigned, doesn't share with input. - const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul); - EXPECT_NE(mul_buffer.index(), param0_buffer.index()); + const BufferAllocation& sub_buffer = GetTopLevelAllocation(*buffers, sub); + EXPECT_NE(sub_buffer.index(), param0_buffer.index()); // The add node cannot reuse the mul node's buffer since we told buffer // assignment so. const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add); - EXPECT_NE(add_buffer.index(), mul_buffer.index()); + EXPECT_NE(add_buffer.index(), sub_buffer.index()); // The sub node has a valid output buffer assigned. GetAssignedOutputAllocation(*buffers, sub); @@ -663,14 +568,12 @@ TEST_F(BufferAssignmentTest, BasicUniquelyColored) { auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto colorer = [](const BufferLiveness& buffer_liveness) { + auto colorer = [](HloAliasAnalysis* alias_analysis, const HloOrdering&) { int color = 0; - - for (LogicalBuffer::Id id = 0; - id < buffer_liveness.points_to_analysis().num_logical_buffers(); - id++) { - auto& buffer = buffer_liveness.points_to_analysis().logical_buffer(id); - buffer.set_color(LogicalBuffer::Color(color++)); + for (HloValue::Id id = 0; + id < alias_analysis->dataflow_analysis().values().size(); id++) { + auto& value = alias_analysis->dataflow_analysis().GetValue(id); + value.set_color(BufferValue::Color(color++)); } return Status::OK(); }; @@ -724,21 +627,19 @@ TEST_F(BufferAssignmentTest, BasicPartiallyColored) { auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto colorer = [](const BufferLiveness& buffer_liveness) { - for (LogicalBuffer::Id id = 0; - id < buffer_liveness.points_to_analysis().num_logical_buffers(); - id++) { - auto& buffer = buffer_liveness.points_to_analysis().logical_buffer(id); - const auto& aliases = - buffer_liveness.points_to_analysis().GetBufferAliases(buffer); - for (const auto& alias : aliases) { - if (alias.instruction()->opcode() == HloOpcode::kAdd || - alias.instruction()->opcode() == HloOpcode::kMultiply) { - buffer.set_color(LogicalBuffer::Color(1)); + auto colorer = [](HloAliasAnalysis* alias_analysis, const HloOrdering&) { + for (HloValue::Id id = 0; + id < alias_analysis->dataflow_analysis().values().size(); id++) { + auto& value = alias_analysis->dataflow_analysis().GetValue(id); + auto& buffer = alias_analysis->GetBufferContainingValue(value); + for (const auto& alias : buffer.values()) { + if (alias->instruction()->opcode() == HloOpcode::kAdd || + alias->instruction()->opcode() == HloOpcode::kMultiply) { + value.set_color(LogicalBuffer::Color(1)); } } - if (!buffer.has_color()) { - buffer.set_color(LogicalBuffer::Color(0)); + if (!value.has_color()) { + value.set_color(LogicalBuffer::Color(0)); } } return Status::OK(); @@ -1734,7 +1635,7 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) { } TEST_F(BufferAssignmentTest, TrivialPeakBuffers) { - // paramscalar ------- (mul) -- (add) -- (sub) + // paramscalar -(bc)- (mul) -- (add) -- (sub) // / / / // param0[100] -------/ / / // / / @@ -1752,7 +1653,7 @@ TEST_F(BufferAssignmentTest, TrivialPeakBuffers) { f32vec100_, HloOpcode::kMultiply, broadcast, param0)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); - builder.AddInstruction(HloInstruction::CreateBinary( + auto sub = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kSubtract, add, param1)); auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); @@ -1760,10 +1661,10 @@ TEST_F(BufferAssignmentTest, TrivialPeakBuffers) { auto buffers = RunBufferAssignment(module.get()); const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul); - const std::vector& peak_buffers = + const std::vector& peak_buffers = mul_buffer.PeakMemoryLogicalBuffers(); ASSERT_EQ(peak_buffers.size(), 1); - EXPECT_EQ(peak_buffers[0]->instruction(), broadcast); + EXPECT_EQ(peak_buffers[0]->instruction(), sub); } TEST_F(BufferAssignmentTest, PeakBuffers) { @@ -1807,81 +1708,18 @@ TEST_F(BufferAssignmentTest, PeakBuffers) { EXPECT_TRUE(buffer.IsPreallocatedTempBuffer()); ASSERT_EQ(buffer.assigned_buffers().size(), 4); - const std::vector& peak_buffers = + const std::vector& peak_buffers = buffer.PeakMemoryLogicalBuffers(); // The peak live set should be concat and its inputs. ASSERT_EQ(peak_buffers.size(), 3); std::vector peak_instructions; - for (const LogicalBuffer* logical_buffer : peak_buffers) { + for (const BufferValue* logical_buffer : peak_buffers) { peak_instructions.push_back(logical_buffer->instruction()); } EXPECT_THAT(peak_instructions, UnorderedElementsAre(rev, neg, concat)); } -TEST_F(BufferAssignmentTest, PeakBuffersWhile) { - auto module = CreateNewVerifiedModule(); - const Shape shape = ShapeUtil::MakeShape(F32, {123, 123}); - HloComputation* condition; - { - auto b = HloComputation::Builder(TestName() + ".cond"); - b.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); - b.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); - condition = module->AddEmbeddedComputation(b.Build()); - } - HloComputation* body; - { - auto b = HloComputation::Builder(TestName() + ".body"); - auto param = - b.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); - b.AddInstruction( - HloInstruction::CreateUnary(shape, HloOpcode::kNegate, param)); - body = module->AddEmbeddedComputation(b.Build()); - } - auto builder = HloComputation::Builder(TestName()); - auto param = - builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); - auto copy = builder.AddInstruction( - HloInstruction::CreateUnary(shape, HloOpcode::kCopy, param)); - auto while_op = builder.AddInstruction( - HloInstruction::CreateWhile(shape, condition, body, copy)); - // This broadcast should get a temporary allocation which is merged with the - // allocation for the while. Peak buffers should include the while and the - // broadcast. - auto bcast = builder.AddInstruction(HloInstruction::CreateBroadcast( - ShapeUtil::MakeShape(F32, {123, 123, 123}), while_op, {0, 1})); - builder.AddInstruction(HloInstruction::CreateReverse( - ShapeUtil::MakeShape(F32, {123, 123, 123}), bcast, {0})); - module->AddEntryComputation(builder.Build()); - - auto buffers = RunBufferAssignment(module.get()); - const BufferAllocation& buffer = GetTopLevelAllocation(*buffers, bcast); - const std::vector& peak_buffers = - buffer.PeakMemoryLogicalBuffers(); - ASSERT_EQ(peak_buffers.size(), 2); - - // The peak buffers should include the broadcast and one of the colocated - // buffers of the while (body param, condition param, body root, or the while - // itself). - const LogicalBuffer* bcast_buffer; - const LogicalBuffer* nonbcast_buffer; - if (peak_buffers[0]->instruction() == bcast) { - bcast_buffer = peak_buffers[0]; - nonbcast_buffer = peak_buffers[1]; - } else { - bcast_buffer = peak_buffers[1]; - nonbcast_buffer = peak_buffers[0]; - } - EXPECT_EQ(bcast_buffer->instruction(), bcast); - EXPECT_TRUE( - nonbcast_buffer->instruction() == copy || - nonbcast_buffer->instruction() == while_op || - nonbcast_buffer->instruction() == body->parameter_instruction(0) || - nonbcast_buffer->instruction() == body->root_instruction() || - nonbcast_buffer->instruction() == condition->parameter_instruction(0)); -} - TEST_F(BufferAssignmentTest, ConstantBuffersAreNotReused) { const char* hlo_text = R"( HloModule Module @@ -1980,7 +1818,6 @@ class WhileBufferAssignmentTest : public HloTestBase { module, absl::make_unique(schedule), ByteSizeOf, [alignment](LogicalBuffer::Color) { return alignment; }, - /*allow_input_output_aliasing=*/false, /*allocate_buffers_for_constants=*/true) .ConsumeValueOrDie(); } @@ -2300,7 +2137,6 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { module.get(), absl::make_unique(schedule), backend().compiler()->BufferSizeBytesFunction(), [](LogicalBuffer::Color) { return 1; }, - /*allow_input_output_aliasing=*/false, /*allocate_buffers_for_constants=*/true)); // The result tuple elements must be assigned with different buffers. @@ -2533,7 +2369,6 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { BufferAssigner::Run( module.get(), absl::make_unique(schedule), ByteSizeOf, [](LogicalBuffer::Color) { return 1; }, - /*allow_input_output_aliasing=*/false, /*allocate_buffers_for_constants=*/true) .ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.cc b/tensorflow/compiler/xla/service/conditional_simplifier.cc index 301ac9cc3d4..f6dac508e5f 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier.cc @@ -46,7 +46,7 @@ StatusOr TryRemoveConditional(HloInstruction* conditional) { CHECK_EQ(conditional->opcode(), HloOpcode::kConditional); // Do not remove conditionals that contain side-effecting instructions or // have control predecessors/successors in either true/false computation. - if (!conditional->parent()->IsRemovable(conditional) || + if (!conditional->parent()->IsSafelyRemovable(conditional) || conditional->HasSideEffect()) { VLOG(2) << "Not attempting to remove conditional as it is not removable or " "has side effect: " @@ -188,7 +188,7 @@ StatusOr ConditionalSimplifier::Run(HloModule* module) { // instructions as we iterate. std::vector conditional_ops; for (auto* comp : module->computations()) { - for (auto* instr : comp->instructions()) { + for (auto* instr : comp->MakeInstructionPostOrder()) { if (instr->opcode() == HloOpcode::kConditional) { conditional_ops.push_back(instr); } diff --git a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc index 9759526c6e0..a584aba816f 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc @@ -212,6 +212,65 @@ ENTRY main { .size(), 2); } + +TEST_F(ConditionalSimplifierTest, + TwoConditionalsCreatedInReversedLexicalOrder) { + absl::string_view hlo_string = R"( + HloModule DeadConditional + computation.1 { + param.1 = s64[] parameter(0) + constant.1 = s64[] constant(1) + ROOT add.1 = s64[] add(param.1, constant.1) + } + + computation.2 { + param.2 = s64[] parameter(0) + constant.2 = s64[] constant(2) + ROOT add.2 = s64[] add(param.2, constant.2) + } + + computation.3 { + param.3 = s64[] parameter(0) + constant.3 = s64[] constant(3) + ROOT add.3 = s64[] add(param.3, constant.3) + } + + computation.4 { + param.4 = s64[] parameter(0) + constant.4 = s64[] constant(4) + ROOT add.4 = s64[] add(param.4, constant.4) + } + + ENTRY KernelEntry { + param.1 = s64[] parameter(0) + param.2 = s64[] parameter(1) + param.3 = s64[] parameter(2) + param.4 = pred[] parameter(3) + + conditional_1 = s64[] conditional(param.4, param.3, param.2), + true_computation=computation.3, false_computation=computation.4 + constant.1 = pred[] constant(false) + ROOT conditional_2 = s64[] conditional(constant.1, conditional_1, + param.1), true_computation=computation.1, + false_computation=computation.2 + })"; + auto status = ParseHloString(hlo_string); + TF_ASSERT_OK(status.status()); + std::unique_ptr module = status.ConsumeValueOrDie(); + HloVerifier v(false, false); + TF_ASSERT_OK(v.Run(module.get()).status()); + + // Replace conditional_1 with a clone that is created after conditional_2. + HloInstruction* conditional_1 = + FindInstruction(module.get(), "conditional_1"); + HloInstruction* conditional_1_clone = + conditional_1->parent()->AddInstruction(conditional_1->Clone()); + TF_ASSERT_OK(conditional_1->ReplaceAllUsesWith(conditional_1_clone)); + TF_ASSERT_OK(conditional_1->parent()->RemoveInstruction(conditional_1)); + + EXPECT_TRUE(ConditionalSimplifier().Run(module.get()).ValueOrDie()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h index f7e19970feb..988b93b557f 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.h +++ b/tensorflow/compiler/xla/service/copy_insertion.h @@ -94,12 +94,12 @@ class CopyInsertion : public HloModulePass { virtual Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module); - private: - Status AddCopiesToResolveInterference(HloModule* module); - // Backend specific function that decides whether a fusion can share buffer // with its operand. HloDataflowAnalysis::FusionCanShareBufferFunction fusion_can_share_buffer_; + + private: + Status AddCopiesToResolveInterference(HloModule* module); }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 09f5c859af4..227d8ffb1a0 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -9,10 +9,9 @@ load( load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") load(":build_defs.bzl", "runtime_copts") -licenses(["notice"]) # Apache 2.0 - package( default_visibility = [":friends"], + licenses = ["notice"], # Apache 2.0 ) package_group( @@ -905,6 +904,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_cost_analysis", "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 06ea1e2f8bd..a3e224824ba 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -297,6 +297,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pass.AddInvariantChecker(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); + pass.AddPass(); pass.AddPass( /*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, @@ -340,8 +341,6 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass(); - pipeline.AddPass(); - ReducePrecisionInsertion::AddPasses( &pipeline, module->config().debug_options(), ReducePrecisionInsertion::PassTiming::AFTER_FUSION); @@ -658,7 +657,6 @@ StatusOr> CpuCompiler::RunBackend( BufferAssigner::Run(module.get(), absl::make_unique(schedule), BufferSizeBytesFunction(), memory_alignment, - /*allow_input_output_aliasing=*/false, /*allocate_buffers_for_constants=*/true)); DumpHloModuleIfEnabled(*module, *assignment, "after_optimizations"); @@ -851,7 +849,6 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, BufferAssigner::Run(module, absl::make_unique(schedule), BufferSizeBytesFunction(), memory_alignment, - /*allow_input_output_aliasing=*/false, /*allocate_buffers_for_constants=*/true)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. diff --git a/tensorflow/compiler/xla/service/cpu/disassembler.cc b/tensorflow/compiler/xla/service/cpu/disassembler.cc index c3c6847b7b7..e95f29fc889 100644 --- a/tensorflow/compiler/xla/service/cpu/disassembler.cc +++ b/tensorflow/compiler/xla/service/cpu/disassembler.cc @@ -89,13 +89,14 @@ StatusOr Disassembler::DisassembleObjectFile( }); // Construct ArrayRef pointing to section contents. - llvm::StringRef section_content_string; - if (section.getContents(section_content_string)) { + llvm::Expected section_content_string = + section.getContents(); + if (!section_content_string) { continue; } llvm::ArrayRef section_content_bytes( - reinterpret_cast(section_content_string.data()), - section_content_string.size()); + reinterpret_cast(section_content_string->data()), + section_content_string->size()); // Use int types from LLVM (eg, uint64_t) for values passed to and returned // from the LLVM API. These values map to different types in LLVM and diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index 234fa91fe3e..23312e40f7e 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h" namespace xla { namespace cpu { @@ -74,8 +75,9 @@ class DefaultCostModel : public ParallelCostModel { // Limit max parallelism for I/O bound instructions by assuming a // sub-linear scaling function (fit based on empirical benchmark results). // TODO(b/29630486) Develop system bandwidth model. - max_parallelism = - std::ceil(std::sqrt(tensorflow::port::NumSchedulableCPUs())); + max_parallelism = std::min( + max_parallelism_, + std::ceil(std::sqrt(tensorflow::port::NumSchedulableCPUs()))); // Use shape size instruction cost and L2 cache size min per-thread cost. instruction_cost = shape_size_(instruction->shape()); min_cost_per_thread = 256LL << 10; // 256KB L2 Cache size. @@ -134,6 +136,10 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( // *) Emit custom loops (kSelectAndScatter). // *) Operations that are not thread safe (like infeed and rng). // *) Tuple-shaped. + // *) Operations that might be implemented as an in-place + // dynamic-update-slice, because we can't know how many output elements + // they will write (out-of-place will touch the whole output buffer, while + // in-place will only touch the updated elements). // TODO(b/27458679) Parallelize instructions which are skipped here. auto opcode = instruction->opcode(); if (opcode == HloOpcode::kParameter || opcode == HloOpcode::kConstant || @@ -147,6 +153,7 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( PotentiallyImplementedAsEigenConvolution(*instruction, target_machine_features_)) || (opcode == HloOpcode::kFusion && !instruction->IsLoopFusion()) || + llvm_ir::MayBeImplementedAsInPlaceDynamicUpdateSlice(instruction) || instruction->shape().IsTuple()) { return 1; } diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc index 35ae62b42df..e2c93568b74 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc @@ -125,5 +125,50 @@ TEST_F(ParallelTaskAssignmentTest, InfeedOutfeedOperationNotParallelized) { EXPECT_FALSE(changed); } +TEST_F(ParallelTaskAssignmentTest, InPlaceDynamicUpdateSliceNotParallelized) { + // A dynamic-update-slice within a while loop. This construction is an easy + // way to make a DUS which can be run "in-place" (i.e. the input and output + // are the same buffer, and running the DUS only writes to the updated + // elements). + const string hlo_string = R"( + HloModule test + + body { + zero = s32[] constant(0) + one = s32[] constant(1) + ten = s32[] constant(10) + loop_carry = (s32[], u32[1,100], u32[10000,100]) parameter(0) + i = s32[] get-tuple-element(loop_carry), index=0 + i_plus_ten = s32[] add(i, ten) + update = u32[1,100] get-tuple-element(loop_carry), index=1 + data = u32[10000,100] get-tuple-element(loop_carry), index=2 + new_data = u32[10000,100] dynamic-update-slice(data, update, i_plus_ten, zero) + new_i = s32[] add(i, one) + ROOT tuple = (s32[], u32[1,100], u32[10000,100]) tuple(new_i, update, new_data) + } + + cond { + loop_carry = (s32[], u32[1,100], u32[10000,100]) parameter(0) + two = s32[] constant(2) + i = s32[] get-tuple-element(loop_carry), index=0 + ROOT less-than = pred[] compare(i, two), direction=LT + } + + ENTRY test { + zero = s32[] constant(0) + initial_i = s32[] parameter(0) + update = u32[1,100] parameter(1) + data = u32[10000,100] parameter(2) + tuple = (s32[], u32[1,100], u32[10000,100]) tuple(initial_i, update, data) + ROOT while = (s32[], u32[1,100], u32[10000,100]) while(tuple), condition=cond, body=body + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(m.get())); + EXPECT_FALSE(changed); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD index 382dfd0d99d..1fa2c56abd0 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -1,10 +1,9 @@ # Description: # Tests for LLVM-based CPU backend for XLA. -licenses(["notice"]) # Apache 2.0 - package( default_visibility = [":friends"], + licenses = ["notice"], # Apache 2.0 ) package_group( diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc index b2563f9949e..1a31f5471de 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc @@ -222,54 +222,84 @@ Status DynamicDimensionInferenceVisitor::HandleReduce(HloInstruction* hlo) { } Status DynamicDimensionInferenceVisitor::HandleDot(HloInstruction* hlo) { - return ForEachOperandDynamicDimension( - hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, - int64 operand_index, HloInstruction* dynamic_size) { - HloInstruction* dot = hlo; - const DotDimensionNumbers& dimension_numbers = - dot->dot_dimension_numbers(); - // A map from the operand dimensions to result dimension. - absl::flat_hash_map result_dim_mapping; - int64 current_result_dims = 0; - std::unordered_set batch_dims( - dimension_numbers.rhs_batch_dimensions().begin(), - dimension_numbers.rhs_batch_dimensions().end()); + return ForEachOperandDynamicDimension(hlo, [&](HloInstruction* operand, + ShapeIndex operand_shape_index, + int64 operand_dimension, + int64 operand_index, + HloInstruction* dynamic_size) { + // There are three types of dimensions in a dot: + // A. batch dims + // B. contracting dims + // C. non-batch non-contracting dims. + // The output dimemsions of a dot has three parts with the following order: + // [(type A), (lhs type C), (rhs type C)] + // + // Note that both lhs and rhs have the same dimension sizes for batch, + // but the dimension index could be different. + // + // Given one dynamic input dimension, either lhs or rhs, we use a + // mapping to find the corresponding output dimension. + HloInstruction* dot = hlo; + const DotDimensionNumbers& dimension_numbers = dot->dot_dimension_numbers(); + // A map from the operand dimensions to result dimension. + absl::flat_hash_map result_dim_mapping; + int64 current_result_dims = 0; - for (int64 i : dimension_numbers.rhs_batch_dimensions()) { - result_dim_mapping[i] = current_result_dims++; - } + bool lhs = operand_index == 0; - for (int64 i = 0; i < dot->operand(0)->shape().rank(); i++) { - if (!absl::c_linear_search( - dimension_numbers.lhs_contracting_dimensions(), i)) { - if (operand_index == 0) { - result_dim_mapping[i] = current_result_dims; - } - current_result_dims++; - } - } + // The first loop keep tracks of batch dimension. RHS and LHS could have + // diffrent batch dimension numbers. + if (lhs) { + for (int64 i : dimension_numbers.lhs_batch_dimensions()) { + result_dim_mapping[i] = current_result_dims++; + } + } else { + for (int64 i : dimension_numbers.rhs_batch_dimensions()) { + result_dim_mapping[i] = current_result_dims++; + } + } - for (int64 i = 0; i < dot->operand(1)->shape().rank(); i++) { - if (!absl::c_linear_search( - dimension_numbers.rhs_contracting_dimensions(), i) && - !absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), - i)) { - if (operand_index == 1) { - result_dim_mapping[i] = current_result_dims; - } - current_result_dims++; - } - } + // Handle dimensions in the lhs. + for (int64 i = 0; i < dot->operand(0)->shape().rank(); i++) { + // Look for non-contracting and non-batching dimension. + if (absl::c_linear_search(dimension_numbers.lhs_contracting_dimensions(), + i)) { + continue; + } + if (absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(), i)) { + continue; + } + if (lhs) { + result_dim_mapping[i] = current_result_dims; + } + current_result_dims++; + } - // Check if the operand dim is in the result shape. If so, add another - // work item to trace that dimension. - auto iter = result_dim_mapping.find(dimension); - if (iter != result_dim_mapping.end()) { - parent_->SetDynamicSize(dot, {}, iter->second, dynamic_size); - } + // Handle dimensions in the rhs. + for (int64 i = 0; i < dot->operand(1)->shape().rank(); i++) { + // Look for non-contracting and non-batching dimension. + if (absl::c_linear_search(dimension_numbers.rhs_contracting_dimensions(), + i)) { + continue; + } + if (absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), i)) { + continue; + } + if (!lhs) { + result_dim_mapping[i] = current_result_dims; + } + current_result_dims++; + } - return Status::OK(); - }); + // Check if the operand dim is in the result shape. If so, add another + // work item to trace that dimension. + auto iter = result_dim_mapping.find(operand_dimension); + if (iter != result_dim_mapping.end()) { + parent_->SetDynamicSize(dot, {}, iter->second, dynamic_size); + } + + return Status::OK(); + }); } Status DynamicDimensionInferenceVisitor::HandleTranspose(HloInstruction* hlo) { diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc index a18c0176153..335aff662ec 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc @@ -344,6 +344,45 @@ TEST_F(DynamicDimensionInferenceTest, DotTest) { EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 1), nullptr); } +TEST_F(DynamicDimensionInferenceTest, DotTestBatch) { + auto builder = HloComputation::Builder(TestName()); + auto lhs_shape = ShapeUtil::MakeShape(F32, {4, 128, 2, 8}); + auto rhs_shape = ShapeUtil::MakeShape(F32, {4, 128, 2, 8}); + auto output_shape = ShapeUtil::MakeShape(F32, {4, 2, 128, 128}); + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, lhs_shape, "A")); + auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, rhs_shape, "B")); + auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/2, scalar_shape_, "size_param")); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(3); + dot_dnums.add_rhs_contracting_dimensions(3); + dot_dnums.add_lhs_batch_dimensions(0); + dot_dnums.add_lhs_batch_dimensions(2); + dot_dnums.add_rhs_batch_dimensions(0); + dot_dnums.add_rhs_batch_dimensions(2); + auto dot = builder.AddInstruction( + HloInstruction::CreateDot(output_shape, a_param, b_param, dot_dnums, + HloTestBase::DefaultPrecisionConfig(2))); + + module_->AddEntryComputation(builder.Build()); + + // Set up dynamic parameter binding for batch dimension. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + SCOPED_TRACE(module_->ToString()); + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 0), size_param); + EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 1), nullptr); + EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 2), nullptr); + EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 3), nullptr); +} + TEST_F(DynamicDimensionInferenceTest, ConvolutionTest) { auto builder = HloComputation::Builder(TestName()); constexpr int xdim = 3; diff --git a/tensorflow/compiler/xla/service/dynamic_update_slice_test.cc b/tensorflow/compiler/xla/service/dynamic_update_slice_test.cc new file mode 100644 index 00000000000..a7caab685bf --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_update_slice_test.cc @@ -0,0 +1,197 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +class DynamicUpdateSliceTest : public HloTestBase {}; + +XLA_TEST_F(DynamicUpdateSliceTest, ShardedInPlaceDUS) { + // A dynamic-update-slice within a while loop. This construction is an easy + // way to make a DUS which can be run "in-place" (i.e. the input and output + // are the same buffer, and running the DUS only writes to the updated + // elements). + const char kModuleStr[] = R"( + HloModule test + + body { + zero = s32[] constant(0) + one = s32[] constant(1) + ten = s32[] constant(10) + loop_carry = (s32[], u32[1,100], u32[10000,100]) parameter(0) + i = s32[] get-tuple-element(loop_carry), index=0 + i_plus_ten = s32[] add(i, ten) + update = u32[1,100] get-tuple-element(loop_carry), index=1 + data = u32[10000,100] get-tuple-element(loop_carry), index=2 + new_data = u32[10000,100] dynamic-update-slice(data, update, i_plus_ten, zero) + new_i = s32[] add(i, one) + ROOT tuple = (s32[], u32[1,100], u32[10000,100]) tuple(new_i, update, new_data) + } + + cond { + loop_carry = (s32[], u32[1,100], u32[10000,100]) parameter(0) + two = s32[] constant(2) + i = s32[] get-tuple-element(loop_carry), index=0 + ROOT less-than = pred[] compare(i, two), direction=LT + } + + ENTRY test { + zero = s32[] constant(0) + initial_i = s32[] parameter(0) + update = u32[1,100] parameter(1) + data = u32[10000,100] parameter(2) + tuple = (s32[], u32[1,100], u32[10000,100]) tuple(initial_i, update, data) + ROOT while = (s32[], u32[1,100], u32[10000,100]) while(tuple), condition=cond, body=body + } +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + TF_ASSERT_OK_AND_ASSIGN(auto fake_arguments, MakeFakeArguments(module.get())); + fake_arguments[0] = LiteralUtil::CreateR0(0); + + std::vector fake_argument_ptrs; + absl::c_transform( + fake_arguments, std::back_inserter(fake_argument_ptrs), + [](const Literal& literal) { return &const_cast(literal); }); + + ErrorSpec no_error(0, 0); + EXPECT_TRUE(RunAndCompare(std::move(module), fake_argument_ptrs, no_error)); +} + +// Regression test for a dynamic-update-slice involved in the expansion of a +// kScatter op. Apologies for the large testcase, this proved difficult to +// reduce. The bug we're checking for occurs when the dynamic-update-slice is +// run in place but is sharded across cores by ParallelTaskAssigner. +XLA_TEST_F(DynamicUpdateSliceTest, ExpandedScatter) { + const char kModuleStr[] = R"( +HloModule TensorFlowScatter + +and.reduce_sub_computation { + lhs = pred[] parameter(0) + rhs = pred[] parameter(1) + ROOT and = pred[] and(lhs, rhs) +} + +while_body { + param.1 = (s32[], f32[8,3,96,1,64]{4,3,2,1,0}, s32[16,4]{1,0}, f32[16,64]{1,0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(param.1), index=0 + constant.4 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.4) + get-tuple-element.2 = f32[8,3,96,1,64]{4,3,2,1,0} get-tuple-element(param.1), index=1 + constant.8 = s32[] constant(0) + broadcast.1 = s32[5]{0} broadcast(constant.8), dimensions={} + get-tuple-element.3 = s32[16,4]{1,0} get-tuple-element(param.1), index=2 + constant.5 = s32[] constant(0) + dynamic-slice = s32[1,4]{1,0} dynamic-slice(get-tuple-element.3, get-tuple-element.1, constant.5), dynamic_slice_sizes={1,4} + slice.18 = s32[1,1]{1,0} slice(dynamic-slice), slice={[0:1], [0:1]} + reshape.23 = s32[1]{0} reshape(slice.18) + reshape.4 = s32[4]{0} reshape(dynamic-slice) + slice.19 = s32[3]{0} slice(reshape.4), slice={[1:4]} + constant.6 = s32[1]{0} constant({0}) + concatenate.1 = s32[5]{0} concatenate(reshape.23, slice.19, constant.6), dimensions={0} + compare.1 = pred[5]{0} compare(broadcast.1, concatenate.1), direction=LE + constant.9 = s32[5]{0} constant({7, 2, 95, 0, 0}) + compare.2 = pred[5]{0} compare(constant.9, concatenate.1), direction=GE + and.1 = pred[5]{0} and(compare.1, compare.2) + constant.10 = pred[] constant(true) + reduce = pred[] reduce(and.1, constant.10), dimensions={0}, to_apply=and.reduce_sub_computation + broadcast.2 = pred[1,1,1,1,64]{4,3,2,1,0} broadcast(reduce), dimensions={} + reshape.24 = s32[] reshape(slice.18) + slice.26 = s32[1]{0} slice(reshape.4), slice={[1:2]} + reshape.10 = s32[] reshape(slice.26) + slice.27 = s32[1]{0} slice(reshape.4), slice={[2:3]} + reshape.11 = s32[] reshape(slice.27) + slice.28 = s32[1]{0} slice(reshape.4), slice={[3:4]} + reshape.12 = s32[] reshape(slice.28) + reshape.13 = s32[] reshape(constant.6) + dynamic-slice.2 = f32[1,1,1,1,64]{4,3,2,1,0} dynamic-slice(get-tuple-element.2, reshape.24, reshape.10, reshape.11, reshape.12, reshape.13), dynamic_slice_sizes={1,1,1,1,64} + get-tuple-element.4 = f32[16,64]{1,0} get-tuple-element(param.1), index=3 + constant.7 = s32[] constant(0) + dynamic-slice.1 = f32[1,64]{1,0} dynamic-slice(get-tuple-element.4, get-tuple-element.1, constant.7), dynamic_slice_sizes={1,64} + reshape.28 = f32[1,1,1,1,64]{4,3,2,1,0} reshape(dynamic-slice.1) + add.1 = f32[1,1,1,1,64]{4,3,2,1,0} add(dynamic-slice.2, reshape.28) + select = f32[1,1,1,1,64]{4,3,2,1,0} select(broadcast.2, add.1, dynamic-slice.2) + reshape.29 = s32[] reshape(slice.18) + slice.29 = s32[1]{0} slice(reshape.4), slice={[1:2]} + reshape.15 = s32[] reshape(slice.29) + slice.30 = s32[1]{0} slice(reshape.4), slice={[2:3]} + reshape.16 = s32[] reshape(slice.30) + slice.31 = s32[1]{0} slice(reshape.4), slice={[3:4]} + reshape.17 = s32[] reshape(slice.31) + reshape.18 = s32[] reshape(constant.6) + dynamic-update-slice = f32[8,3,96,1,64]{4,3,2,1,0} dynamic-update-slice(get-tuple-element.2, select, reshape.29, reshape.15, reshape.16, reshape.17, reshape.18) + ROOT tuple.1 = (s32[], f32[8,3,96,1,64]{4,3,2,1,0}, s32[16,4]{1,0}, f32[16,64]{1,0}) tuple(add, dynamic-update-slice, get-tuple-element.3, get-tuple-element.4) +} + +while_cond { + param.0 = (s32[], f32[8,3,96,1,64]{4,3,2,1,0}, s32[16,4]{1,0}, f32[16,64]{1,0}) parameter(0) + get-tuple-element = s32[] get-tuple-element(param.0), index=0 + constant.2 = s32[] constant(16) + ROOT compare = pred[] compare(get-tuple-element, constant.2), direction=LT +} + +ENTRY main { + constant = s32[] constant(0) + z = f32[] constant(0) + b = f32[8,3,96,1,64]{4,3,2,1,0} broadcast(z), dimensions={} + i = s32[8,2,4]{2,1,0} parameter(0) + reshape = s32[16,4]{1,0} reshape(i) + u = f32[8,2,64]{2,1,0} parameter(1) + reshape.1 = f32[16,64]{1,0} reshape(u) + tuple = (s32[], f32[8,3,96,1,64]{4,3,2,1,0}, s32[16,4]{1,0}, f32[16,64]{1,0}) tuple(constant, b, reshape, reshape.1) + while = (s32[], f32[8,3,96,1,64]{4,3,2,1,0}, s32[16,4]{1,0}, f32[16,64]{1,0}) while(tuple), condition=while_cond, body=while_body + ROOT get-tuple-element.5 = f32[8,3,96,1,64]{4,3,2,1,0} get-tuple-element(while), index=1 +} +)"; + + Literal updates = + Literal::CreateFromShape(ShapeUtil::MakeShape(F32, {8, 2, 64})); + updates.PopulateWithValue(1.0f); + + Literal indices = + Literal::CreateFromShape(ShapeUtil::MakeShape(S32, {8, 2, 4})); + indices + .Populate([&](absl::Span indices) -> int { + auto i = indices[2] + indices[1] * 4 + indices[0] * 2 * 4; + switch (indices[2]) { + case 0: + return i % 8; + case 1: + return i % 3; + case 2: + return i % 96; + default: + return 0; + } + }) + .IgnoreError(); + + ErrorSpec no_error(0, 0); + EXPECT_TRUE( + RunAndCompare(ParseAndReturnVerifiedModule(kModuleStr).ValueOrDie(), + {&indices, &updates}, no_error)); +} + +} // anonymous namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index d6a7ec90b59..efa44b2a88d 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -45,14 +45,18 @@ Status GenericTransferManager::WriteSingleTupleIndexTable( const Shape& shape, se::DeviceMemoryBase* region) { TF_RET_CHECK(elements.size() == ShapeUtil::TupleElementCount(shape)); - std::vector element_pointers; + auto element_pointers = std::make_shared>(); + element_pointers->reserve(elements.size()); for (const se::DeviceMemoryBase& element : elements) { - element_pointers.push_back(element.opaque()); + element_pointers->push_back(element.opaque()); } TF_RETURN_IF_ERROR(TransferBufferToDevice( - stream, GetByteSizeRequirement(shape), element_pointers.data(), region)); + stream, GetByteSizeRequirement(shape), element_pointers->data(), region)); // Ensure the buffer is transferred before we destroy element_pointers. - return stream->BlockHostUntilDone(); + stream->ThenDoHostCallback([element_pointers]() { + /* holds reference to element_pointers in closure */ + }); + return Status::OK(); } void GenericTransferManager::TransferLiteralFromDevice( @@ -115,7 +119,7 @@ Status GenericTransferManager::TransferLiteralToDeviceAsync( TF_RET_CHECK(stream->parent()->device_ordinal() == device_buffer.device_ordinal()); - TF_RETURN_IF_ERROR(WriteTupleIndexTables(stream, device_buffer)); + TF_RETURN_IF_ERROR(WriteTupleIndexTablesAsync(stream, device_buffer)); return ShapeUtil::ForEachSubshapeWithStatus( device_buffer.on_host_shape(), diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 2ffc6c8fb63..7a4c5ffc742 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -10,9 +10,10 @@ load( load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_cuda_library") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = [":friends"]) +package( + default_visibility = [":friends"], + licenses = ["notice"], # Apache 2.0 +) package_group( name = "friends", @@ -333,6 +334,7 @@ cc_library( deps = [ ":buffer_allocations", ":hlo_execution_profiler", + "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", @@ -347,6 +349,7 @@ tf_cuda_library( ":buffer_allocations", ":hlo_execution_profiler", ":thunk", + "//tensorflow/compiler/xla:refcounting_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/synchronization", "//tensorflow/compiler/xla:util", @@ -451,6 +454,7 @@ cc_library( "//tensorflow/stream_executor:device_memory_allocator", "//tensorflow/stream_executor:kernel", "//tensorflow/stream_executor/cuda:cuda_stream", + "//tensorflow/stream_executor/cuda:curand_plugin", "//tensorflow/stream_executor/gpu:gpu_stream", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", @@ -547,6 +551,7 @@ cc_library( "//tensorflow/stream_executor:device_memory", "//tensorflow/stream_executor:device_memory_allocator", "//tensorflow/stream_executor:stream_executor_headers", + "//tensorflow/stream_executor/cuda:ptxas_utils", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", ], @@ -635,6 +640,9 @@ cc_library( srcs = ["cusolver_context.cc"], hdrs = ["cusolver_context.h"], deps = [ + # LINT.IfChange + "@local_config_cuda//cuda:cublas_headers", + # LINT.ThenChange(//tensorflow/copy.bara.sky:cublas_headers) "@local_config_cuda//cuda:cuda_headers", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", @@ -978,6 +986,7 @@ cc_library( "//tensorflow/core/profiler/lib:traceme", "//tensorflow/stream_executor:stream_executor_headers", "//tensorflow/stream_executor/cuda:cuda_diagnostics", + "//tensorflow/stream_executor/cuda:ptxas_utils", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -1158,6 +1167,7 @@ cc_library( "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core/profiler/lib:traceme", "//tensorflow/stream_executor:kernel_spec", + "//tensorflow/stream_executor/cuda:ptxas_utils", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc index 5f3b3b48ef2..31e3eadd69f 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc @@ -363,10 +363,10 @@ static StatusOr DeviceCompare(se::Stream* stream, se::DeviceMemory rhs_typed(rhs); uint64 buffer_size = lhs_typed.ElementCount(); - PtxCompilationOptions opts(config); - TF_ASSIGN_OR_RETURN( - absl::Span compiled_ptx, - CompilePtxOrGetCached(executor, buffer_compare_ptx, opts)); + TF_ASSIGN_OR_RETURN(absl::Span compiled_ptx, + se::cuda::CompilePtxOrGetCached( + executor->device_ordinal(), buffer_compare_ptx, + PtxOptsFromConfig(config))); TF_ASSIGN_OR_RETURN( std::unique_ptr> comparison_kernel, diff --git a/tensorflow/compiler/xla/service/gpu/cholesky_thunk.cc b/tensorflow/compiler/xla/service/gpu/cholesky_thunk.cc index 7daef16cb62..84970a71ac3 100644 --- a/tensorflow/compiler/xla/service/gpu/cholesky_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/cholesky_thunk.cc @@ -52,7 +52,7 @@ CholeskyThunk::CholeskyThunk(const CholeskyOptions& options, Status CholeskyThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, - HloExecutionProfiler* profiler) { + const RunId& /*run_id*/, HloExecutionProfiler* profiler) { VLOG(3) << "type=" << PrimitiveType_Name(type_) << " uplo=" << se::blas::UpperLowerString(uplo_) << " batch_size=" << batch_size_ << " n=" << n_ diff --git a/tensorflow/compiler/xla/service/gpu/cholesky_thunk.h b/tensorflow/compiler/xla/service/gpu/cholesky_thunk.h index cde245a7e8b..eb6f02baa8c 100644 --- a/tensorflow/compiler/xla/service/gpu/cholesky_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/cholesky_thunk.h @@ -52,7 +52,7 @@ class CholeskyThunk : public Thunk { CholeskyThunk& operator=(const CholeskyThunk&) = delete; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) override; private: diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc index ea639249826..90f797e7e15 100644 --- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc @@ -58,7 +58,7 @@ Status ConditionalThunk::Initialize(const GpuExecutable& executable, Status ConditionalThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, - HloExecutionProfiler* profiler) { + const RunId& run_id, HloExecutionProfiler* profiler) { auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); // Copy the predicate value from device. int32 branch_index = -1; @@ -89,7 +89,7 @@ Status ConditionalThunk::ExecuteOnStream( // Execute the branch computation corresponding to the value of branch_index. profiler->StartHloComputation(); TF_RETURN_IF_ERROR(branch_thunks_[branch_index]->ExecuteOnStream( - buffer_allocations, stream, profiler)); + buffer_allocations, stream, run_id, profiler)); profiler->FinishHloComputation( hlo_instruction()->branch_computation(branch_index)); diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h index c0093ca6397..ca625f4a97b 100644 --- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h @@ -54,7 +54,7 @@ class ConditionalThunk : public Thunk { Status Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) override; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) override; private: diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index e1dffad3045..265a3f67020 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -43,7 +43,7 @@ ConvolutionThunk::ConvolutionThunk( Status ConvolutionThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, - HloExecutionProfiler* profiler) { + const RunId& /*run_id*/, HloExecutionProfiler* profiler) { std::vector operand_se_buffers; for (const auto& buffer : operand_buffers_) { operand_se_buffers.push_back(buffer_allocations.GetDeviceAddress(buffer)); diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h index c71515490c9..4a29164cbe6 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h @@ -54,7 +54,7 @@ class ConvolutionThunk : public Thunk { // Does the convolution for the thunk on "stream". Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) override; private: diff --git a/tensorflow/compiler/xla/service/gpu/copy_thunk.cc b/tensorflow/compiler/xla/service/gpu/copy_thunk.cc index 92e03f94c11..62878cf864d 100644 --- a/tensorflow/compiler/xla/service/gpu/copy_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/copy_thunk.cc @@ -32,7 +32,7 @@ HostToDeviceCopyThunk::HostToDeviceCopyThunk( Status HostToDeviceCopyThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, - HloExecutionProfiler* profiler) { + const RunId& /*run_id*/, HloExecutionProfiler* profiler) { se::DeviceMemoryBase destination_data = buffer_allocations.GetDeviceAddress(destination_buffer_); auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); @@ -51,7 +51,7 @@ DeviceToDeviceCopyThunk::DeviceToDeviceCopyThunk( Status DeviceToDeviceCopyThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, - HloExecutionProfiler* profiler) { + const RunId& /*run_id*/, HloExecutionProfiler* profiler) { se::DeviceMemoryBase destination_data = buffer_allocations.GetDeviceAddress(destination_buffer_); se::DeviceMemoryBase source_data = diff --git a/tensorflow/compiler/xla/service/gpu/copy_thunk.h b/tensorflow/compiler/xla/service/gpu/copy_thunk.h index 91564b520ac..30fb71f4c4e 100644 --- a/tensorflow/compiler/xla/service/gpu/copy_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/copy_thunk.h @@ -41,7 +41,7 @@ class HostToDeviceCopyThunk : public Thunk { HostToDeviceCopyThunk& operator=(const HostToDeviceCopyThunk&) = delete; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) override; private: @@ -65,7 +65,7 @@ class DeviceToDeviceCopyThunk : public Thunk { DeviceToDeviceCopyThunk& operator=(const DeviceToDeviceCopyThunk&) = delete; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) override; private: diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc index bc3c6f72f67..3147bc66e3f 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc @@ -100,7 +100,7 @@ CudnnBatchNormForwardInferenceThunk::CudnnBatchNormForwardInferenceThunk( Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, - HloExecutionProfiler* profiler) { + const RunId& /*run_id*/, HloExecutionProfiler* profiler) { dnn::BatchDescriptor operand_desc; dnn::BatchDescriptor scale_offset_desc; std::tie(operand_desc, scale_offset_desc) = @@ -114,17 +114,19 @@ Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream( se::DeviceMemory(buffer_allocations.GetDeviceAddress(offset_)), se::DeviceMemory(buffer_allocations.GetDeviceAddress(mean_)), se::DeviceMemory(buffer_allocations.GetDeviceAddress(variance_)), - operand_desc, // - scale_offset_desc, // - epsilon_, // - &output, // - /*batch_mean=*/nullptr, // - /*batch_var=*/nullptr, // - /*saved_mean=*/nullptr, // - /*saved_inv_var=*/nullptr, // - /*is_training=*/false, // - /*var_to_inv_var=*/nullptr, // - /*inv_var_to_var=*/nullptr); + operand_desc, // + scale_offset_desc, // + epsilon_, // + &output, // + /*batch_mean=*/nullptr, // + /*batch_var=*/nullptr, // + /*saved_mean=*/nullptr, // + /*saved_inv_var=*/nullptr, // + /*is_training=*/false, // + /*var_to_inv_var=*/nullptr, // + /*inv_var_to_var=*/nullptr, // + /*reserve_space_allocator=*/nullptr, // + /*workspace_allocator=*/nullptr); if (!stream->ok()) { return InternalError("BatchNormalizationForward call failed."); @@ -162,7 +164,7 @@ CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk( Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, - HloExecutionProfiler* profiler) { + const RunId& /*run_id*/, HloExecutionProfiler* profiler) { dnn::BatchDescriptor operand_desc; dnn::BatchDescriptor scale_offset_desc; // The BatchNormTraining HLO outputs a tuple of three elements: output data, @@ -196,7 +198,9 @@ Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream( /*saved_inv_var=*/&output_inv_stddev, // /*is_training=*/true, // /*var_to_inv_var=*/nullptr, // - /*inv_var_to_var=*/nullptr); + /*inv_var_to_var=*/nullptr, // + /*reserve_space_allocator=*/nullptr, // + /*workspace_allocator=*/nullptr); // Write the tuple. void* ptrs[] = {output_data.opaque(), output_mean.opaque(), @@ -246,7 +250,7 @@ CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk( Status CudnnBatchNormBackwardThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, - HloExecutionProfiler* profiler) { + const RunId& /*run_id*/, HloExecutionProfiler* profiler) { dnn::BatchDescriptor operand_desc; dnn::BatchDescriptor scale_offset_desc; @@ -272,7 +276,7 @@ Status CudnnBatchNormBackwardThunk::ExecuteOnStream( se::DeviceMemory(buffer_allocations.GetDeviceAddress(mean_)), se::DeviceMemory(buffer_allocations.GetDeviceAddress(inv_stddev_)), operand_desc, scale_offset_desc, epsilon_, &output_grad_data, - &output_grad_scale, &output_grad_offset); + &output_grad_scale, &output_grad_offset, nullptr, nullptr); // Write the output tuple. void* ptrs[] = {output_grad_data.opaque(), output_grad_scale.opaque(), diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h index d2143b39529..e0e6e86818f 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h @@ -61,7 +61,7 @@ class CudnnBatchNormForwardInferenceThunk : public Thunk { const CudnnBatchNormForwardInferenceThunk&) = delete; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) override; private: @@ -92,7 +92,7 @@ class CudnnBatchNormForwardTrainingThunk : public Thunk { const CudnnBatchNormForwardTrainingThunk&) = delete; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) override; private: @@ -126,7 +126,7 @@ class CudnnBatchNormBackwardThunk : public Thunk { delete; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) override; private: diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc index 9ef5f07d857..1e9c3d83c56 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc @@ -178,8 +178,8 @@ struct ConvCacheStats { int64 cache_misses = 0; void LogStats() { - VLOG(1) << "Cache hits: " << cache_hits; - VLOG(1) << "Cache misses: " << cache_misses; + VLOG(2) << "Cache hits: " << cache_hits; + VLOG(2) << "Cache misses: " << cache_misses; } }; @@ -269,8 +269,7 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache( if (allocator_ != nullptr) { allocator = allocator_; } else { - se_allocator.emplace(stream_exec_->platform(), - absl::Span({stream_exec_})); + se_allocator.emplace(stream_exec_); allocator = &*se_allocator; } @@ -302,12 +301,15 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache( break; } case xla::F32: { - uint32 bits; - memcpy(&bits, &kBroadcastedConstant, sizeof(bits)); - stream.ThenMemset32(&buffer, bits, buffer.size()); + se::DeviceMemory typed_buffer(buffer); + stream.ThenPopulateRandUniform(&typed_buffer); + break; + } + case xla::F64: { + se::DeviceMemory typed_buffer(buffer); + stream.ThenPopulateRandUniform(&typed_buffer); break; } - // TODO(timshen): populate non-zero data for f64. default: stream.ThenMemZero(&buffer, buffer.size()); } @@ -425,6 +427,8 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache( << AlgorithmToString(first_algorithm) << " vs " << AlgorithmToString(alg); PrintPlatformInfo(&stream); + VLOG(1) << "Full module on failure: \n" + << instr->GetModule()->ToString(); auto* fail = result.mutable_failure(); fail->set_kind(AutotuneResult::WRONG_RESULT); auto* reference_conv = fail->mutable_reference_conv(); @@ -462,12 +466,10 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache( *log.mutable_cudnn_version() = GetCudnnVersion(stream_exec_); log.set_device_pci_bus_id( stream_exec_->GetDeviceDescription().pci_bus_id()); + VLOG(1) << "Autotuning result: " << log.ShortDebugString(); // If we crash on checking failure, we are in a testing/benchmark mode, thus - // print more information instead of logging to the logger. - if (crash_on_checking_failure) { - LOG(INFO) << "Autotuning result: " << log.ShortDebugString(); - } else { - VLOG(2) << "Autotuning result:\n" << log.DebugString(); + // omitting logging through the logger. + if (!crash_on_checking_failure) { tensorflow::Logger::Singleton()->LogProto(log); } } @@ -527,7 +529,7 @@ StatusOr CudnnConvAlgorithmPicker::RunOnInstruction( } auto best_algo = std::move(best_algo_or).ValueOrDie(); - VLOG(1) << "Setting cudnn conv to use algorithm " + VLOG(2) << "Setting cudnn conv to use algorithm " << best_algo.conv().algorithm() << " and " << NumBytesToString(best_algo.scratch_bytes()) << " of scratch memory: " << instr->ToString() @@ -548,7 +550,7 @@ StatusOr CudnnConvAlgorithmPicker::RunOnInstruction( HloInstruction* new_call = computation->AddInstruction( instr->CloneWithNewOperands(new_call_shape, instr->operands())); - VLOG(1) << "Replacing convolution " << instr->ToString() << " with " + VLOG(2) << "Replacing convolution " << instr->ToString() << " with " << new_call->ToString(); TF_RETURN_IF_ERROR(new_call->set_backend_config(backend_config)); diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc index cd0198e2cb9..c2817e36466 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc @@ -14,10 +14,10 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h" + #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -39,42 +39,6 @@ using se::dnn::FilterDescriptor; using se::dnn::FilterLayout; using se::dnn::ProfileResult; -struct CudnnConvParams { - // Here are the fields related to cuDNN's fused convolution. The result thus - // is defined as: - // activation(conv_result_scale * conv(x, w) + - // side_input_scale * side_input + broadcast(bias)) - // - // The most common fused conv is conv forward + relu/identity, for example. - // - // bias_buf is a single-dimensional array, with the length equal to the number - // of output features. It'll be broadcasted to the output shape in order to be - // added to the final results. - // - // side_input_buf, if valid, must have the same shape as the output buffer. - struct FusionParams { - se::dnn::ActivationMode mode; - double side_input_scale; - se::DeviceMemoryBase bias_buf; - se::DeviceMemoryBase side_input_buf; // nullable - }; - - CudnnConvKind kind; - const Shape* input_shape; - const Shape* filter_shape; - const Shape* output_shape; - se::DeviceMemoryBase input_buf; - se::DeviceMemoryBase filter_buf; - se::DeviceMemoryBase output_buf; - const Window* window; - const ConvolutionDimensionNumbers* dnums; - int64 feature_group_count; - se::dnn::AlgorithmConfig algorithm; - double conv_result_scale; - - absl::optional fusion; -}; - // A StreamExecutor ScratchAllocator that wraps a single XLA allocation, // returning it (in its entirety) the first time Allocate() is called. class ScratchBufAllocator : public se::ScratchAllocator { @@ -110,132 +74,19 @@ class ScratchBufAllocator : public se::ScratchAllocator { }; template -Status RunCudnnConvImpl(CudnnConvParams params, +Status RunCudnnConvImpl(const CudnnConvParams& params, se::ScratchAllocator* scratch_allocator, - se::Stream* stream, - se::dnn::ProfileResult* profile_result) { - CudnnConvKind kind = params.kind; - const Shape& input_shape = *params.input_shape; - const Shape& filter_shape = *params.filter_shape; - const Shape& output_shape = *params.output_shape; - DeviceMemory input_buf(params.input_buf); - DeviceMemory filter_buf(params.filter_buf); - DeviceMemory output_buf(params.output_buf); - const Window& window = *params.window; - const ConvolutionDimensionNumbers& dnums = *params.dnums; - int64 feature_group_count = params.feature_group_count; + se::Stream* stream, RunConvOptions options) { + auto input_buf = se::DeviceMemory(params.input_buf); + auto filter_buf = se::DeviceMemory(params.filter_buf); + auto output_buf = se::DeviceMemory(params.output_buf); AlgorithmConfig algorithm = params.algorithm; - VLOG(3) << "Convolution Algorithm: " << algorithm.algorithm()->algo_id(); - VLOG(3) << "tensor_ops_enabled: " - << algorithm.algorithm()->tensor_ops_enabled(); - VLOG(3) << "Convolution kind: " << CudnnConvKindToString(kind); - VLOG(3) << "input shape: " << ShapeUtil::HumanStringWithLayout(input_shape); - VLOG(3) << "filter shape: " << ShapeUtil::HumanStringWithLayout(filter_shape); - VLOG(3) << "Output shape: " << ShapeUtil::HumanStringWithLayout(output_shape); - VLOG(3) << "Window: { " << window.ShortDebugString() << " }"; - VLOG(3) << "Dim nums: { " << dnums.ShortDebugString() << " }"; - - const int num_dimensions = window.dimensions_size(); - CHECK_LE(num_dimensions, 3); - CHECK_GE(num_dimensions, 1); - // cuDNN does not support 1D convolutions. We therefore express 1D - // convolutions as 2D convolutions where the first spatial dimension is 1. - // This matches the behavior of TF (see definition of conv1d in - // tensorflow/python/ops/nn_ops.py). - const int effective_num_dimensions = std::max(2, num_dimensions); - - CHECK_EQ(primitive_util::NativeToPrimitiveType(), - output_shape.element_type()) - << ShapeUtil::HumanString(output_shape); - - // If one dimension is reversed, we need to have all dimensions reversed (so - // we're doing convolution not cross correlation). - const bool dims_reversed = window.dimensions()[0].window_reversal(); - - CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size()); - CHECK_EQ(num_dimensions, dnums.kernel_spatial_dimensions_size()); - CHECK_EQ(num_dimensions, dnums.output_spatial_dimensions_size()); - for (const WindowDimension& dim : window.dimensions()) { - CHECK_EQ(dims_reversed, dim.window_reversal()); - CHECK_EQ(dim.padding_low(), dim.padding_high()); - CHECK_EQ(dim.base_dilation(), 1) - << "cudnn does not support base dilation; it " - "must be made explicit with a kPad"; - CHECK_EQ(dim.window_dilation(), 1) - << "XLA does not support window dilation (although cudnn does); it " - "must be made explicit with a kPad"; + if (options.algo_override) { + algorithm = AlgorithmConfig(*options.algo_override); } - // cuDNN's convolution APIs support the BDYX layout for activations/output and - // the OIYX layout for weights. - DataLayout input_dl; - FilterLayout filter_dl; - DataLayout output_dl; - - TF_ASSIGN_OR_RETURN(std::tie(input_dl, filter_dl, output_dl), - XlaConvLayoutsToStreamExecutorLayouts( - dnums, input_shape.layout(), filter_shape.layout(), - output_shape.layout())); - - BatchDescriptor input_descriptor(effective_num_dimensions); - input_descriptor.set_layout(input_dl) - .set_feature_map_count( - input_shape.dimensions(dnums.input_feature_dimension())) - .set_count(input_shape.dimensions(dnums.input_batch_dimension())); - for (int dim = 0; dim < num_dimensions; ++dim) { - // Note that the dimensions are reversed. The same holds below. - input_descriptor.set_spatial_dim( - static_cast(effective_num_dimensions - dim - 1), - input_shape.dimensions(dnums.input_spatial_dimensions(dim))); - } - - FilterDescriptor filter_descriptor(effective_num_dimensions); - filter_descriptor.set_layout(filter_dl) - .set_input_feature_map_count( - filter_shape.dimensions(dnums.kernel_input_feature_dimension())) - .set_output_feature_map_count( - filter_shape.dimensions(dnums.kernel_output_feature_dimension())); - for (int dim = 0; dim < num_dimensions; ++dim) { - filter_descriptor.set_spatial_dim( - static_cast(effective_num_dimensions - dim - 1), - filter_shape.dimensions(dnums.kernel_spatial_dimensions(dim))); - } - - ConvolutionDescriptor convolution_descriptor(effective_num_dimensions); - convolution_descriptor.set_group_count(feature_group_count); - convolution_descriptor.set_convolution_not_crosscorr(dims_reversed); - for (int dim = 0; dim < num_dimensions; ++dim) { - convolution_descriptor - .set_zero_padding( - static_cast(effective_num_dimensions - dim - 1), - window.dimensions(dim).padding_low()) - .set_filter_stride( - static_cast(effective_num_dimensions - dim - 1), - window.dimensions(dim).stride()); - } - - BatchDescriptor output_descriptor(effective_num_dimensions); - output_descriptor.set_layout(output_dl) - .set_feature_map_count( - output_shape.dimensions(dnums.output_feature_dimension())) - .set_count(output_shape.dimensions(dnums.output_batch_dimension())); - for (int dim = 0; dim < num_dimensions; ++dim) { - output_descriptor.set_spatial_dim( - static_cast(effective_num_dimensions - dim - 1), - output_shape.dimensions(dnums.output_spatial_dimensions(dim))); - } - - // Add a singleton dimension in the 1D convolution case. - if (num_dimensions == 1) { - input_descriptor.set_spatial_dim(static_cast(0), 1); - output_descriptor.set_spatial_dim(static_cast(0), 1); - filter_descriptor.set_spatial_dim(static_cast(0), 1); - convolution_descriptor.set_zero_padding(static_cast(0), 0) - .set_filter_stride(static_cast(0), 1); - } - - switch (kind) { + switch (params.kind) { case CudnnConvKind::kForward: if (params.conv_result_scale != 1) { return InternalError( @@ -243,9 +94,9 @@ Status RunCudnnConvImpl(CudnnConvParams params, params.conv_result_scale); } stream->ThenConvolveWithAlgorithm( - input_descriptor, input_buf, filter_descriptor, filter_buf, - convolution_descriptor, output_descriptor, &output_buf, - scratch_allocator, algorithm, profile_result); + params.input_descriptor, input_buf, params.filter_descriptor, + filter_buf, params.conv_desc, params.output_descriptor, &output_buf, + scratch_allocator, algorithm, options.profile_result); break; case CudnnConvKind::kBackwardInput: if (params.conv_result_scale != 1) { @@ -254,9 +105,9 @@ Status RunCudnnConvImpl(CudnnConvParams params, params.conv_result_scale); } stream->ThenConvolveBackwardDataWithAlgorithm( - filter_descriptor, filter_buf, output_descriptor, output_buf, - convolution_descriptor, input_descriptor, &input_buf, - scratch_allocator, algorithm, profile_result); + params.filter_descriptor, filter_buf, params.output_descriptor, + output_buf, params.conv_desc, params.input_descriptor, &input_buf, + scratch_allocator, algorithm, options.profile_result); break; case CudnnConvKind::kBackwardFilter: if (params.conv_result_scale != 1) { @@ -265,18 +116,17 @@ Status RunCudnnConvImpl(CudnnConvParams params, params.conv_result_scale); } stream->ThenConvolveBackwardFilterWithAlgorithm( - input_descriptor, input_buf, output_descriptor, output_buf, - convolution_descriptor, filter_descriptor, &filter_buf, - scratch_allocator, algorithm, profile_result); + params.input_descriptor, input_buf, params.output_descriptor, + output_buf, params.conv_desc, params.filter_descriptor, &filter_buf, + scratch_allocator, algorithm, options.profile_result); break; case CudnnConvKind::kForwardActivation: { BatchDescriptor bias_desc; bias_desc.set_count(1) .set_height(1) .set_width(1) - .set_feature_map_count( - output_shape.dimensions(dnums.output_feature_dimension())) - .set_layout(output_dl); + .set_feature_map_count(params.output_descriptor.feature_map_count()) + .set_layout(params.output_descriptor.layout()); se::DeviceMemory side_input(params.fusion->side_input_buf); // If there is no side input, use output as the side input. @@ -296,12 +146,12 @@ Status RunCudnnConvImpl(CudnnConvParams params, } stream->ThenFusedConvolveWithAlgorithm( - input_descriptor, input_buf, params.conv_result_scale, - filter_descriptor, filter_buf, convolution_descriptor, side_input, + params.input_descriptor, input_buf, params.conv_result_scale, + params.filter_descriptor, filter_buf, params.conv_desc, side_input, params.fusion->side_input_scale, bias_desc, DeviceMemory(params.fusion->bias_buf), params.fusion->mode, - output_descriptor, &output_buf, scratch_allocator, algorithm, - profile_result); + params.output_descriptor, &output_buf, scratch_allocator, algorithm, + options.profile_result); break; } } @@ -309,14 +159,14 @@ Status RunCudnnConvImpl(CudnnConvParams params, if (!stream->ok()) { return InternalError( "Unable to launch convolution with type %s and algorithm (%d, %d)", - CudnnConvKindToString(kind), algorithm.algorithm()->algo_id(), + CudnnConvKindToString(params.kind), algorithm.algorithm()->algo_id(), algorithm.algorithm_no_scratch()->algo_id()); } return Status::OK(); } -// Returns the cudnn convolution parameters generated from conv, which must be a -// custom-call to a cudnn convolution. +} // anonymous namespace + StatusOr GetCudnnConvParams( const HloCustomCallInstruction* conv, absl::Span operand_buffers, @@ -325,50 +175,46 @@ StatusOr GetCudnnConvParams( TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, conv->backend_config()); - TF_ASSIGN_OR_RETURN(CudnnConvKind kind, GetCudnnConvKind(conv)); - const auto& lhs_shape = conv->operand(0)->shape(); - const auto& rhs_shape = conv->operand(1)->shape(); - const auto& conv_result_shape = conv->shape().tuple_shapes(0); + TF_ASSIGN_OR_RETURN(params.kind, GetCudnnConvKind(conv)); + const Shape* input_shape; + const Shape* filter_shape; + const Shape* output_shape; - params.kind = kind; - params.window = &conv->window(); - params.dnums = &conv->convolution_dimension_numbers(); - params.feature_group_count = conv->feature_group_count(); params.algorithm = se::dnn::AlgorithmConfig(se::dnn::AlgorithmDesc( backend_config.algorithm(), backend_config.tensor_ops_enabled())); params.conv_result_scale = backend_config.conv_result_scale(); - switch (kind) { + switch (params.kind) { case CudnnConvKind::kForward: - params.input_shape = &lhs_shape; - params.filter_shape = &rhs_shape; - params.output_shape = &conv_result_shape; + input_shape = &conv->operand(0)->shape(); + filter_shape = &conv->operand(1)->shape(); + output_shape = &conv->shape().tuple_shapes(0); params.input_buf = operand_buffers[0]; params.filter_buf = operand_buffers[1]; params.output_buf = result_buffer; break; case CudnnConvKind::kBackwardInput: - params.input_shape = &conv_result_shape; - params.filter_shape = &rhs_shape; - params.output_shape = &lhs_shape; + input_shape = &conv->shape().tuple_shapes(0); + filter_shape = &conv->operand(1)->shape(); + output_shape = &conv->operand(0)->shape(); params.input_buf = result_buffer; params.filter_buf = operand_buffers[1]; params.output_buf = operand_buffers[0]; break; case CudnnConvKind::kBackwardFilter: - params.input_shape = &lhs_shape; - params.filter_shape = &conv_result_shape; - params.output_shape = &rhs_shape; + input_shape = &conv->operand(0)->shape(); + filter_shape = &conv->shape().tuple_shapes(0); + output_shape = &conv->operand(1)->shape(); params.input_buf = operand_buffers[0]; params.filter_buf = result_buffer; params.output_buf = operand_buffers[1]; break; case CudnnConvKind::kForwardActivation: { - params.input_shape = &lhs_shape; - params.filter_shape = &rhs_shape; - params.output_shape = &conv_result_shape; + input_shape = &conv->operand(0)->shape(); + filter_shape = &conv->operand(1)->shape(); + output_shape = &conv->shape().tuple_shapes(0); params.fusion.emplace(); - auto& fusion = *params.fusion; + CudnnConvParams::FusionParams& fusion = *params.fusion; if (!se::dnn::ActivationMode_IsValid(backend_config.activation_mode())) { return InternalError("Bad activation mode: %s", backend_config.ShortDebugString()); @@ -385,11 +231,129 @@ StatusOr GetCudnnConvParams( } } } + + const Window& window = conv->window(); + const ConvolutionDimensionNumbers& dnums = + conv->convolution_dimension_numbers(); + + VLOG(3) << "Convolution Algorithm: " + << params.algorithm.algorithm()->algo_id(); + VLOG(3) << "tensor_ops_enabled: " + << params.algorithm.algorithm()->tensor_ops_enabled(); + VLOG(3) << "Convolution kind: " << CudnnConvKindToString(params.kind); + VLOG(3) << "input shape: " << ShapeUtil::HumanStringWithLayout(*input_shape); + VLOG(3) << "filter shape: " + << ShapeUtil::HumanStringWithLayout(*filter_shape); + VLOG(3) << "Output shape: " + << ShapeUtil::HumanStringWithLayout(*output_shape); + VLOG(3) << "Window: { " << window.ShortDebugString() << " }"; + VLOG(3) << "Dim nums: { " << dnums.ShortDebugString() << " }"; + + const int num_dimensions = window.dimensions_size(); + CHECK_LE(num_dimensions, 3) << conv->ToString(); + CHECK_GE(num_dimensions, 1) << conv->ToString(); + // cuDNN does not support 1D convolutions. We therefore express 1D + // convolutions as 2D convolutions where the first spatial dimension is 1. + // This matches the behavior of TF (see definition of conv1d in + // tensorflow/python/ops/nn_ops.py). + const int effective_num_dimensions = std::max(2, num_dimensions); + + // If one dimension is reversed, we need to have all dimensions reversed (so + // we're doing convolution not cross correlation). + const bool dims_reversed = window.dimensions()[0].window_reversal(); + + CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size()) + << conv->ToString(); + CHECK_EQ(num_dimensions, dnums.kernel_spatial_dimensions_size()) + << conv->ToString(); + CHECK_EQ(num_dimensions, dnums.output_spatial_dimensions_size()) + << conv->ToString(); + for (const WindowDimension& dim : window.dimensions()) { + CHECK_EQ(dims_reversed, dim.window_reversal()) << conv->ToString(); + CHECK_EQ(dim.padding_low(), dim.padding_high()) << conv->ToString(); + CHECK_EQ(dim.base_dilation(), 1) + << "cudnn does not support base dilation; it " + "must be made explicit with a kPad: " + << conv->ToString(); + } + + // cuDNN's convolution APIs support the BDYX layout for activations/output and + // the OIYX layout for weights. + DataLayout input_dl; + FilterLayout filter_dl; + DataLayout output_dl; + + TF_ASSIGN_OR_RETURN(std::tie(input_dl, filter_dl, output_dl), + XlaConvLayoutsToStreamExecutorLayouts( + dnums, input_shape->layout(), filter_shape->layout(), + output_shape->layout())); + + BatchDescriptor& input_descriptor = params.input_descriptor; + input_descriptor = BatchDescriptor(effective_num_dimensions); + input_descriptor.set_layout(input_dl) + .set_feature_map_count( + input_shape->dimensions(dnums.input_feature_dimension())) + .set_count(input_shape->dimensions(dnums.input_batch_dimension())); + for (int dim = 0; dim < num_dimensions; ++dim) { + // Note that the dimensions are reversed. The same holds below. + input_descriptor.set_spatial_dim( + static_cast(effective_num_dimensions - dim - 1), + input_shape->dimensions(dnums.input_spatial_dimensions(dim))); + } + + FilterDescriptor& filter_descriptor = params.filter_descriptor; + filter_descriptor = FilterDescriptor(effective_num_dimensions); + filter_descriptor.set_layout(filter_dl) + .set_input_feature_map_count( + filter_shape->dimensions(dnums.kernel_input_feature_dimension())) + .set_output_feature_map_count( + filter_shape->dimensions(dnums.kernel_output_feature_dimension())); + for (int dim = 0; dim < num_dimensions; ++dim) { + filter_descriptor.set_spatial_dim( + static_cast(effective_num_dimensions - dim - 1), + filter_shape->dimensions(dnums.kernel_spatial_dimensions(dim))); + } + + params.conv_desc = ConvolutionDescriptor(effective_num_dimensions); + params.conv_desc.set_group_count(conv->feature_group_count()); + params.conv_desc.set_convolution_not_crosscorr(dims_reversed); + for (int dim = 0; dim < num_dimensions; ++dim) { + params.conv_desc + .set_zero_padding( + static_cast(effective_num_dimensions - dim - 1), + window.dimensions(dim).padding_low()) + .set_filter_stride( + static_cast(effective_num_dimensions - dim - 1), + window.dimensions(dim).stride()) + .set_dilation_rate( + static_cast(effective_num_dimensions - dim - 1), + window.dimensions(dim).window_dilation()); + } + + BatchDescriptor& output_descriptor = params.output_descriptor; + output_descriptor = BatchDescriptor(effective_num_dimensions); + output_descriptor.set_layout(output_dl) + .set_feature_map_count( + output_shape->dimensions(dnums.output_feature_dimension())) + .set_count(output_shape->dimensions(dnums.output_batch_dimension())); + for (int dim = 0; dim < num_dimensions; ++dim) { + output_descriptor.set_spatial_dim( + static_cast(effective_num_dimensions - dim - 1), + output_shape->dimensions(dnums.output_spatial_dimensions(dim))); + } + + // Add a singleton dimension in the 1D convolution case. + if (num_dimensions == 1) { + input_descriptor.set_spatial_dim(static_cast(0), 1); + output_descriptor.set_spatial_dim(static_cast(0), 1); + filter_descriptor.set_spatial_dim(static_cast(0), 1); + params.conv_desc.set_zero_padding(static_cast(0), 0) + .set_filter_stride(static_cast(0), 1); + } + return params; } -} // anonymous namespace - Status RunCudnnConv(const HloCustomCallInstruction* conv, absl::Span operand_buffers, se::DeviceMemoryBase result_buffer, @@ -408,24 +372,20 @@ Status RunCudnnConv(const HloCustomCallInstruction* conv, TF_ASSIGN_OR_RETURN(CudnnConvParams params, GetCudnnConvParams(conv, operand_buffers, result_buffer)); - if (options.algo_override) { - params.algorithm = AlgorithmConfig(*options.algo_override); - } - PrimitiveType output_primitive_type = conv->shape().tuple_shapes(0).element_type(); switch (output_primitive_type) { case F16: return RunCudnnConvImpl(params, scratch_allocator, stream, - options.profile_result); + options); case F32: return RunCudnnConvImpl(params, scratch_allocator, stream, - options.profile_result); + options); case F64: return RunCudnnConvImpl(params, scratch_allocator, stream, - options.profile_result); + options); default: - LOG(FATAL) << ShapeUtil::HumanString(*params.output_shape); + LOG(FATAL) << conv->ToString(); } } diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h index 25b2461ca61..14124a08369 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_RUNNER_H_ #include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/status.h" @@ -36,6 +37,41 @@ struct RunConvOptions { absl::optional algo_override; }; +// Implementation struct exposed for debugging and log analysis. +struct CudnnConvParams { + // Here are the fields related to cuDNN's fused convolution. The result thus + // is defined as: + // activation(conv_result_scale * conv(x, w) + + // side_input_scale * side_input + broadcast(bias)) + // + // The most common fused conv is conv forward + relu/identity, for example. + // + // bias_buf is a single-dimensional array, with the length equal to the number + // of output features. It'll be broadcasted to the output shape in order to be + // added to the final results. + // + // side_input_buf, if valid, must have the same shape as the output buffer. + struct FusionParams { + se::dnn::ActivationMode mode; + double side_input_scale; + se::DeviceMemoryBase bias_buf; + se::DeviceMemoryBase side_input_buf; // nullable + }; + + CudnnConvKind kind; + se::dnn::BatchDescriptor input_descriptor; + se::dnn::FilterDescriptor filter_descriptor; + se::dnn::BatchDescriptor output_descriptor; + se::DeviceMemoryBase input_buf; + se::DeviceMemoryBase filter_buf; + se::DeviceMemoryBase output_buf; + se::dnn::ConvolutionDescriptor conv_desc; + se::dnn::AlgorithmConfig algorithm; + double conv_result_scale; + + absl::optional fusion; +}; + // This file contains low-level routines for running cudnn convolutions. // Calls into cudnn to run the specified convolution. @@ -62,6 +98,12 @@ Status RunCudnnConv(const HloCustomCallInstruction* conv, se::ScratchAllocator* scratch_allocator, se::Stream* stream, RunConvOptions = {}); +// Implementation details exposed for debugging and log analysis. +StatusOr GetCudnnConvParams( + const HloCustomCallInstruction* conv, + absl::Span operand_buffers, + se::DeviceMemoryBase result_buffer); + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cusolver_context.cc b/tensorflow/compiler/xla/service/gpu/cusolver_context.cc index 923b7bc4528..4103a720c98 100644 --- a/tensorflow/compiler/xla/service/gpu/cusolver_context.cc +++ b/tensorflow/compiler/xla/service/gpu/cusolver_context.cc @@ -91,12 +91,14 @@ StatusOr CusolverContext::Create(se::Stream* stream) { TF_RETURN_IF_ERROR(CusolverStatusToStatus(cusolverDnCreate(&handle))); CusolverContext context(stream, handle); - // StreamExecutor really should just expose the Cuda stream to clients... - const cudaStream_t* cuda_stream = - CHECK_NOTNULL(reinterpret_cast( - stream->implementation()->GpuStreamMemberHack())); - TF_RETURN_IF_ERROR( - CusolverStatusToStatus(cusolverDnSetStream(handle, *cuda_stream))); + if (stream) { + // StreamExecutor really should just expose the Cuda stream to clients... + const cudaStream_t* cuda_stream = + CHECK_NOTNULL(reinterpret_cast( + stream->implementation()->GpuStreamMemberHack())); + TF_RETURN_IF_ERROR( + CusolverStatusToStatus(cusolverDnSetStream(handle, *cuda_stream))); + } return std::move(context); } @@ -131,17 +133,40 @@ CusolverContext::~CusolverContext() { #define DN_SOLVER_FN(method, type_prefix) cusolverDn##type_prefix##method -#define POTRF_BUFFER_SIZE_INSTANCE(T, type_prefix) \ - StatusOr CusolverContext::PotrfBufferSize( \ - se::blas::UpperLower uplo, int n, se::DeviceMemory A, int lda) { \ - int size = -1; \ - TF_RETURN_IF_ERROR(CusolverStatusToStatus(DN_SOLVER_FN( \ - potrf_bufferSize, type_prefix)(handle(), CUDABlasUpperLower(uplo), n, \ - ToDevicePointer(A), lda, &size))); \ - return size; \ +// Note: NVidia have promised that it is safe to pass 'nullptr' as the argument +// buffers to cuSolver buffer size methods and this will be a documented +// behavior in a future cuSolver release. +StatusOr CusolverContext::PotrfBufferSize(PrimitiveType type, + se::blas::UpperLower uplo, + int n, int lda) { + int size = -1; + switch (type) { + case F32: { + TF_RETURN_IF_ERROR(CusolverStatusToStatus(cusolverDnSpotrf_bufferSize( + handle(), CUDABlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size))); + break; + } + case F64: { + TF_RETURN_IF_ERROR(CusolverStatusToStatus(cusolverDnDpotrf_bufferSize( + handle(), CUDABlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size))); + break; + } + case C64: { + TF_RETURN_IF_ERROR(CusolverStatusToStatus(cusolverDnCpotrf_bufferSize( + handle(), CUDABlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size))); + break; + } + case C128: { + TF_RETURN_IF_ERROR(CusolverStatusToStatus(cusolverDnZpotrf_bufferSize( + handle(), CUDABlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size))); + break; + } + default: + return InvalidArgument("Invalid type for cholesky decomposition: %s", + PrimitiveType_Name(type)); } - -CALL_LAPACK_TYPES(POTRF_BUFFER_SIZE_INSTANCE); + return size; +} #define POTRF_INSTANCE(T, type_prefix) \ Status CusolverContext::Potrf( \ diff --git a/tensorflow/compiler/xla/service/gpu/cusolver_context.h b/tensorflow/compiler/xla/service/gpu/cusolver_context.h index 68b5fb14c6b..c3d075c47c7 100644 --- a/tensorflow/compiler/xla/service/gpu/cusolver_context.h +++ b/tensorflow/compiler/xla/service/gpu/cusolver_context.h @@ -32,6 +32,8 @@ namespace gpu { class CusolverContext { public: + // stream may be nullptr, in which case the context can only be used for + // buffer size queries. static StatusOr Create(se::Stream* stream); CusolverContext() = default; ~CusolverContext(); @@ -63,17 +65,9 @@ class CusolverContext { se::DeviceMemory> workspace); // Returns the size of the `workspace` required by Potrf, in number of - // elements of size T. - StatusOr PotrfBufferSize(se::blas::UpperLower uplo, int n, - se::DeviceMemory dev_A, int lda); - StatusOr PotrfBufferSize(se::blas::UpperLower uplo, int n, - se::DeviceMemory dev_A, int lda); - StatusOr PotrfBufferSize(se::blas::UpperLower uplo, int n, - se::DeviceMemory> dev_A, - int lda); - StatusOr PotrfBufferSize(se::blas::UpperLower uplo, int n, - se::DeviceMemory> dev_A, - int lda); + // elements of `type`. + StatusOr PotrfBufferSize(PrimitiveType type, se::blas::UpperLower uplo, + int n, int lda); private: CusolverContext(se::Stream* stream, cusolverDnHandle_t handle); diff --git a/tensorflow/compiler/xla/service/gpu/cusolver_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cusolver_rewriter.cc index 2ba6e8fc3c5..64c3c319321 100644 --- a/tensorflow/compiler/xla/service/gpu/cusolver_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cusolver_rewriter.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" -#include "tensorflow/compiler/xla/service/gpu/scratch_allocator.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -31,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/stream_executor/blas.h" namespace xla { @@ -48,7 +46,6 @@ void SetFortranLayout(Shape* shape) { } StatusOr CreateCholesky(CusolverContext* context, - ScratchAllocator* allocator, HloInstruction* operand, const CholeskyOptions& options, const OpMetadata& metadata) { @@ -67,39 +64,8 @@ StatusOr CreateCholesky(CusolverContext* context, se::blas::UpperLower uplo = options.lower() ? se::blas::UpperLower::kLower : se::blas::UpperLower::kUpper; int64 workspace_size; // Number of elements of size a_shape.element_type() - switch (a_shape.element_type()) { - case F32: { - TF_ASSIGN_OR_RETURN(auto a, - allocator->Allocate(context->stream(), n * n)); - TF_ASSIGN_OR_RETURN(workspace_size, - context->PotrfBufferSize(uplo, n, a, n)); - break; - } - case F64: { - TF_ASSIGN_OR_RETURN( - auto a, allocator->Allocate(context->stream(), n * n)); - TF_ASSIGN_OR_RETURN(workspace_size, - context->PotrfBufferSize(uplo, n, a, n)); - break; - } - case C64: { - TF_ASSIGN_OR_RETURN(auto a, allocator->Allocate>( - context->stream(), n * n)); - TF_ASSIGN_OR_RETURN(workspace_size, - context->PotrfBufferSize(uplo, n, a, n)); - break; - } - case C128: { - TF_ASSIGN_OR_RETURN(auto a, allocator->Allocate>( - context->stream(), n * n)); - TF_ASSIGN_OR_RETURN(workspace_size, - context->PotrfBufferSize(uplo, n, a, n)); - break; - } - default: - return InvalidArgument("Invalid type for cholesky decomposition: %s", - a_shape.ToString()); - } + TF_ASSIGN_OR_RETURN(workspace_size, context->PotrfBufferSize( + a_shape.element_type(), uplo, n, n)); // TODO(phawkins): Ideally we would relax this constraint. What we actually // want is that: @@ -131,7 +97,6 @@ StatusOr CreateCholesky(CusolverContext* context, // Tries to rewrite a single convolution into a call to cudnn. StatusOr RunOnInstruction(CusolverContext* context, - ScratchAllocator* allocator, HloInstruction* instruction) { if (instruction->opcode() != HloOpcode::kCholesky) { return false; @@ -139,7 +104,7 @@ StatusOr RunOnInstruction(CusolverContext* context, TF_ASSIGN_OR_RETURN( HloInstruction * custom_call, - CreateCholesky(context, allocator, instruction->mutable_operand(0), + CreateCholesky(context, instruction->mutable_operand(0), instruction->cholesky_options(), instruction->metadata())); VLOG(1) << "Replacing " << instruction->ToString() << " with " @@ -167,41 +132,18 @@ StatusOr CusolverRewriter::RunOnComputation(HloComputation* computation) { return false; } - // Create a stream for us to do our work on. We don't really need to do any - // work, just allocate memory, but that's the cuSolver API. - se::Stream stream{stream_exec_}; - stream.Init(); - const auto device_ordinal = stream_exec_->device_ordinal(); - - // allocator either points to this->allocator_ or, if that's null, to a - // se::StreamExecutorMemoryAllocator for stream_exec_. - se::DeviceMemoryAllocator* allocator; - absl::optional se_allocator; - if (allocator_ != nullptr) { - allocator = allocator_; - } else { - se_allocator.emplace(stream_exec_->platform(), - absl::Span({stream_exec_})); - allocator = &*se_allocator; - } - ScratchAllocator scratch_allocator(device_ordinal, allocator); - TF_ASSIGN_OR_RETURN(CusolverContext context, - CusolverContext::Create(&stream)); + CusolverContext::Create(/*stream=*/nullptr)); bool changed = false; for (HloInstruction* instruction : cusolver_calls) { - TF_ASSIGN_OR_RETURN( - bool result, - RunOnInstruction(&context, &scratch_allocator, instruction)); + TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(&context, instruction)); changed |= result; } return changed; } -CusolverRewriter::CusolverRewriter(se::StreamExecutor* stream_exec, - se::DeviceMemoryAllocator* allocator) - : stream_exec_(stream_exec), allocator_(allocator) {} +CusolverRewriter::CusolverRewriter() = default; StatusOr CusolverRewriter::Run(HloModule* module) { bool changed = false; diff --git a/tensorflow/compiler/xla/service/gpu/cusolver_rewriter.h b/tensorflow/compiler/xla/service/gpu/cusolver_rewriter.h index d8c2cc55872..8be7cd5c947 100644 --- a/tensorflow/compiler/xla/service/gpu/cusolver_rewriter.h +++ b/tensorflow/compiler/xla/service/gpu/cusolver_rewriter.h @@ -29,17 +29,13 @@ namespace gpu { // Rewrites Cholesky calls into CustomCall HLOs that call into cuSolver. class CusolverRewriter : public HloModulePass { public: - CusolverRewriter(se::StreamExecutor* stream_exec, - se::DeviceMemoryAllocator* allocator); + CusolverRewriter(); absl::string_view name() const override { return "cusolver-rewriter"; } StatusOr Run(HloModule* module) override; private: StatusOr RunOnComputation(HloComputation* computation); - - se::StreamExecutor* stream_exec_; // never null - se::DeviceMemoryAllocator* allocator_; // may be null }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/custom_call_thunk.cc b/tensorflow/compiler/xla/service/gpu/custom_call_thunk.cc index f0f3152ac98..b521e36108b 100644 --- a/tensorflow/compiler/xla/service/gpu/custom_call_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/custom_call_thunk.cc @@ -50,7 +50,7 @@ CustomCallThunk::CustomCallThunk( Status CustomCallThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, - HloExecutionProfiler* profiler) { + const RunId& /*run_id*/, HloExecutionProfiler* profiler) { // gpu_stream is CUstream or e.g. the equivalent type in ROCm. auto gpu_stream = se::gpu::AsGpuStreamValue(stream); auto typed_call_target = diff --git a/tensorflow/compiler/xla/service/gpu/custom_call_thunk.h b/tensorflow/compiler/xla/service/gpu/custom_call_thunk.h index 9011fa26ffa..6db7950e8e0 100644 --- a/tensorflow/compiler/xla/service/gpu/custom_call_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/custom_call_thunk.h @@ -45,7 +45,7 @@ class CustomCallThunk : public Thunk { const HloInstruction* instr); Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) override; private: diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc index 1609f0d60c4..55300a8d33a 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc @@ -107,7 +107,7 @@ FftThunk::FftThunk(FftType fft_type, absl::Span fft_length, output_shape_(output_shape) {} Status FftThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& /*run_id*/, HloExecutionProfiler* profiler) { VLOG(3) << "FFT type: " << FftTypeToString(fft_type_); VLOG(3) << "Input shape: " << ShapeUtil::HumanStringWithLayout(input_shape_); diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.h b/tensorflow/compiler/xla/service/gpu/fft_thunk.h index f653e4f12fe..12718db873b 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.h @@ -73,7 +73,7 @@ class FftThunk : public Thunk { // Does the FFT for the thunk on "stream". Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) override; private: diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.cc b/tensorflow/compiler/xla/service/gpu/for_thunk.cc index 88f0b4d71c9..ee47fea38c3 100644 --- a/tensorflow/compiler/xla/service/gpu/for_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.cc @@ -41,7 +41,7 @@ Status ForThunk::Initialize(const GpuExecutable& executable, } Status ForThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) { VLOG(2) << "Executing ForThunk with " << loop_limit_ << " iters for " << (hlo_instruction() ? hlo_instruction()->ToString() : ""); @@ -49,8 +49,8 @@ Status ForThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, for (int64 i = 0; i < loop_limit_; ++i) { profiler->StartHloComputation(); // Invoke loop body thunk sequence. - TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(buffer_allocations, - stream, profiler)); + TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream( + buffer_allocations, stream, run_id, profiler)); profiler->FinishHloComputation(hlo_instruction()->while_body()); } return Status::OK(); diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.h b/tensorflow/compiler/xla/service/gpu/for_thunk.h index c2d39071b29..e3bef820e57 100644 --- a/tensorflow/compiler/xla/service/gpu/for_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.h @@ -40,7 +40,7 @@ class ForThunk : public Thunk { Status Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) override; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) override; private: diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index 9bbe1ab5a38..237c065cd73 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -446,7 +446,7 @@ GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer, implements_whole_instruction_(implements_whole_instruction) {} Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& /*run_id*/, HloExecutionProfiler* profiler) { auto fn = [&]() { switch (output_shape_.element_type()) { diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h index e4f07d04820..3cba1d5e169 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h @@ -50,7 +50,7 @@ class GemmThunk : public Thunk { // Does the gemm operation for the thunk on "stream", which must be non-null. Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) override; private: diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index dec40c5e49c..8be1655367f 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -119,7 +119,8 @@ Status GpuExecutable::ExecuteThunks( op_annotation.emplace( thunk->hlo_instruction()->ToString(HloPrintOptions::Canonical()), absl::StrCat("#tf_op=", hlo->metadata().op_name(), - ",hlo_op=", hlo->name(), "#")); + ",hlo_op=", hlo->name(), + ",hlo_module=", hlo->GetModule()->name(), "#")); } TF_RETURN_IF_ERROR(thunk->Initialize(*this, executor)); @@ -136,7 +137,8 @@ Status GpuExecutable::ExecuteThunks( << thunk->hlo_instruction()->ToString() << " on stream " << stream_no; TF_RETURN_IF_ERROR( - thunk->ExecuteOnStream(buffer_allocations, stream, &profiler)); + thunk->ExecuteOnStream(buffer_allocations, stream, + run_options->run_options().run_id(), &profiler)); if (thunk_schedule_->Depended(thunk)) { auto finish_event = absl::make_unique(main_stream->parent()); finish_event->Init(); diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc index 676380c3b10..dbf590591c3 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc @@ -28,7 +28,7 @@ InfeedThunk::InfeedThunk( : Thunk(Kind::kInfeed, hlo_instruction), infeed_slices_(infeed_slices) {} Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& /*run_id*/, HloExecutionProfiler* profiler) { VLOG(2) << "Infeeding to GPU: " << hlo_instruction()->ToString(); diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h index 59487e245b7..50d9c53d957 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h @@ -41,7 +41,7 @@ class InfeedThunk : public Thunk { InfeedThunk& operator=(const InfeedThunk&) = delete; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) override; private: diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 957a2f00723..c6919740c87 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -312,17 +312,55 @@ llvm::Value* EmitPrintf(absl::string_view fmt, arguments_ptr}); } +// Helper function to emit call to AMDGPU shfl_down function. +llvm::Value* EmitAMDGPUShflDown(llvm::Value* value, llvm::Value* offset, + llvm::IRBuilder<>* b) { + llvm::Module* module = b->GetInsertBlock()->getModule(); + CHECK_EQ(value->getType()->getPrimitiveSizeInBits(), 32); + auto* i32_ty = b->getInt32Ty(); + llvm::FunctionCallee shfl_fn = module->getOrInsertFunction( + llvm_ir::AsStringRef("__ockl_readuplane_i32"), + llvm::FunctionType::get(/*Result=*/i32_ty, {i32_ty, i32_ty}, + /*isVarArg=*/false)); + // AMDGPU device function requires first argument as i32. + llvm::Value* result = + b->CreateCall(shfl_fn, {b->CreateBitCast(value, i32_ty), offset}); + // AMDGPU device function always returns an i32 type. + return b->CreateBitCast(result, value->getType()); +} + +// Helper function to emit call to NVPTX shfl_down intrinsic. +llvm::Value* EmitNVPTXShflDown(llvm::Value* value, llvm::Value* offset, + llvm::IRBuilder<>* b) { + llvm::Module* module = b->GetInsertBlock()->getModule(); + llvm::Intrinsic::ID llvm_intrinsic_id; + CHECK_EQ(value->getType()->getPrimitiveSizeInBits(), 32); + if (value->getType()->isFloatTy()) { + llvm_intrinsic_id = llvm::Intrinsic::nvvm_shfl_sync_down_f32; + } else { + llvm_intrinsic_id = llvm::Intrinsic::nvvm_shfl_sync_down_i32; + } + llvm::Function* intrinsic = + llvm::Intrinsic::getDeclaration(module, llvm_intrinsic_id, {}); + return b->CreateCall( + intrinsic, {b->getInt32(-1), value, offset, b->getInt32(kWarpSize - 1)}); +} + llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, llvm::IRBuilder<>* builder) { int bit_width = value->getType()->getPrimitiveSizeInBits(); - llvm::Value* all_warps_mask = builder->getInt32(-1); + llvm::Module* module = builder->GetInsertBlock()->getModule(); + llvm::Triple target_triple = llvm::Triple(module->getTargetTriple()); // Special case for efficiency if (value->getType()->isFloatTy() && bit_width == 32) { - return EmitCallToTargetIntrinsic( - TargetIntrinsicID::kShflDownF32, - {all_warps_mask, value, offset, builder->getInt32(kWarpSize - 1)}, {}, - builder); + if (target_triple.isNVPTX()) { + return EmitNVPTXShflDown(value, offset, builder); + } else if (target_triple.getArch() == llvm::Triple::amdgcn) { + return EmitAMDGPUShflDown(value, offset, builder); + } else { + LOG(FATAL) << "Invalid triple " << target_triple.str(); + } } // We must split values wider than 32 bits as the "shfl" instruction operates @@ -334,14 +372,17 @@ llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, builder->getIntNTy(32 * num_segments)), llvm::VectorType::get(builder->getInt32Ty(), num_segments)); for (int i = 0; i < num_segments; ++i) { - x = builder->CreateInsertElement( - x, - EmitCallToTargetIntrinsic( - TargetIntrinsicID::kShflDownI32, - {all_warps_mask, builder->CreateExtractElement(x, i), offset, - builder->getInt32(kWarpSize - 1)}, - {}, builder), - i); + llvm::Value* insert_val; + if (target_triple.isNVPTX()) { + insert_val = EmitNVPTXShflDown(builder->CreateExtractElement(x, i), + offset, builder); + } else if (target_triple.getArch() == llvm::Triple::amdgcn) { + insert_val = EmitAMDGPUShflDown(builder->CreateExtractElement(x, i), + offset, builder); + } else { + LOG(FATAL) << "Invalid triple " << target_triple.str(); + } + x = builder->CreateInsertElement(x, insert_val, i); } return builder->CreateBitCast( builder->CreateTrunc( diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc index fbe22e3a18e..c85b35ed386 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc @@ -70,7 +70,7 @@ void KernelThunk::SetLaunchDimensions(const LaunchDimensions& launch_dims) { } Status KernelThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& /*run_id*/, HloExecutionProfiler* profiler) { // Load the kernel. se::StreamExecutor* executor = stream->parent(); diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h index 2cea89e4e2a..e867904bcf2 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h @@ -63,7 +63,7 @@ class KernelThunk : public Thunk { // Executes the kernel for the thunk on "stream", which must be non-null. Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) override; private: diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD index ca42807edd1..d025fc99275 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD @@ -1,7 +1,6 @@ -licenses(["notice"]) # Apache 2.0 - package( default_visibility = [":friends"], + licenses = ["notice"], # Apache 2.0 ) package_group( diff --git a/tensorflow/compiler/xla/service/gpu/memset_thunk.cc b/tensorflow/compiler/xla/service/gpu/memset_thunk.cc index 9fd6cf7157e..7a5b14be7b0 100644 --- a/tensorflow/compiler/xla/service/gpu/memset_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/memset_thunk.cc @@ -23,7 +23,7 @@ namespace gpu { Status MemzeroThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, - HloExecutionProfiler* profiler) { + const RunId& /*run_id*/, HloExecutionProfiler* profiler) { se::DeviceMemoryBase dest_data = buffer_allocations.GetDeviceAddress(dest_); auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); stream->ThenMemZero(&dest_data, dest_data.size()); @@ -32,7 +32,7 @@ Status MemzeroThunk::ExecuteOnStream( Status Memset32BitValueThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, - HloExecutionProfiler* profiler) { + const RunId& /*run_id*/, HloExecutionProfiler* profiler) { se::DeviceMemoryBase dest_data = buffer_allocations.GetDeviceAddress(dest_); auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); stream->ThenMemset32(&dest_data, value_, dest_data.size()); diff --git a/tensorflow/compiler/xla/service/gpu/memset_thunk.h b/tensorflow/compiler/xla/service/gpu/memset_thunk.h index d1fec0bd76b..727f2441f39 100644 --- a/tensorflow/compiler/xla/service/gpu/memset_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/memset_thunk.h @@ -37,7 +37,7 @@ class MemzeroThunk : public Thunk { : Thunk(Kind::kMemzero, hlo), dest_(dest) {} Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) override; private: @@ -54,7 +54,7 @@ class Memset32BitValueThunk : public Thunk { : Thunk(Kind::kMemset32BitValue, hlo), value_(value), dest_(dest) {} Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId&, HloExecutionProfiler* profiler) override; private: diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc index c00edae9540..89def76afe3 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc @@ -15,12 +15,24 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h" -#include "tensorflow/compiler/xla/util.h" - #if GOOGLE_CUDA +#include +#include +#include +#include + +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" -#include "absl/synchronization/blocking_counter.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" #include "third_party/nccl/nccl.h" +#include "tensorflow/compiler/xla/refcounting_hash_map.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/blocking_counter.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/stream_executor/cuda/cuda_activation.h" @@ -29,6 +41,25 @@ limitations under the License. namespace xla { namespace gpu { +// This file runs collective ops (i.e. ops that communicate between multiple +// GPUs) using NCCL. Currently only kAllReduce is implemented. +// +// Here's a high-level overview of how running an op works. +// +// - Multiple threads call NcclAllReduceThunk::ExecuteOnStream. +// - All threads that "go together" (i.e. are participating in the "same" +// collective op) choose the same Rendezvous object from a global map. +// - Once all threads have arrived at the Rendezvous, we know exactly which +// GPUs are participating in the op, so we get or create a NcclClique +// containing those GPUs. +// - We perform the NCCL operation using the clique, then destroy the +// Rendezvous. The clique is cached, see below. +// +// Creating NCCL cliques is expensive, so we cache them. Our policy is, a thunk +// keeps alive all cliques it's ever used. When the thunk is destroyed, it +// releases its handle on the cliques, and cliques whose refcounts go to 0 are +// destroyed. + /* static */ bool NcclAllReduceThunk::NcclIsEnabled() { #if GOOGLE_CUDA return true; @@ -40,17 +71,145 @@ namespace gpu { #if GOOGLE_CUDA namespace { -// GPU-replica-driving host threads (i.e. the threads that call -// GpuExecutable::Execute) build up this structure to describe their -// participating replica, and then call to -// GlobalRendezvousManager::SubmitParticipant. -struct ParticipantData { - // Number of replicas particiating in the AllReduce. - int64 replica_count; +// Functions to translate an ncclResult_t/cudaError_t to a Status object. Used +// by the macros below. +Status TranslateStatus(ncclResult_t s, const char* file, int64 line, + const char* expr) { + if (s == ncclSuccess) { + return Status::OK(); + } + return tensorflow::errors::Internal( + absl::StrFormat("%s:%d: NCCL operation %s failed: %s", file, line, expr, + ncclGetErrorString(s))); +} +Status TranslateStatus(cudaError_t s, const char* file, int64 line, + const char* expr) { + if (s == cudaSuccess) { + return Status::OK(); + } + return tensorflow::errors::Internal( + absl::StrFormat("%s:%d: CUDA operation %s failed: %s", file, line, expr, + cudaGetErrorString(s))); +} + +// Macros to return or warn on CUDA/NCCL errors. (The same macro works for both +// NCCL and CUDA errors.) +// +// It's tempting to say these macros belong in an XLA header somewhere, but in +// practice we don't do much direct-to-CUDA-API stuff outside of this file. +#define XLA_CUDA_RETURN_IF_ERROR(expr) \ + do { \ + Status s = ::xla::gpu::TranslateStatus(expr, __FILE__, __LINE__, #expr); \ + if (!s.ok()) { \ + return s; \ + } \ + } while (0) + +#define XLA_CUDA_WARN_IF_ERROR(expr) \ + do { \ + Status s = ::xla::gpu::TranslateStatus(expr, __FILE__, __LINE__, #expr); \ + if (!s.ok()) { \ + LOG(ERROR) << s.ToString(); \ + } \ + } while (0) + +// RAII class owning a ncclComm_t, ensuring it doesn't leak. +class NcclComm { + public: + explicit NcclComm(ncclComm_t comm) : comm_(comm) {} + + // Movable, but not copyable. + NcclComm(NcclComm&& c) noexcept : comm_(c.comm_) { c.comm_.reset(); } + NcclComm& operator=(NcclComm&& c) noexcept { + comm_ = c.comm_; + c.comm_.reset(); + return *this; + } + NcclComm(const NcclComm&) = delete; + NcclComm& operator=(const NcclComm&) = delete; + + ~NcclComm() { + if (comm_.has_value() && *comm_ != nullptr) { + VLOG(3) << absl::StreamFormat("Destroying comm %p", *comm_); + XLA_CUDA_WARN_IF_ERROR(ncclCommDestroy(*comm_)); + } + } + + ncclComm_t comm() { return *comm_; } + + private: + absl::optional comm_; +}; + +// Key that identifies a particular Rendezvous object in our global hashtable. +// This determines which calls to ExecuteOnStream communicate with each other. +// The rules are as follows. +// +// * Only ops with the same RunId can communicate with each other. (This is the +// whole purpose of RunId). +// +// * Only ops with the same opcode can communicate with each other. At the +// moment we only support kAllReduce, so we don't check for this explicitly. +// +// * For cross-module all-reduces (i.e. instr->all_reduce_id().has_value()), +// only ops with the same value for all_reduce_id() can communicate with each +// other. +// +// * For cross-replica (i.e. same-module) all-reduces (i.e. +// !all_reduce_id().has_value()), only ops from the same module (as identified +// by its unique_id()) can communicate with each other. +// +struct RendezvousKey { + enum AllReduceKind { + kCrossModule, + kCrossReplica, + }; + + explicit RendezvousKey(const RunId& run_id, + const HloAllReduceInstruction* instr) + : run_id(run_id) { + std::tie(all_reduce_kind, op_id) = + instr->all_reduce_id().has_value() + ? std::make_pair(kCrossModule, instr->all_reduce_id().value()) + : std::make_pair( + kCrossReplica, + static_cast(instr->GetModule()->unique_id())); + } + + template + friend H AbslHashValue(H h, const RendezvousKey& k) { + return H::combine(std::move(h), k.run_id, + static_cast(k.all_reduce_kind), k.op_id); + } + friend bool operator==(const RendezvousKey& a, const RendezvousKey& b) { + return a.run_id == b.run_id && a.all_reduce_kind == b.all_reduce_kind && + a.op_id == b.op_id; + } + friend bool operator!=(const RendezvousKey& a, const RendezvousKey& b) { + return !(a == b); + } + + string ToString() const { + return absl::StrFormat( + "RendezvousKey{run_id=%s, all_reduce_kind=%d, op_id=%d}", + run_id.ToString(), static_cast(all_reduce_kind), op_id); + } + + RunId run_id; + AllReduceKind all_reduce_kind; + int64 op_id; +}; + +// Encapsulates parameters to Rendezvous::SubmitParticipant. +struct ParticipantData { + explicit ParticipantData(RendezvousKey rendezvous_key) + : rendezvous_key(rendezvous_key) {} + + int64 replica_count; // Number of GPUs particiating in the AllReduce. int64 element_count; int64 device_ordinal; - int64 generation_counter; + RendezvousKey rendezvous_key; // TODO(b/125951860): We should vet that we're buffer allocating such that // source_buffer == destination_buffer if that avoids a NCCL copy (will depend @@ -60,333 +219,301 @@ struct ParticipantData { se::DeviceMemoryBase destination_data; se::Stream* stream; - NcclAllReduceThunk* originator; - string ToString() const { return absl::StrFormat( "ParticipantData{replica_count=%d, element_count=%d, " - "device_ordinal=%d, generation_counter=%d, stream=%p, originator=%p}", - replica_count, element_count, device_ordinal, generation_counter, - stream, originator); + "rendezvous_key=%s, device_ordinal=%d, stream=%p}", + replica_count, element_count, rendezvous_key.ToString(), device_ordinal, + stream); } }; -// Class that gets instantiated as a singleton in GetGlobalRendezvous() to -// coordinate participating threads in performing an AllReduce operation. -// -// This manager is responsible for establishing communication channels and -// ultimately enqueueing the NCCL library operation onto the participating -// streams. -// -// Implementation note: We make an effort to avoid initializing nccl -// communciation channels too often, as this is expensive. -// -// Ideally, we'd set up a nccl channel between each pair of devices that needs -// to communicate, and close each channel when the GPUs won't be communicating -// again "for a long time" (because channels hold memory on the GPU). As a -// simplification to this ideal, we adopt the following policy. -// -// - We maintain a set of GPUs that are "actively participating" in -// cross-device communications. That set of GPUs is always connected as a -// clique, using ncclCommInitAll. -// -// - When a NcclAllReduceThunk touches a new GPU, we tear down the old clique -// and build a new, bigger one. -// -// - All GPUs ever touched by a thunk are considered "actively in use" by that -// thunk until the thunk is destroyed. Destroying the thunk decrements the -// refcount of the GPUs it's touched, and if that refcount goes to 0 -// (meaning, some GPUs are no longer in use by any thunk), we tear down the -// clique and build a new, smaller one. -// -// This approximation is justified because: -// -// - Currently the only collective operation we support is AllReduce, which -// requires a clique. When we support point-to-point operations, we may not -// want to build a communication clique. -// -// - Tearing down and creating a new thunk is tantamount to running the whole -// XLA:GPU compiler. This is expensive, so shouldn't happen "too often" to -// cause thrashing here. -// -// - XLA executables already keep resources on the GPU tied to the lifetime of -// the executable (e.g. constants stored in GPU memory), so tying the -// lifetime of the nccl communication channels to the lifetime of the -// executable is consistent. -class GlobalRendezvousManager { - public: - // The GpuExecutable-executing threads call this in order to a) establish the - // all-reduce rendezvous and b) enqueue the AllReduce operation on the caller - // thread's associated stream (given in "participant"). - // - // Implementation note: since the rendezvous we're creating here is global, we - // try to be paranoid about the fact that the *correct* one is happening. In - // an ideal world we'd have some StreamExecutor se::Platform level construct - // that we could use for cross-device networking primitives (e.g. via a - // NetworkSupport interface) that could be shared between TensorFlow and XLA, - // but this is a reasonable stopgap measure to get multi-GPU-replica up and - // running properly for single-host, single-concurrent-XLA-module usage. - Status SubmitParticipant(ParticipantData participant); - - // Returns the current generation number of AllReduce operations. - // (Currently one AllReduce operation occurs per generation.) - int64 GetCurrentGeneration() { - tensorflow::mutex_lock lock(mutex_); - return current_generation_; +// Key for looking up a particular NCCL clique. This is just a set of unique +// device ordinals (i.e. GPU IDs). +struct NcclCliqueKey { + explicit NcclCliqueKey(absl::Span devices) + : devices(devices.begin(), devices.end()) { + absl::c_sort(this->devices); + CHECK(absl::c_adjacent_find(devices) == devices.end()) + << "Duplicate devices are not allowed: " + << absl::StrJoin(devices, ", "); } - // Increments the refcount of a GPU in our accounting of which devices are - // "actively participating" in cross-device operations. - // - // This doesn't actually do anything other than increment the refcount. If - // the GPU added here is novel, we'll rebuild the nccl communication clique - // when we actually go do the communication. - void AddrefParticipatingDevice(int device_ordinal); + template + friend H AbslHashValue(H h, const NcclCliqueKey& k) { + return H::combine(std::move(h), k.devices); + } + friend bool operator==(const NcclCliqueKey& a, const NcclCliqueKey& b) { + return a.devices == b.devices; + } - // Decrements the refcount of a set of GPUs in our accounting of which devices - // are "actively participating" in cross-device operations. - // - // If one or more GPUs' refcounts to go 0, we immediately destroy the whole - // nccl communication clique. We'll rebuild a new, smaller clique the next - // time it's used. - void DecrefParticipatingDevices(absl::Span device_ordinals); + std::vector devices; +}; - // Gets the set of devices that have a NCCL channel currently open. This is - // primarily for testing. - absl::flat_hash_set DevicesWithOpenNcclChannels() const { - absl::flat_hash_set devices; - tensorflow::mutex_lock lock(mutex_); - for (const auto& kv : comms_) { - devices.insert(kv.first); - } - return devices; +// Owns a clique of NCCL comms which can be used for collective operations among +// a particular set of GPUs. +// +// You must ensure this is not in an error state (i.e. status() is OK) before +// touching any other methods. +// +// (Usually allowing objects to be in a constructed-but-uninitialized state is +// an antipattern. We do it here because it allows us to have a +// RefcountingHashMap which contains and automatically constructs NcclCliques. +// This greatly simplifies the rest of this file.) +// +// Note that if you want to do a collective operation among a subset of these +// GPUs, you'll need a different clique. +class NcclClique { + public: + explicit NcclClique(absl::Span devices) + : devices_(devices.begin(), devices.end()) { + absl::c_sort(devices_); + status_ = Init(); + } + + Status status() { return status_; } + + absl::Span devices() { + TF_CHECK_OK(status_); + return devices_; + } + ncclComm_t comm(int64 device) { + int64 idx = std::distance(devices_.begin(), absl::c_find(devices_, device)); + return comms_.at(idx).comm(); + } + + // These methods let you acquire exclusive access to a NCCL clique, ensuring + // no other NCCL operations are taking place on the clique's comms. + // + // We disable thread-safety analysis because in common use, only the primary + // thread in a Rendezvous acquires this lock, and that makes thread-safety + // analysis unhappy. Tread carefully, you are playing with fire. + void Lock() NO_THREAD_SAFETY_ANALYSIS { + TF_CHECK_OK(status_); + mu_->lock(); + } + void Unlock() NO_THREAD_SAFETY_ANALYSIS { + TF_CHECK_OK(status_); + mu_->unlock(); } private: - // Destroys the current nccl communication clique and builds a new one - // connecting the given devices. - Status ReinitializeNcclClique(const absl::flat_hash_set& device_ordinals) - EXCLUSIVE_LOCKS_REQUIRED(mutex_); + Status Init() { + VLOG(3) << absl::StreamFormat( + "Initializing nccl comms for participant devices {%s}", + absl::StrJoin(devices_, ", ")); - // Called when all necessary participants are present, the functionality - // that's implemented by all executing threads lives in here. - Status DoAllReduce(ParticipantData data, ncclComm_t comm); + // Restore CUDA device after running this. XLA shouldn't care, but maybe + // another consumer does. + int initial_cuda_device; + XLA_CUDA_RETURN_IF_ERROR(cudaGetDevice(&initial_cuda_device)); + auto cuda_device_restorer = MakeCleanup( + [&] { XLA_CUDA_WARN_IF_ERROR(cudaSetDevice(initial_cuda_device)); }); - // Puts all state back into a "reset" state for the next generation of - // AllReduce requests. - void DeinitializeGeneration() EXCLUSIVE_LOCKS_REQUIRED(mutex_) { - participants_.clear(); - current_generation_++; - initialized_ = false; - done_ = absl::nullopt; + // When using ncclGroupStart/End it seems that the ncclComm_t's are not + // populated until the End() call. This unfortunately makes error handling + // tricky. + std::vector raw_comms(devices_.size(), nullptr); + ncclUniqueId nccl_id; + XLA_CUDA_RETURN_IF_ERROR(ncclGetUniqueId(&nccl_id)); + XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart()); + Status status = [&] { + for (int i = 0; i < devices_.size(); ++i) { + XLA_CUDA_RETURN_IF_ERROR(cudaSetDevice(devices_[i])); + XLA_CUDA_RETURN_IF_ERROR( + ncclCommInitRank(&raw_comms[i], devices_.size(), nccl_id, i)); + } + return Status::OK(); + }(); + // Always call ncclGroupEnd(). + XLA_CUDA_RETURN_IF_ERROR(ncclGroupEnd()); + + // Populate comms_ from the raw comms we created above. If we encountered + // an error above we'll later clear comms_ thus destroying any raw comms + // that were created before the error. + for (int i = 0; i < devices_.size(); ++i) { + VLOG(3) << absl::StreamFormat("Device %d assigned ncclComm %p", + devices_[i], raw_comms[i]); + CHECK(raw_comms[i] != nullptr || !status.ok()); + comms_.emplace_back(raw_comms[i]); + } + if (!status.ok()) { + comms_.clear(); + } + + return status; } - mutable tensorflow::mutex mutex_; - tensorflow::condition_variable all_participants_present_; - tensorflow::condition_variable deinitialized_; + Status status_; + std::vector devices_; + std::vector comms_; - Status initialize_status_ GUARDED_BY(mutex_); - std::vector participants_ GUARDED_BY(mutex_); - int64 current_generation_ GUARDED_BY(mutex_) = 0; - bool initialized_ GUARDED_BY(mutex_) = false; - - struct Comm { - explicit Comm(ncclComm_t nccl_comm) : nccl_comm(nccl_comm) {} - - // Movable, but not copyable. - Comm(Comm&& c) : nccl_comm(c.nccl_comm) { c.nccl_comm.reset(); } - Comm& operator=(Comm&& c) { - nccl_comm = c.nccl_comm; - c.nccl_comm.reset(); - return *this; - } - Comm(const Comm&) = delete; - Comm& operator=(const Comm&) = delete; - - absl::optional nccl_comm; - - ~Comm() { - if (nccl_comm.has_value()) { - VLOG(3) << absl::StreamFormat("Destroying comm %p", *nccl_comm); - ncclCommDestroy(*nccl_comm); - } - } - }; - // Communication handles for our NCCL clique. Key is device ordinal. - absl::flat_hash_map comms_ GUARDED_BY(mutex_); - - // Refcounts of which devices are "actively participating" in all-reduces. - // These devices don't necessarily have an open comm, but the next time we run - // an operation, we'll create a NCCL clique between all of them. - absl::flat_hash_map device_refcounts_ GUARDED_BY(mutex_); - - // The participating threads wait for this to count down in order to know we - // can begin the teardown process. - absl::optional done_; + // This mutex is in a unique_ptr so NcclClique can be movable. + std::unique_ptr mu_ = + absl::make_unique(); }; -Status GlobalRendezvousManager::SubmitParticipant(ParticipantData participant) { - auto all_participants_present = [this, &participant]() - EXCLUSIVE_LOCKS_REQUIRED(mutex_) -> bool { - return participants_.size() >= participant.replica_count; - }; +// Global cache of NCCL cliques. An entry in this map is kept alive as long as +// there's a reference to it somewhere. A Thunk holds a reference to each +// Clique it's ever used. +// +// A consequence of the fact that this is process-global is that we'll only ever +// have one clique alive for a given set of GPUs. This means that a process +// will never do two collective operations concurrently on the same set of GPUs. +RefcountingHashMap& GlobalNcclCliqueMap() { + static auto& m = *new RefcountingHashMap( + [](const NcclCliqueKey& key) { + return absl::make_unique(key.devices); + }); + return m; +} - { - tensorflow::mutex_lock lock(mutex_); +// The set of threads that want to do a collective op together all pick the same +// Rendezvous object out of the global cache and call SubmitParticipant. +// +// The Rendezvous instance handles waiting for all threads to join, ensuring +// that a clique exists for the desired set of GPUs, etc. +// +// Rendezvous objects can only be used once. +class Rendezvous { + public: + Rendezvous() = default; - // Spot check for consistent replica counts among submitting threads. - if (!participants_.empty() && - (participants_.back().replica_count != participant.replica_count || - participants_.back().originator != participant.originator)) { - return InvalidArgument( - "Running two XLA modules with AllReduces in parallel is not " - "supported. It is possible this is due to a bug where were try to " - "run two different AllReduces from the same module at once. " - "(Attempted a rendezvous with a different replica count from other " - "participants; existing: %s; submitted: %s)", - participants_.back().ToString(), participant.ToString()); - } - participants_.push_back(participant); + // Runs the all-reduce on the given thread. If successful, returns + // - a handle to the clique that was used, so that the caller may keep the + // clique alive if it chooses. + // - a BlockingCounter initialized to the number of participants, so that + // the caller can coordinate with the participants one last time if it + // chooses. This is useful for coordinating destruction of the Rendezvous. + StatusOr, + std::shared_ptr>> + SubmitParticipant(ParticipantData participant); - if (all_participants_present()) { - all_participants_present_.notify_all(); - } - } + private: + Status DoAllReduce(ParticipantData participant, ncclComm_t comm); + tensorflow::mutex mu_; + tensorflow::condition_variable all_participants_present_; + + bool initialized_ GUARDED_BY(mu_) = false; + absl::optional done_; + std::vector participants_ GUARDED_BY(mu_); + + // BlockingCounter returned by SubmitParticipant. Initialized by the primary + // thread. + std::shared_ptr returned_blocking_counter_; +}; + +// Global map of Rendezvous objects. A thread participating in a collective op +// looks up its Rendezvous in this map to find the other threads that it's +// participating with. +// +// Rendezvous objects are one-time use, so they're removed from this map once +// we're through with them. +RefcountingHashMap& GlobalRendezvousMap() { + static auto& m = *new RefcountingHashMap(); + return m; +} + +StatusOr, + std::shared_ptr>> +Rendezvous::SubmitParticipant(ParticipantData participant) { // We pull into our thread a) the communication handle and b) whether we're // the "primary" thread for this rendezvous -- the "primary" thread has some // additional responsibilities for setup/teardown. ncclComm_t comm; bool primary; + std::shared_ptr clique; + + // Releases the lock on the clique (held only by the primary thread). + Cleanup> clique_lock_releaser; { - tensorflow::mutex_lock lock(mutex_); - while (!all_participants_present()) { - // Once all the participants have arrived, all participating threads will - // cross this barrier, though only (the first) one will be the "primary". + tensorflow::mutex_lock lock(mu_); + CHECK(!initialized_); + + // Spot check for consistent replica counts among submitting threads. + if (!participants_.empty() && + (participants_.back().replica_count != participant.replica_count || + participants_.back().element_count != participant.element_count || + participants_.back().rendezvous_key != participant.rendezvous_key)) { + return InvalidArgument( + "Mismatch among all-reduce participants. Expected same " + "replica-count, element-count, and rendezvous-key but were %s and %s", + participants_.back().ToString(), participant.ToString()); + } + participants_.push_back(participant); + + // Wait here for all participants to arrive. + while (participants_.size() < participant.replica_count) { all_participants_present_.wait(lock); } + if (participants_.size() == participant.replica_count) { + all_participants_present_.notify_all(); + } - // Somebody will be the first -- that thread has some additional - // responsibilities. + // The first thread to get here has additional responsibilities, such as + // ensuring that there's a NCCL clique available for us to use. primary = !initialized_; - CHECK_EQ(participant.generation_counter, current_generation_); + // Look up or create the NCCL clique for this set of devices. + std::vector devices; + for (const auto& p : participants_) { + devices.push_back(p.device_ordinal); + } + clique = GlobalNcclCliqueMap()[NcclCliqueKey(devices)]; - // Bump the generation counter so the other threads know we've completed the - // global rendezvous and have set up the AllReduce. if (primary) { VLOG(3) << "Primary initializing accounting data."; initialized_ = true; done_.emplace(participant.replica_count); + returned_blocking_counter_ = + std::make_shared( + participant.replica_count); - // Check if all participants_ are in comms_. If not, we will rebuild the - // clique to include them. (This can't be spelled using absl::c_any_of - // because it needs to touch comms_ and tensorflow::mutex lacks an - // AssertHeld() function that would let us assert that the lambda is run - // while holding the lock.) - bool new_devices_found = false; - for (const auto& p : participants_) { - if (!comms_.contains(p.device_ordinal)) { - new_devices_found = true; - break; - } - } - - if (new_devices_found) { - absl::flat_hash_set new_clique_device_ordinals; - for (const auto& kv : comms_) { - new_clique_device_ordinals.insert(kv.first); - } - for (const auto& p : participants_) { - new_clique_device_ordinals.insert(p.device_ordinal); - } - - initialize_status_ = ReinitializeNcclClique(new_clique_device_ordinals); - VLOG(3) << "Done initializing communication channels; status: " - << initialize_status_; - if (!initialize_status_.ok()) { - DeinitializeGeneration(); - } - } + // Acquire exclusive access to the NCCL clique itself so that two + // unrelated collective operations won't try to use the clique + // concurrently. + clique->Lock(); + clique_lock_releaser = MakeCleanup([clique] { clique->Unlock(); }); } - if (!initialize_status_.ok()) { - // TODO(b/125951860): If this fails once, it will fail forever. - return initialize_status_; + if (!clique->status().ok()) { + VLOG(1) + << "SubmitParticipant failing because clique failed to initialize: " + << clique->status().ToString(); + return clique->status(); } - comm = *comms_.at(participant.device_ordinal).nccl_comm; + comm = clique->comm(participant.device_ordinal); // Drop the lock at the end of scope so other participants may enter. } VLOG(3) << "Performing all reduce from device ordinal: " << participant.device_ordinal; - Status all_reduce_status = DoAllReduce(participant, comm); - - VLOG(3) << "Waiting for all participants to complete enqueue."; + VLOG(3) << "This thread done with all-reduce op."; done_->DecrementCount(); + // The primary owns the lock on the NCCL clique. Hold it until all threads + // are done. (We'll release it when we return from this function.) if (primary) { - // Primary thread clears out the AllReduce state when everybody is done to - // make it clean-slate for any subsequent AllReduce request (e.g. number of - // replicas may change in the next request). - // - // Note surrounding TODOs for only reinitializing this when the replica - // count / participants actually change -- lots of "playing it safe" - // happening in this first cut. + VLOG(3) + << "Primary waiting for all participants to complete all-reduce op."; done_->Wait(); - VLOG(3) << "All participants completed enqueue."; - VLOG(3) << "Primary thread clearing."; - tensorflow::mutex_lock lock(mutex_); - DeinitializeGeneration(); - VLOG(3) << "Generation is now: " << current_generation_; - deinitialized_.notify_all(); - } else { - VLOG(3) << "Waiting to deinitialize."; - tensorflow::mutex_lock lock(mutex_); - while (initialized_) { - deinitialized_.wait(lock); - } + VLOG(3) << "All participants completed all-reduce op."; } VLOG(3) << "Returning status: " << all_reduce_status; - return all_reduce_status; + if (!all_reduce_status.ok()) { + return all_reduce_status; + } + return std::make_pair(clique, returned_blocking_counter_); } -Status GlobalRendezvousManager::ReinitializeNcclClique( - const absl::flat_hash_set& device_ordinals) { - comms_.clear(); - - std::vector ordinals_vec(device_ordinals.begin(), device_ordinals.end()); - std::vector comm_vec; - comm_vec.resize(device_ordinals.size()); - - VLOG(3) << absl::StreamFormat( - "Initializing nccl comms for participant devices {%s}", - absl::StrJoin(ordinals_vec, ", ")); - ncclResult_t result = ncclCommInitAll(comm_vec.data(), comm_vec.size(), - /*devlist=*/ordinals_vec.data()); - if (result != ncclSuccess) { - return InternalError( - "Failed to initialize NCCL communication channels for %d participants: " - "%s", - ordinals_vec.size(), ncclGetErrorString(result)); - } - - for (int64 i = 0; i < ordinals_vec.size(); ++i) { - VLOG(3) << absl::StreamFormat("Device ordinal %d assigned ncclComm %p", - ordinals_vec[i], comm_vec[i]); - CHECK(comms_.emplace(ordinals_vec[i], Comm{comm_vec[i]}).second); - } - return Status::OK(); -} - -Status GlobalRendezvousManager::DoAllReduce(ParticipantData participant, - ncclComm_t comm) { +Status Rendezvous::DoAllReduce(ParticipantData participant, ncclComm_t comm) { se::StreamExecutor* executor = participant.stream->parent(); se::cuda::ScopedActivateExecutorContext scoped_context(executor); cudaStream_t* cu_stream = reinterpret_cast( @@ -400,14 +527,12 @@ Status GlobalRendezvousManager::DoAllReduce(ParticipantData participant, "datatype=ncclFloat, op=ncclSum, comm=%p, stream=%p)", send_buffer, recv_buffer, participant.element_count, static_cast(comm), cu_stream); - ncclResult_t result = ncclAllReduce(send_buffer, recv_buffer, - /*count=*/participant.element_count, - /*datatype=*/ncclFloat, - /*op=*/ncclSum, - /*comm=*/comm, - /*stream=*/*cu_stream); - TF_RET_CHECK(ncclSuccess == result) - << "Failed to perform all-reduce: " << ncclGetErrorString(result); + XLA_CUDA_RETURN_IF_ERROR(ncclAllReduce(send_buffer, recv_buffer, + /*count=*/participant.element_count, + /*datatype=*/ncclFloat, + /*op=*/ncclSum, + /*comm=*/comm, + /*stream=*/*cu_stream)); VLOG(3) << "Done performing all reduce for ordinal: " << participant.device_ordinal; @@ -415,95 +540,100 @@ Status GlobalRendezvousManager::DoAllReduce(ParticipantData participant, return Status::OK(); } -void GlobalRendezvousManager::AddrefParticipatingDevice(int device_ordinal) { - // Addref'ing a device doesn't do anything other than increment its refcount. - // We'll update our nccl clique if necessary during the next call to - // SubmitParticipant. - tensorflow::mutex_lock lock(mutex_); - device_refcounts_[device_ordinal]++; -} - -void GlobalRendezvousManager::DecrefParticipatingDevices( - absl::Span device_ordinals) { - // Decref'ing devices causes us to destroy the nccl clique if any devices were - // removed due to having refcount 0. We'll rebuild the new, smaller clique - // during the next call to SubmitParticipant. - tensorflow::mutex_lock lock(mutex_); - bool removed_device = false; - for (int device_ordinal : device_ordinals) { - auto it = device_refcounts_.find(device_ordinal); - CHECK(it != device_refcounts_.end()); - it->second--; - if (it->second == 0) { - device_refcounts_.erase(it); - removed_device = true; - } - } - - if (removed_device) { - comms_.clear(); - } -} - -static GlobalRendezvousManager* GetGlobalRendezvous() { - static auto* manager = new GlobalRendezvousManager; - return manager; -} - } // namespace +// Extra data stored in NcclAllReduceThunk that we didn't want to expose in the +// header. In particular, this stores the thunk's cache of all NcclCliques it's +// ever used. This causes those cliques to stay alive as long as the thunk +// lives, which is how we avoid expensive reinitialization of NCCL cliques. +struct NcclAllReduceThunk::AuxData { + tensorflow::mutex mu; + absl::flat_hash_set> cliques GUARDED_BY(mu); +}; + /*static*/ absl::flat_hash_set NcclAllReduceThunk::DevicesWithOpenNcclChannels() { - return GetGlobalRendezvous()->DevicesWithOpenNcclChannels(); + absl::flat_hash_set devices; + GlobalNcclCliqueMap().ForEach( + [&](const NcclCliqueKey& k, const std::shared_ptr&) { + devices.insert(k.devices.begin(), k.devices.end()); + }); + return devices; } +NcclAllReduceThunk::NcclAllReduceThunk( + int64 replica_count, int64 element_count, + const BufferAllocation::Slice& source_buffer, + const BufferAllocation::Slice& destination_buffer, + const HloInstruction* all_reduce) + : Thunk(Thunk::kNcclAllReduce, all_reduce), + replica_count_(replica_count), + element_count_(element_count), + source_buffer_(source_buffer), + destination_buffer_(destination_buffer), + aux_data_(absl::make_unique()) {} + Status NcclAllReduceThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, - HloExecutionProfiler* profiler) { - auto* global_rendezvous = GetGlobalRendezvous(); + const RunId& run_id, HloExecutionProfiler* profiler) { + // Find or create the rendezvous for this collective operation. + RendezvousKey rendezvous_key( + run_id, Cast(hlo_instruction())); + std::shared_ptr rendezvous = + GlobalRendezvousMap()[rendezvous_key]; - ParticipantData participant; + ParticipantData participant(rendezvous_key); participant.replica_count = replica_count_; participant.element_count = element_count_; participant.device_ordinal = stream->parent()->device_ordinal(); - participant.generation_counter = global_rendezvous->GetCurrentGeneration(); participant.source_data = buffer_allocations.GetDeviceAddress(source_buffer_); participant.destination_data = buffer_allocations.GetDeviceAddress(destination_buffer_); participant.stream = stream; - participant.originator = this; - // We currently say that that all GPUs this thunk has ever touched are - // "actively participating" in cross-device operations, until the thunk itself - // is destroyed. - // - // This policy is an attempt to avoid thrashing the GPU (ncclCommInitAll is - // very expensive) while also freeing resources on the GPUs when we can. The - // idea is, creating new thunks is tantamount to running the whole XLA:GPU - // compiler stack, so that shouldn't happen terribly often. - bool new_device; + // Do the operation. + StatusOr, + std::shared_ptr>> + result = rendezvous->SubmitParticipant(participant); + if (!result.ok()) { + VLOG(1) << "NcclAllReduceThunk::ExecuteOnStream failed: " + << result.status().ToString(); + return result.status(); + } + + std::shared_ptr clique; + std::shared_ptr blocking_counter; + std::tie(clique, blocking_counter) = std::move(result).ValueOrDie(); + + // Keep the clique we used alive for as long as this Thunk lives. Creating + // new NCCL cliques is expensive, and this is how we avoid thrashing them. { - tensorflow::mutex_lock lock(mu_); - new_device = devices_seen_.insert(participant.device_ordinal).second; - } - if (new_device) { - GetGlobalRendezvous()->AddrefParticipatingDevice( - participant.device_ordinal); + tensorflow::mutex_lock lock(aux_data_->mu); + aux_data_->cliques.insert(std::move(clique)); } - return GetGlobalRendezvous()->SubmitParticipant(std::move(participant)); + // Drop our reference to the Rendezvous and wait for all other threads to do + // the same. If we didn't do this, one of the threads could run past this + // point, reenter ExecuteOnStream for another all-reduce, and attempt to reuse + // the Rendezvous! + // + // An alternative way of accomplishing this goal would be to implement + // RefcountingHashMap::erase() and call it during SubmitParticipant. But + // erase() is deceptively complex to implement correctly. + rendezvous.reset(); + blocking_counter->DecrementCount(); + blocking_counter->Wait(); + + return Status::OK(); } -NcclAllReduceThunk::~NcclAllReduceThunk() { - GetGlobalRendezvous()->DecrefParticipatingDevices( - std::vector(devices_seen_.begin(), devices_seen_.end())); -} +NcclAllReduceThunk::~NcclAllReduceThunk() {} #else Status NcclAllReduceThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, - HloExecutionProfiler* profiler) { + const RunId& run_id, HloExecutionProfiler* profiler) { return Unimplemented( "NCCL support is not available: this binary was not built with a CUDA " "compiler, which is necessary to build the NCCL source library."); @@ -516,7 +646,7 @@ NcclAllReduceThunk::DevicesWithOpenNcclChannels() { return {}; } -#endif // GOOGLE_CUDA +struct NcclAllReduceThunk::AuxData {}; NcclAllReduceThunk::NcclAllReduceThunk( int64 replica_count, int64 element_count, @@ -529,5 +659,7 @@ NcclAllReduceThunk::NcclAllReduceThunk( source_buffer_(source_buffer), destination_buffer_(destination_buffer) {} +#endif // GOOGLE_CUDA + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h index 9ff4fb187af..52ba4950565 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h @@ -50,6 +50,9 @@ class NcclAllReduceThunk : public Thunk { // TODO(b/125951860): Plumb more datatypes / reduction operators. Initial // implementation is simply F32 summation. + // + // TODO(b/125951860): Support all-reduces with replica groups, i.e. + // all-reduces that compute multiple sums across subsets of all replicas. NcclAllReduceThunk(int64 replica_count, int64 element_count, const BufferAllocation::Slice& source_buffer, const BufferAllocation::Slice& destination_buffer, @@ -57,18 +60,21 @@ class NcclAllReduceThunk : public Thunk { ~NcclAllReduceThunk() override; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) override; private: + // Extra data stored in NcclAllReduceThunk whose types we don't want exposed + // in the header file. (This is mainly because the implementation of + // NcclAllReduceThunk is different depending on whether CUDA is enabled in the + // build, and we don't want to expose *that* mess in the header.) + struct AuxData; + const int64 replica_count_; const int64 element_count_; const BufferAllocation::Slice source_buffer_; const BufferAllocation::Slice destination_buffer_; - - tensorflow::mutex mu_; - // Set of GPUs that ExecuteOnStream has been called on. - absl::flat_hash_set devices_seen_ GUARDED_BY(mu_); + std::unique_ptr aux_data_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index d8249e99d42..93fdc67d8ad 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -108,6 +108,7 @@ limitations under the License. #include "tensorflow/core/platform/tracing.h" #include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/stream_executor/cuda/cuda_diagnostics.h" +#include "tensorflow/stream_executor/cuda/ptxas_utils.h" namespace xla { namespace gpu { @@ -157,7 +158,7 @@ string GetLibdeviceDir(const HloModuleConfig& hlo_module_config) { "uses routines from libdevice.", hlo_module_config); - // GetCudaRotCandidates always inclues ".", but but if everything fails, we + // GetCudaRootCandidates always inclues ".", but but if everything fails, we // return it anyway. Better than returning the empty string. return "."; } @@ -265,7 +266,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, HloPassPipeline pipeline("conv_canonicalization"); pipeline.AddInvariantChecker(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); - pipeline.AddPass(stream_exec, device_allocator); + pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); @@ -520,7 +521,6 @@ StatusOr> NVPTXCompiler::RunBackend( BufferSizeBytesFunction(), /*color_alignment=*/ [](LogicalBuffer::Color) { return kXlaAllocatedBufferAlignBytes; }, - /*allow_input_output_aliasing=*/false, /*allocate_buffers_for_constants=*/true)); DumpHloModuleIfEnabled(*module, *buffer_assignment, "after_optimizations"); @@ -677,8 +677,9 @@ std::vector NVPTXCompiler::CompilePtxOrGetCachedResult( if (inserted) { CHECK(!cache_value->compilation_done); if (!ptx.empty()) { - StatusOr> maybe_cubin = CompilePtx( - stream_exec, *cache_ptx, PtxCompilationOptions(hlo_module_config)); + StatusOr> maybe_cubin = se::cuda::CompilePtx( + stream_exec->device_ordinal(), cache_ptx->c_str(), + PtxOptsFromConfig(hlo_module_config)); if (maybe_cubin.ok()) { cache_value->cubin_data = std::move(maybe_cubin).ValueOrDie(); VLOG(2) << "Compiled PTX size:" << ptx.size() diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc index e0f3e84a4cb..527305070b7 100644 --- a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc @@ -30,7 +30,7 @@ OutfeedThunk::OutfeedThunk(ShapeTree outfeed_slices, Status OutfeedThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, - HloExecutionProfiler* profiler) { + const RunId& /*run_id*/, HloExecutionProfiler* profiler) { VLOG(2) << "Outfeeding from GPU: " << hlo_instruction()->ToString(); auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h index 8ed89f05f0c..5e7bc7cea1a 100644 --- a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h @@ -39,7 +39,7 @@ class OutfeedThunk : public Thunk { OutfeedThunk& operator=(const OutfeedThunk&) = delete; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) override; private: diff --git a/tensorflow/compiler/xla/service/gpu/redzone_allocator.cc b/tensorflow/compiler/xla/service/gpu/redzone_allocator.cc index 9427a44a90c..64db95ce98a 100644 --- a/tensorflow/compiler/xla/service/gpu/redzone_allocator.cc +++ b/tensorflow/compiler/xla/service/gpu/redzone_allocator.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/stream_executor/cuda/ptxas_utils.h" #include "tensorflow/stream_executor/device_memory.h" #include "tensorflow/stream_executor/kernel.h" #include "tensorflow/stream_executor/kernel_spec.h" @@ -272,8 +273,10 @@ StatusOr RedzoneAllocator::CheckRedzones( se::StreamExecutor* executor = stream->parent(); absl::Span compiled_ptx = {}; - StatusOr> compiled_ptx_or = CompilePtxOrGetCached( - executor, redzone_checker_ptx, PtxCompilationOptions(hlo_module_config_)); + StatusOr> compiled_ptx_or = + se::cuda::CompilePtxOrGetCached(executor->device_ordinal(), + redzone_checker_ptx, + PtxOptsFromConfig(hlo_module_config_)); if (compiled_ptx_or.ok()) { compiled_ptx = compiled_ptx_or.ValueOrDie(); } else { diff --git a/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc index 84285be70a4..2f456938d92 100644 --- a/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc @@ -35,11 +35,11 @@ Status SequentialThunk::Initialize(const GpuExecutable& executable, Status SequentialThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, - HloExecutionProfiler* profiler) { + const RunId& run_id, HloExecutionProfiler* profiler) { auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); for (const auto& thunk : thunks_) { TF_RETURN_IF_ERROR( - thunk->ExecuteOnStream(buffer_allocations, stream, profiler)); + thunk->ExecuteOnStream(buffer_allocations, stream, run_id, profiler)); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/sequential_thunk.h b/tensorflow/compiler/xla/service/gpu/sequential_thunk.h index 3c4de1d1a6c..e617c99c2c9 100644 --- a/tensorflow/compiler/xla/service/gpu/sequential_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/sequential_thunk.h @@ -42,7 +42,7 @@ class SequentialThunk : public Thunk { Status Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) override; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) override; private: diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc index ca409fff67b..75b6f31e3dc 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc @@ -226,186 +226,11 @@ Status ExecuteKernelOnStream(const se::KernelBase& kernel, return Status::OK(); } -// Prints a warning if the ptxas at ptxas_path has known bugs. -// -// Only prints a warning the first time it's called for a particular value of -// ptxas_path. -// -// Locks on entry. -void WarnIfBadPtxasVersion(const string& ptxas_path) { - static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); - static std::unordered_set* seen_ptxas_paths GUARDED_BY(mu) = - new std::unordered_set(); - - tensorflow::mutex_lock lock(mu); - if (!seen_ptxas_paths->insert(ptxas_path).second) { - // Already checked this ptx binary, nothing to do. - return; - } - - tensorflow::SubProcess ptxas; - ptxas.SetProgram(ptxas_path, {ptxas_path, "--version"}); - ptxas.SetChannelAction(tensorflow::CHAN_STDOUT, tensorflow::ACTION_PIPE); - if (!ptxas.Start()) { - LOG(WARNING) << "Couldn't invoke " << ptxas_path << " --version"; - return; - } - - string out; - int exit_code = ptxas.Communicate(/*stdin_input=*/nullptr, &out, - /*stderr_output=*/nullptr); - if (exit_code != 0) { - LOG(WARNING) << "Running " << ptxas_path << " --version returned " - << exit_code; - return; - } - - int64 vmaj, vmin, vdot; - string vmaj_str, vmin_str, vdot_str; - if (!RE2::PartialMatch(out, R"(\bV(\d+)\.(\d+)\.(\d+)\b)", &vmaj_str, - &vmin_str, &vdot_str) || - !absl::SimpleAtoi(vmaj_str, &vmaj) || - !absl::SimpleAtoi(vmin_str, &vmin) || - !absl::SimpleAtoi(vdot_str, &vdot)) { - LOG(WARNING) << "Couldn't parse ptxas version in output of " << ptxas_path - << " --version:\n" - << out; - return; - } - - // We need ptxas >= 9.0 as a hard requirement, because we compile targeting - // PTX 6.0. An older ptxas will just fail to compile any of our code. - // - // ptxas 9.0 before 9.0.276 and ptxas 9.1 before 9.1.121 miscompile some - // address calculations with large offsets (e.g. "load ptr + large_constant"), - // b/70245379. - // - // ptxas 9.1.121 miscompiles some large multioutput fusions, again in a way - // that appears related to address calculations, b/111107644. ptxas 9.2.88 - // appears to work, as far as we can tell. - if (vmaj < 9) { - LOG(ERROR) - << "You are using ptxas 8.x, but XLA requires ptxas 9.x (and strongly " - "prefers >= 9.2.88). Compilation of XLA kernels below will likely " - "fail.\n\nYou do not need to update CUDA; cherry-picking the ptxas " - "binary is sufficient."; - } else if (std::make_tuple(vmaj, vmin, vdot) < std::make_tuple(9, 2, 88)) { - LOG(WARNING) - << "*** WARNING *** You are using ptxas " << vmaj << "." << vmin << "." - << vdot - << ", which is older than 9.2.88. ptxas 9.x before 9.2.88 is known to " - "miscompile XLA code, leading to incorrect results or " - "invalid-address errors.\n\nYou do not need to update to CUDA " - "9.2.88; cherry-picking the ptxas binary is sufficient."; - } -} - -StatusOr> CompilePtxOrGetCached( - se::StreamExecutor* executor, absl::string_view ptx, - PtxCompilationOptions compilation_options) { - using PtxCacheKey = std::tuple; - static tensorflow::mutex ptx_cache_mutex(tensorflow::LINKER_INITIALIZED); - static auto& ptx_cache GUARDED_BY(ptx_cache_mutex) = - *new absl::flat_hash_map>(); - - tensorflow::mutex_lock lock(ptx_cache_mutex); - PtxCacheKey cache_key{executor, std::string(ptx), - compilation_options.ToTuple()}; - auto it = ptx_cache.find(cache_key); - if (it == ptx_cache.end()) { - TF_ASSIGN_OR_RETURN(std::vector compiled, - CompilePtx(executor, ptx, compilation_options)); - it = ptx_cache.emplace(cache_key, std::move(compiled)).first; - } - - CHECK(it != ptx_cache.end()); - const std::vector& compiled = it->second; - return absl::MakeSpan(compiled); -} - -StatusOr> CompilePtx( - se::StreamExecutor* stream_exec, absl::string_view ptx, - PtxCompilationOptions compile_ptx_options) { - int cc_major, cc_minor; - if (!stream_exec->GetDeviceDescription().cuda_compute_capability(&cc_major, - &cc_minor)) { - LOG(WARNING) - << "Couldn't get compute capability for device; assuming sm_20."; - cc_major = 2; - cc_minor = 0; - } - - tensorflow::profiler::TraceMe activity( - "Compile PTX", tensorflow::profiler::TraceMeLevel::kInfo); - auto env = tensorflow::Env::Default(); - string ptxas_path; - for (const string& cuda_root : tensorflow::CandidateCudaRoots( - /*preferred_location=*/compile_ptx_options.xla_gpu_cuda_data_dir)) { - ptxas_path = tensorflow::io::JoinPath(cuda_root, "bin", "ptxas"); - VLOG(2) << "Looking for ptxas at " << ptxas_path; - if (env->FileExists(ptxas_path).ok()) { - break; - } - } - TF_RETURN_IF_ERROR(env->FileExists(ptxas_path)); - VLOG(2) << "Using ptxas at " << ptxas_path; - - WarnIfBadPtxasVersion(ptxas_path); - - // Write ptx into a temporary file. - string ptx_path; - if (!env->LocalTempFilename(&ptx_path)) { - return InternalError("couldn't get temp PTX file name"); - } - auto ptx_cleaner = tensorflow::gtl::MakeCleanup([&ptx_path] { - TF_CHECK_OK(tensorflow::Env::Default()->DeleteFile(ptx_path)); - }); - - TF_RETURN_IF_ERROR(tensorflow::WriteStringToFile(env, ptx_path, ptx)); - VLOG(2) << "ptx written to: " << ptx_path; - - // Invoke ptxas and collect its output. - string cubin_path; - if (!env->LocalTempFilename(&cubin_path)) { - return InternalError("couldn't get temp CUBIN file name"); - } - auto cubin_cleaner = tensorflow::gtl::MakeCleanup([&cubin_path] { - // CUBIN file may never be created, so the failure to delete it should not - // produce TF error. - tensorflow::Env::Default()->DeleteFile(cubin_path).IgnoreError(); - }); - tensorflow::SubProcess ptxas_info_dumper; - std::vector ptxas_args = { - ptxas_path, ptx_path, "-o", cubin_path, - absl::StrCat("-arch=sm_", cc_major, cc_minor)}; - if (VLOG_IS_ON(2)) { - ptxas_args.push_back("-v"); - } - if (compile_ptx_options.xla_gpu_disable_ptxas_optimizations) { - ptxas_args.push_back("-O0"); - } - ptxas_info_dumper.SetProgram(ptxas_path, ptxas_args); - ptxas_info_dumper.SetChannelAction(tensorflow::CHAN_STDERR, - tensorflow::ACTION_PIPE); - if (!ptxas_info_dumper.Start()) { - return InternalError("Failed to launch ptxas"); - } - string stderr_output; - int exit_status = ptxas_info_dumper.Communicate( - /*stdin_input=*/nullptr, /*stdout_output=*/nullptr, &stderr_output); - XLA_LOG_LINES(tensorflow::INFO, stderr_output); - if (exit_status != 0) { - return InternalError("ptxas exited with non-zero error code %d", - exit_status); - } - - // Read in the result of compilation and return it as a byte vector. - string cubin; - TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(tensorflow::Env::Default(), - cubin_path, &cubin)); - std::vector cubin_vector(cubin.begin(), cubin.end()); - return cubin_vector; +se::cuda::PtxCompilationOptions PtxOptsFromConfig( + const HloModuleConfig& hlo_module_config) { + return se::cuda::PtxCompilationOptions( + hlo_module_config.debug_options().xla_gpu_disable_ptxas_optimizations(), + hlo_module_config.debug_options().xla_gpu_cuda_data_dir()); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h index 06ac7dca634..483ab210558 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/stream_executor/cuda/ptxas_utils.h" #include "tensorflow/stream_executor/kernel_spec.h" // Helper functions for interacting with StreamExecutor. @@ -103,47 +104,9 @@ Status ExecuteKernelOnStream(const se::KernelBase& kernel, int64 threads_per_block, int64 block_count, se::Stream* stream); -// Options for compiling with PTX. -struct PtxCompilationOptions { - bool xla_gpu_disable_ptxas_optimizations; - std::string xla_gpu_cuda_data_dir; - - using PtxOptionsTuple = std::tuple; - - explicit PtxCompilationOptions(const HloModuleConfig& hlo_module_config) - : xla_gpu_disable_ptxas_optimizations( - hlo_module_config.debug_options() - .xla_gpu_disable_ptxas_optimizations()), - xla_gpu_cuda_data_dir( - hlo_module_config.debug_options().xla_gpu_cuda_data_dir()) {} - - // For comparison and hashing. - PtxOptionsTuple ToTuple() { - return std::make_tuple(xla_gpu_disable_ptxas_optimizations, - xla_gpu_cuda_data_dir); - } -}; - -// Compiles the given PTX string using ptxas and returns the resulting machine -// code (i.e. a cubin) as a byte array. -// -// Queries stream executor stream_exec to get CUDA compute capability from the -// device. -// -// compile_ptx_options is used to query for the CUDA location in case it is -// customized in a passed flag, and for controlling ptxas optimizations. -// It can be constructed from HloModuleConfig. -StatusOr> CompilePtx( - se::StreamExecutor* stream_exec, absl::string_view ptx, - PtxCompilationOptions compile_ptx_options); - -// Same as CompilePtx, but caches the result, and returns unowned view of -// the compiled binary. -// -// A copy of the string provided in ptx will be made. -StatusOr> CompilePtxOrGetCached( - se::StreamExecutor* executor, absl::string_view ptx, - PtxCompilationOptions compilation_options); +// Create PtxCompilationOptions out of HloModuleConfig. +se::cuda::PtxCompilationOptions PtxOptsFromConfig( + const HloModuleConfig& hlo_module_config); } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/target_util.cc b/tensorflow/compiler/xla/service/gpu/target_util.cc index 8225cd79a66..746f74b8e45 100644 --- a/tensorflow/compiler/xla/service/gpu/target_util.cc +++ b/tensorflow/compiler/xla/service/gpu/target_util.cc @@ -36,14 +36,6 @@ struct TargetIntrinsics { // corresponding to the give TargetIntrinsicID. struct TargetIntrinsics GetIntrinsic(TargetIntrinsicID intrin) { switch (intrin) { - case TargetIntrinsicID::kShflDownF32: { - return {llvm::Intrinsic::nvvm_shfl_sync_down_f32, - llvm::Intrinsic::not_intrinsic}; - } - case TargetIntrinsicID::kShflDownI32: { - return {llvm::Intrinsic::nvvm_shfl_sync_down_i32, - llvm::Intrinsic::not_intrinsic}; - } case TargetIntrinsicID::kThreadIdx: { return {llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, llvm::Intrinsic::amdgcn_workitem_id_x}; diff --git a/tensorflow/compiler/xla/service/gpu/target_util.h b/tensorflow/compiler/xla/service/gpu/target_util.h index b8f796c7259..a7497b91390 100644 --- a/tensorflow/compiler/xla/service/gpu/target_util.h +++ b/tensorflow/compiler/xla/service/gpu/target_util.h @@ -31,9 +31,7 @@ namespace gpu { // Enmeration to get target specific intrinsics. enum class TargetIntrinsicID { - kShflDownF32 = 0, - kShflDownI32, - kThreadIdx, + kThreadIdx = 0, kThreadIdy, kThreadIdz, kBlockIdx, diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD index b6ce15bb384..4c229046e14 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -4,9 +4,10 @@ # TODO(jlebar): None of these tests actually use the GPU, so they should not # need to run on machines with GPUs present. -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = [":friends"]) +package( + default_visibility = [":friends"], + licenses = ["notice"], # Apache 2.0 +) package_group( name = "friends", diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h index bdd06718717..9670a3ece08 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.h +++ b/tensorflow/compiler/xla/service/gpu/thunk.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -92,7 +93,7 @@ class Thunk { // // Precondition: Initialize(stream->parent()) has been called. virtual Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) = 0; protected: diff --git a/tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.cc b/tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.cc index 5200a2af412..2635a7b3c45 100644 --- a/tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.cc @@ -70,7 +70,7 @@ TriangularSolveThunk::TriangularSolveThunk( Status TriangularSolveThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, - HloExecutionProfiler* profiler) { + const RunId& /*run_id*/, HloExecutionProfiler* profiler) { VLOG(3) << "uplo=" << se::blas::UpperLowerString(uplo_) << " side=" << se::blas::SideString(side_) << " diagonal=" << se::blas::DiagonalString(unit_diagonal_) diff --git a/tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h b/tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h index c947162ea32..94bf6bf6442 100644 --- a/tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h @@ -49,7 +49,7 @@ class TriangularSolveThunk : public Thunk { TriangularSolveThunk& operator=(const TriangularSolveThunk&) = delete; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) override; private: diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc index 989b542ff45..f7dda240367 100644 --- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc @@ -23,7 +23,7 @@ namespace xla { namespace gpu { Status TupleThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& /*run_id*/, HloExecutionProfiler* profiler) { auto size = tuple_element_buffers_.size(); auto tuple_element_buffer_addresses = absl::make_unique(size); diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h index dcdbf2cf3c2..47784c5c373 100644 --- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h @@ -46,7 +46,7 @@ class TupleThunk : public Thunk { TupleThunk& operator=(const TupleThunk&) = delete; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) override; private: diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.cc b/tensorflow/compiler/xla/service/gpu/while_thunk.cc index c4754fe3789..0223582f2a9 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.cc @@ -48,7 +48,7 @@ Status WhileThunk::Initialize(const GpuExecutable& executable, } Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) { se::DeviceMemoryBase condition_result_data = buffer_allocations.GetDeviceAddress(condition_result_buffer_index_); @@ -59,7 +59,7 @@ Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, profiler->StartHloComputation(); VLOG(3) << "Executing condition computation"; TF_RETURN_IF_ERROR(condition_thunk_sequence_->ExecuteOnStream( - buffer_allocations, stream, profiler)); + buffer_allocations, stream, run_id, profiler)); profiler->FinishHloComputation(hlo_instruction()->while_condition()); // Copy the result of condition computation and break the loop if 'false'. @@ -83,8 +83,8 @@ Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, VLOG(3) << "Executing body computation"; // Invoke thunk sequence for while 'body' computation, and pass on // 'profiler' to measure the timing of the thunks in 'body_thunk_sequence_'. - TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(buffer_allocations, - stream, profiler)); + TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream( + buffer_allocations, stream, run_id, profiler)); profiler->FinishHloComputation(hlo_instruction()->while_body()); } return Status::OK(); diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.h b/tensorflow/compiler/xla/service/gpu/while_thunk.h index 9270f95ee67..97ac24f61cc 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.h @@ -49,7 +49,7 @@ class WhileThunk : public Thunk { Status Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) override; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) override; private: diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.h b/tensorflow/compiler/xla/service/hlo_alias_analysis.h index d09ec15e83a..08ef3eabfb6 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.h @@ -43,7 +43,7 @@ class HloAliasAnalysis { static StatusOr> Run( const HloModule* module, const HloDataflowAnalysis::FusionCanShareBufferFunction& - fusion_can_share_buffer); + fusion_can_share_buffer = nullptr); string ToString() const; diff --git a/tensorflow/compiler/xla/service/hlo_buffer.h b/tensorflow/compiler/xla/service/hlo_buffer.h index a81078fdc96..91597d6f705 100644 --- a/tensorflow/compiler/xla/service/hlo_buffer.h +++ b/tensorflow/compiler/xla/service/hlo_buffer.h @@ -93,6 +93,17 @@ class HloBuffer { // Return all values contained in this buffer. const std::vector& values() const { return values_; } + // Memory space color. Used to indicate the memory space that the hlo buffer + // needs to live in. + BufferValue::Color color() const { + // Invariant: All values in the buffer should have the same color. + BufferValue::Color result = values()[0]->color(); + for (const HloValue* value : values()) { + DCHECK_EQ(result, value->color()); + } + return result; + } + // Return the unique HLO value in the buffer. CHECK fails if the buffer does // not contain exactly one value. const HloValue& GetUniqueValue() const { diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 195c84b034f..908c1ad451a 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -166,14 +166,23 @@ Status HloComputation::RemoveParameter(int64 param_no) { return Status::OK(); } -Status HloComputation::RemoveUnusedParameters() { - CHECK(IsFusionComputation()); +Status HloComputation::RemoveUnusedParametersFromFusedComputation() { + return RemoveUnusedParametersImpl(/*allow_non_fusion=*/false); +} + +Status HloComputation::RemoveUnusedParametersFromAnyComputation() { + return RemoveUnusedParametersImpl(/*allow_non_fusion=*/true); +} + +Status HloComputation::RemoveUnusedParametersImpl(bool allow_non_fusion) { + CHECK(allow_non_fusion || IsFusionComputation()); int64 removed = 0; for (int64 i = 0; i < param_instructions_.size(); ++i) { HloInstruction* param_instruction = param_instructions_[i]; if (param_instruction->user_count() == 0 && param_instruction != root_instruction()) { - TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction)); + TF_RETURN_IF_ERROR( + RemoveInstructionImpl(param_instruction, allow_non_fusion)); ++removed; continue; } @@ -185,14 +194,15 @@ Status HloComputation::RemoveUnusedParameters() { StrCat("param_", param_no))); TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr)); param_instructions_[param_no] = new_instr; - TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction)); + TF_RETURN_IF_ERROR( + RemoveInstructionImpl(param_instruction, allow_non_fusion)); } } param_instructions_.resize(param_instructions_.size() - removed); return Status::OK(); } -bool HloComputation::IsRemovable(const HloInstruction* instruction) { +bool HloComputation::IsSafelyRemovable(const HloInstruction* instruction) { // If the instruction has control predecessors or successors then we cannot // remove the instruction without violating ordering constraints (added, for // example, to avert interference due to buffer aliasing). @@ -223,7 +233,7 @@ Status HloComputation::RemoveInstructionAndUnusedOperands( TF_RET_CHECK(root_instruction() != instruction); TF_RET_CHECK(instruction->user_count() == 0); - TF_RET_CHECK(IsRemovable(instruction)) + TF_RET_CHECK(IsSafelyRemovable(instruction)) << "Cannot remove instruction: " << instruction->ToString(); absl::flat_hash_set removed; std::queue worklist; @@ -233,7 +243,7 @@ Status HloComputation::RemoveInstructionAndUnusedOperands( worklist.pop(); if (removed.contains(item) || item->user_count() != 0 || - item == root_instruction() || !IsRemovable(item) || + item == root_instruction() || !IsSafelyRemovable(item) || (item->HasSideEffect() && item != instruction)) { continue; } @@ -248,9 +258,18 @@ Status HloComputation::RemoveInstructionAndUnusedOperands( } Status HloComputation::RemoveInstruction(HloInstruction* instruction) { + return RemoveInstructionImpl(instruction, /*ignore_safety_check=*/false); +} + +Status HloComputation::ForceRemoveInstruction(HloInstruction* instruction) { + return RemoveInstructionImpl(instruction, /*ignore_safety_check=*/true); +} + +Status HloComputation::RemoveInstructionImpl(HloInstruction* instruction, + bool ignore_safety_check) { VLOG(2) << "Removing instruction " << instruction->name() << " from computation " << name(); - TF_RET_CHECK(IsRemovable(instruction)) + TF_RET_CHECK(ignore_safety_check || IsSafelyRemovable(instruction)) << "cannot remove instruction: " << instruction->ToString(); TF_RET_CHECK(root_instruction() != instruction) << "cannot remove root instruction " << instruction->name(); @@ -291,6 +310,16 @@ void HloComputation::set_root_instruction(HloInstruction* new_root_instruction, } DCHECK(root_found); + if (parent() && parent()->has_entry_computation() && + parent()->entry_computation() == this) { + if (!Shape::Equal()(new_root_instruction->shape(), + root_instruction_->shape())) { + // Rebuild input output alias config now that we have a new output shape. + parent()->input_output_alias_config() = + HloInputOutputAliasConfig(new_root_instruction->shape()); + } + } + root_instruction_ = new_root_instruction; } diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 89dbe93b36b..ad6cc2fee41 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -115,7 +115,12 @@ class HloComputation { // Remove unused parameters from the computation. // Note this is only applicatable to the computation for the fusion // instruction. - Status RemoveUnusedParameters(); + Status RemoveUnusedParametersFromFusedComputation(); + + // Remove unused parameters from the computation. Unlike + // RemoveUnusedParametersFromFusedComputation, this function can be used + // to remove parameters from non-fusion computations. + Status RemoveUnusedParametersFromAnyComputation(); // Adds a new parameter instruction to a fusion computation. // @@ -135,6 +140,11 @@ class HloComputation { // users. Instruction is deallocated with this call. Status RemoveInstruction(HloInstruction* instruction); + // Removes an instruction from the computation. The instruction must have no + // users. Instruction is deallocated with this call. The instruction will be + // removed even if it is marked as not removable. + Status ForceRemoveInstruction(HloInstruction* instruction); + // Remove an instruction (including side effecting ones) from the computation // and also transitively any operand that has no side effect and no users post // removing an instruction. The instruction must have no users. Instruction is @@ -378,13 +388,13 @@ class HloComputation { // the HLO computation with the exception of fusion computation. A parameter // instruction is removable for a fusion computation. // - // Note that IsRemovable() is a necessariy condition to remove an instruction - // rather than a sufficient condition. For example, instructions with - // side-effect (e.g., Send, Infeed) may be removed from a computation, but the - // transformation must guarantee the invariants relevant to the instructions - // still hold (e.g., Send and Recv must be removed together to make each - // channel complete). - bool IsRemovable(const HloInstruction* instruction); + // Note that IsSafelyRemovable() is a necassarily condition to remove an + // instruction rather than a sufficient condition. For example, instructions + // with side-effect (e.g., Send, Infeed) may be removed from a computation, + // but the transformation must guarantee the invariants relevant to the + // instructions still hold (e.g., Send and Recv must be removed together to + // make each channel complete). + bool IsSafelyRemovable(const HloInstruction* instruction); // Returns a map from channel-id to the group of instructions associated with // the channel. These instructions will be considered as a single node for @@ -459,6 +469,11 @@ class HloComputation { std::vector* post_order, HloInstruction* root, absl::flat_hash_map* visited) const; + Status RemoveUnusedParametersImpl(bool allow_non_fusion); + + Status RemoveInstructionImpl(HloInstruction* instruction, + bool ignore_safety_check); + string name_; int64 unique_id_; HloInstruction* root_instruction_; diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 9036ae8d5fd..a1586af7b5a 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -149,7 +149,7 @@ string HloDataflowAnalysis::ToString() const { StrAppend(&out, " Instruction value sets:\n"); for (const HloComputation* computation : module_.computations()) { for (const HloInstruction* instruction : computation->instructions()) { - StrAppend(&out, " ", instruction->name(), ":\n"); + StrAppend(&out, "Instruction: \n ", instruction->name(), ":\n"); if (instruction->shape().IsTuple()) { GetInstructionValueSet(instruction) .ForEachElement([this, &instruction, &out]( @@ -1044,7 +1044,7 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( } if (fusion_can_share_buffer_ != nullptr) { - return fusion_can_share_buffer_(user, operand); + return fusion_can_share_buffer_(user, operand, user_index); } if (user->IsLoopFusion() || user->IsInputFusion()) { diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index ece17fc4c3e..de4ea8a80df 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -49,12 +49,14 @@ class HloDataflowAnalysis { // default strategy. // // The first parameter of the function should be the fusion instruction, the - // second parameter should be an operand of the fusion instruction. + // second parameter should be an operand of the fusion instruction. The third + // parameter should be the output index of the fusion. // // TODO(b/80315712): Find a better way to tell whether a fusion can share // buffer. using FusionCanShareBufferFunction = std::function; + const HloInstruction* fusion, const HloInstruction* operand, + const ShapeIndex& fusion_index)>; // Run dataflow analysis on the given module. Parameters: // @@ -128,7 +130,7 @@ class HloDataflowAnalysis { int64 value_count() const { return values_.size(); } // Return a vector of all HloValues stabily sorted by HloValue::Id. - const std::vector& values() const { return values_vector_; } + const std::vector& values() const { return values_vector_; } // Return the call graph used for computing the dataflow. const CallGraph& call_graph() const { return *call_graph_; } @@ -153,6 +155,8 @@ class HloDataflowAnalysis { HloInstruction* user, const ShapeIndex& user_index) const; + const HloModule& module() const { return module_; } + protected: HloDataflowAnalysis( const HloModule& module, bool ssa_form, @@ -238,7 +242,7 @@ class HloDataflowAnalysis { std::vector value_ids_to_delete_; // A vector containing all HloValues sorted by HloValue::Id. - std::vector values_vector_; + std::vector values_vector_; // The Id to use for the next HloValue. HloValue::Id next_value_id_ = 0; diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index cb2341a80be..275feab5030 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -2576,7 +2576,8 @@ TEST_F(CanShareOperandBufferWithUserTest, FusionCanShareBufferCustomized) { auto fusion = computation_->CreateFusionInstruction( {add, two, mul}, HloInstruction::FusionKind::kInput); RunAnalysis(/*fusion_can_share_buffer=*/[](const HloInstruction* fusion, - const HloInstruction*) { + const HloInstruction*, + const ShapeIndex& output_index) { return fusion->IsLoopFusion(); }); diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc index a5a11f09cf4..702de4fef86 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_dce.cc @@ -49,7 +49,7 @@ StatusOr HloDCE::Run(HloModule* module) { for (auto* instruction : computation->instructions()) { if (instruction != computation->root_instruction() && instruction->user_count() == 0 && - computation->IsRemovable(instruction) && + computation->IsSafelyRemovable(instruction) && !instruction->HasSideEffect()) { dead_roots.push_back(instruction); } diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 0320979102f..21cc216b33b 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -782,26 +783,15 @@ Status HloEvaluator::HandleTuple(HloInstruction* tuple) { namespace { -// Straightforward implementation of 1D DFT transform. Uses passed-in start -// index and stride to gather inputs from the data vector into the preallocated -// buffer, computes the result, and writes it back to the same locations in the -// data vector. Runs in O(length^2) time. -// -// Parameters contract_output and expand_input are used to avoid unnecessary -// calculations. When contract_output is set to true, then only (length / 2) + 1 -// output values are computed. When expand_input is set to true, then -// (length / 2) + 1 values from the data set are used to re-create the full set -// of size 'length', on which the transform is then performed. -// -void NaiveDft1D(int64 length, int64 start, int64 stride, bool inverse, - bool contract_output, bool expand_input, - absl::Span data, absl::Span buffer) { - CHECK_GT(data.size(), start + (length - 1) * stride); - CHECK_GT(buffer.size(), length - 1); - - // Copy input data to 1D vector. +// Common code used by 1D implementations, which copies data from the input to +// the contiguous buffer. Returns true if all copied values are zero. +bool GatherToBuffer(absl::Span data, int64 length, int64 start, + int64 stride, bool expand_input, + absl::Span buffer) { + CHECK_GE(buffer.size(), length); bool input_is_zero = true; const int64 ub = expand_input ? length / 2 + 1 : length; + CHECK_GE(data.size(), start + (ub - 1) * stride); for (int64 k = 0; k < ub; k++) { complex128 value = data[start + k * stride]; input_is_zero &= value == complex128(0.0, 0.0); @@ -815,22 +805,118 @@ void NaiveDft1D(int64 length, int64 start, int64 stride, bool inverse, } } } + return input_is_zero; +} + +// Returns (conjugated, if 'inverse' is true) k-th twiddle for the given length. +inline complex128 Twiddle(int64 k, int64 length, bool inverse) { + auto coeff = std::exp(complex128(0.0, -2.0 * M_PI * k / length)); + return inverse ? std::conj(coeff) : coeff; +} + +// Straightforward implementation of 1D DFT transform of arbitrary length. Uses +// passed-in start index and stride to gather inputs from the data vector into +// the preallocated buffer, computes the result, and writes it back to the same +// locations in the data vector. Runs in O(length^2) time. +// +// Parameters contract_output and expand_input are used to avoid unnecessary +// calculations. When contract_output is set to true, then only (length / 2) + 1 +// output values are computed. When expand_input is set to true, then +// (length / 2) + 1 values from the data set are used to re-create the full set +// of size 'length', on which the transform is then performed. +// +void NaiveDft1D(int64 length, int64 start, int64 stride, bool inverse, + bool contract_output, bool expand_input, + absl::Span data, absl::Span buffer) { + const bool input_is_zero = + GatherToBuffer(data, length, start, stride, expand_input, buffer); - // Do 1D transformation with double precision. if (!input_is_zero) { const int64 ub = contract_output ? length / 2 + 1 : length; for (int64 k = 0; k < ub; k++) { complex128 value = complex128(0.0, 0.0); for (int n = 0; n < length; n++) { - auto coeff = std::exp(complex128(0.0, -2.0 * M_PI * n * k / length)); - value += (inverse ? std::conj(buffer[n]) : buffer[n]) * coeff; + value += buffer[n] * Twiddle(n * k, length, inverse); } data[start + k * stride] = - inverse ? std::conj(value) / complex128(length, 0.0) : value; + inverse ? value / complex128(length, 0.0) : value; } } } +// Non-recursive implementation of the Cooley-Tukey radix-2 decimation in time. +// Performs 1D FFT transform for the lengths, which are powers of 2. Runs in +// O(length * log(length)) time. Uses the same parameters as the naive +// implementation above, except that the preallocated buffer must be at least +// twice as big as the length of the transform, because the buffer is used to +// hold both input and output values for each stage of the transform. +// +void Fft1D(int64 length, int64 start, int64 stride, bool inverse, + bool contract_output, bool expand_input, absl::Span data, + absl::Span buffer) { + CHECK(IsPowerOfTwo(static_cast(length))); + const bool input_is_zero = + GatherToBuffer(data, length, start, stride, expand_input, buffer); + + if (!input_is_zero) { + auto generate_twiddles = [](int64 length, bool inverse) { + std::vector twiddles; + // Need only half the twiddles. + for (int64 k = 0; k < length / 2; k++) { + twiddles.push_back(Twiddle(k, length, inverse)); + } + return twiddles; + }; + + // Indices into the parts of the buffer used for input and output values. + int64 in_base = length; + int64 out_base = 0; + + // At each stage, we "split" the input data into num_blocks, with block_size + // values in each block. + for (int64 num_blocks = 1; num_blocks < length; num_blocks *= 2) { + // Swap input and output parts of the buffer. + std::swap(in_base, out_base); + auto twiddles = generate_twiddles(num_blocks * 2, inverse); + const int64 block_size = length / num_blocks; + const int64 next_iteration_block_size = block_size / 2; + for (int64 block = 0; block < num_blocks; block++) { + const int64 in_offset = in_base + block * block_size; + const int64 out_offset = out_base + block * next_iteration_block_size; + // For each (even, odd) pair of values in the block, calculate two + // output values as even + twiddle * odd and even - twiddle * odd. + for (int64 pair = 0; pair < block_size / 2; pair++) { + const complex128 even = buffer[in_offset + pair]; + const complex128 odd = buffer[in_offset + block_size / 2 + pair]; + const complex128 twiddled_odd = twiddles[block] * odd; + buffer[out_offset + pair] = even + twiddled_odd; + buffer[out_offset + length / 2 + pair] = even - twiddled_odd; + } + } + } + // Copy computed result back to data. + const int64 ub = contract_output ? length / 2 + 1 : length; + for (int64 k = 0; k < ub; k++) { + complex128 value = buffer[out_base + k]; + data[start + k * stride] = + inverse ? value / complex128(length, 0.0) : value; + } + } +} + +// Determine, which implementation of 1D transform to use and call it. +void Dft1D(int64 length, int64 start, int64 stride, bool inverse, + bool contract_output, bool expand_input, absl::Span data, + absl::Span buffer) { + if (IsPowerOfTwo(static_cast(length))) { + Fft1D(length, start, stride, inverse, contract_output, expand_input, data, + buffer); + } else { + NaiveDft1D(length, start, stride, inverse, contract_output, expand_input, + data, buffer); + } +} + // Helper to reverse the order of dimension lengths in the passed-in literal. std::vector GetDimensionLengths(const Literal& literal) { std::vector lengths = literal.shape().dimensions(); @@ -906,8 +992,8 @@ void Sweep(int64 fft_rank, FftType fft_type, const int64 stride = fft_strides[sweep_axis]; const bool expand_input = input_is_truncated && sweep_axis == 0; const bool contract_oputput = output_is_truncated && sweep_axis == 0; - NaiveDft1D(length, start, stride, inverse, contract_oputput, expand_input, - data, buffer); + Dft1D(length, start, stride, inverse, contract_oputput, expand_input, + data, buffer); } else if (axis == sweep_axis) { // Visit only the elements with coordinate 0 along the sweep axis. sweep(sweep_axis, axis - 1, start); @@ -1207,10 +1293,10 @@ Status CheckParameters(const Shape& input_shape, const Shape& output_shape, } // namespace -// Flexible but slow implementation of the discrete Fourier transform. All -// transform types (FFT, IFFT, RFFT, and IRFFT) are supported, as well as the -// arbitrary rank and length of each dimension of the transform, and arbitrary -// layouts of the input and output literals. +// Flexible implementation of the discrete Fourier transform. All transform +// types (FFT, IFFT, RFFT, and IRFFT) are supported, as well as the arbitrary +// rank and length of each dimension of the transform, and arbitrary layouts of +// the input and output literals. // // The input literal in operand 0 provides input data, which must be complex64 // for FFT, IFFT, IRFFT transforms and float for RFFT. The transform is computed @@ -1241,15 +1327,18 @@ Status CheckParameters(const Shape& input_shape, const Shape& output_shape, // complex64[64][16][9] input array will use all input values and will produce // float[64][16][16] output. // -// The implementation of the 1D transform is a straightforward loop nest. The -// transforms of higher ranks apply sets of 1D transforms along each axis. For -// example, the 2D transform is computed by applying 1D transforms to each -// column followed by applying 1D transforms to each row. +// The implementation of the 1D transform for lengths, that are powers of 2, is +// the Cooley-Tukey radix-2 decimation-in-time. For all other 1D transform +// lengths, a straightforward, but slow, loop nest is used. The transforms of +// higher ranks apply sets of 1D transforms along each axis. For example, the 2D +// transform is computed by applying 1D transforms to each column followed by +// applying 1D transforms to each row. // // In general, a transform of rank n runs in O(N0*N1*...*Nn*(N0+N1+...+Nn)) -// time, where Ni is the length of the transform's i-th dimension. It is -// possible to reduce the run time to O(N0*N1*...(log(N0)+log(N1)+...)) by -// plugging in a more efficient 1D implementation. +// time, where Ni is the length of the transform's i-th dimension. However, for +// dimension lengths, which are powers of 2, the run time along these dimensions +// is reduced to log(Ni) in the summation, giving the runtime of +// O(N0*N1*...*Nn*(log(N0)+log(N1)+...+log(Nn)) in the best case. // Status HloEvaluator::HandleFft(HloInstruction* fft) { const FftType fft_type = fft->fft_type(); @@ -1275,8 +1364,14 @@ Status HloEvaluator::HandleFft(HloInstruction* fft) { // Linearized working data set. std::vector data(fft_size); - // Temporary buffer allocated once and used in 1D sweeps. - std::vector buffer(*absl::c_max_element(fft_lengths)); + // Temporary buffer allocated once and used in 1D sweeps. For dimension + // length values that are powers of 2, the buffer should be twice as large. + int64 buffer_size = 0; + for (auto len : fft_lengths) { + int64 size = IsPowerOfTwo(static_cast(len)) ? len * 2 : len; + buffer_size = std::max(buffer_size, size); + } + std::vector buffer(buffer_size); // Sizes of each axis of input and output literals. const auto input_lengths = GetDimensionLengths(input_literal); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index c4266f95fcc..888434774bb 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -129,6 +129,62 @@ class HloEvaluatorTest : public HloTestBase { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } + std::unique_ptr MaxComputationScalarF32() { + HloComputation::Builder max_computation("max"); + Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + auto param_lhs = max_computation.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "lhs")); + auto param_rhs = max_computation.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "rhs")); + max_computation.AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kMaximum, param_lhs, param_rhs)); + return max_computation.Build(); + } + + void ReduceWindowMaxIotaTest(int window_size, int padding, int stride, + int window_dilation, int base_dilation, + const Literal& expected) { + HloComputation::Builder b(TestName()); + + // arg: + // f32[4,4] { + // { 0, 1, 2, 3 }, + // { 4, 5, 6, 7 }, + // { 8, 9, 10, 11 }, + // { 12, 13, 14, 15 } + // } + auto arg_array = absl::make_unique>(4, 4); + arg_array->FillIota(0); + auto arg_literal = LiteralUtil::CreateR2FromArray2D(*arg_array); + + HloInstruction* arg_instruction = b.AddInstruction( + HloInstruction::CreateConstant(std::move(arg_literal))); + auto init_value = b.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.f))); + auto max_func = m_->AddEmbeddedComputation(MaxComputationScalarF32()); + + Window window; + WindowDimension dim; + dim.set_size(window_size); + dim.set_stride(stride); + dim.set_padding_low(padding); + dim.set_padding_high(padding); + dim.set_window_dilation(window_dilation); + dim.set_base_dilation(base_dilation); + *window.add_dimensions() = dim; + *window.add_dimensions() = dim; + + int dim0 = expected.shape().dimensions(0); + int dim1 = expected.shape().dimensions(1); + Shape shape = ShapeUtil::MakeShape(F32, {dim0, dim1}); + b.AddInstruction(HloInstruction::CreateReduceWindow( + shape, arg_instruction, init_value, window, max_func)); + + m_->AddEntryComputation(b.Build()); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); + } + protected: explicit HloEvaluatorTest(bool use_bfloat16) : use_bfloat16_(use_bfloat16) { InitializeFftData(); @@ -2585,16 +2641,7 @@ TEST_P(HloEvaluatorBf16Test, ReduceWindowMax) { auto init_value = b.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.f))); - - HloComputation::Builder max_computation("max"); - Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); - auto param_lhs = max_computation.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape, "lhs")); - auto param_rhs = max_computation.AddInstruction( - HloInstruction::CreateParameter(1, scalar_shape, "rhs")); - max_computation.AddInstruction(HloInstruction::CreateBinary( - scalar_shape, HloOpcode::kMaximum, param_lhs, param_rhs)); - auto max_func = m_->AddEmbeddedComputation(max_computation.Build()); + auto max_func = m_->AddEmbeddedComputation(MaxComputationScalarF32()); Window window; WindowDimension dim; @@ -2619,56 +2666,79 @@ TEST_P(HloEvaluatorBf16Test, ReduceWindowMax) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorBf16Test, ReduceWindowMaxWindowDilation) { - HloComputation::Builder b(TestName()); +TEST_P(HloEvaluatorBf16Test, ReduceWindowMaxIotaWindowDilation) { + auto expected = LiteralUtil::CreateR2({{10, 11}, {14, 15}}); + ReduceWindowMaxIotaTest( + /*window_size=*/2, + /*padding=*/0, + /*stride=*/1, + /*window_dilation=*/2, + /*base_dilation=*/1, + /*expected=*/expected); +} - // arg: - // f32[3,3] { - // { 1, 2, 3 }, - // { 5, 6, 7 }, - // { 9, 10, 11 }, - // } - auto arg_array = absl::make_unique>(3, 3); - arg_array->FillUnique(1.0f); - auto arg_literal = LiteralUtil::CreateR2FromArray2D(*arg_array); +TEST_P(HloEvaluatorBf16Test, ReduceWindowMaxIotaStrideWindowDilation) { + auto expected = LiteralUtil::CreateR2({{10}}); + ReduceWindowMaxIotaTest( + /*window_size=*/2, + /*padding=*/0, + /*stride=*/2, + /*window_dilation=*/2, + /*base_dilation=*/1, + /*expected=*/expected); +} - HloInstruction* arg_instruction = - b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal))); +TEST_P(HloEvaluatorBf16Test, ReduceWindowMaxIotaBaseDilation) { + auto expected = LiteralUtil::CreateR2({{0, 1, 1, 2, 2, 3}, + {4, 5, 5, 6, 6, 7}, + {4, 5, 5, 6, 6, 7}, + {8, 9, 9, 10, 10, 11}, + {8, 9, 9, 10, 10, 11}, + {12, 13, 13, 14, 14, 15}}); + ReduceWindowMaxIotaTest( + /*window_size=*/2, + /*padding=*/0, + /*stride=*/1, + /*window_dilation=*/1, + /*base_dilation=*/2, + /*expected=*/expected); +} - auto init_value = b.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.f))); +TEST_P(HloEvaluatorBf16Test, ReduceWindowMaxIotaStrideBaseDilation) { + auto expected = + LiteralUtil::CreateR2({{0, 1, 2}, {4, 5, 6}, {8, 9, 10}}); + ReduceWindowMaxIotaTest( + /*window_size=*/2, + /*padding=*/0, + /*stride=*/2, + /*window_dilation=*/1, + /*base_dilation=*/2, + /*expected=*/expected); +} - HloComputation::Builder max_computation("max"); - Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); - auto param_lhs = max_computation.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape, "lhs")); - auto param_rhs = max_computation.AddInstruction( - HloInstruction::CreateParameter(1, scalar_shape, "rhs")); - max_computation.AddInstruction(HloInstruction::CreateBinary( - scalar_shape, HloOpcode::kMaximum, param_lhs, param_rhs)); - auto max_func = m_->AddEmbeddedComputation(max_computation.Build()); +TEST_P(HloEvaluatorBf16Test, ReduceWindowMaxIotaStrideBothDilation) { + auto expected = + LiteralUtil::CreateR2({{5, 6, 7}, {9, 10, 11}, {13, 14, 15}}); + ReduceWindowMaxIotaTest( + /*window_size=*/2, + /*padding=*/0, + /*stride=*/2, + /*window_dilation=*/2, + /*base_dilation=*/2, + /*expected=*/expected); +} - Window window; - WindowDimension dim; - dim.set_size(2); - dim.set_stride(1); - dim.set_padding_low(0); - dim.set_padding_high(0); - dim.set_window_dilation(2); - dim.set_base_dilation(1); - *window.add_dimensions() = dim; - *window.add_dimensions() = dim; - - Shape shape = ShapeUtil::MakeShape(F32, {1, 1}); - b.AddInstruction(HloInstruction::CreateReduceWindow( - shape, arg_instruction, init_value, window, max_func)); - - m_->AddEntryComputation(b.Build()); - - TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); - - auto expected = LiteralUtil::CreateR2({{11}}); - EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); +TEST_P(HloEvaluatorBf16Test, ReduceWindowMaxIotaPaddingStrideBaseDilation) { + // The base is dilated first, and then padding is applied, hence this result. + auto expected = + LiteralUtil::CreateR2({{0, 2, 3}, {8, 10, 11}, {12, 14, 15}}); + ReduceWindowMaxIotaTest( + /*window_size=*/3, + /*padding=*/1, + /*stride=*/3, + /*window_dilation=*/1, + /*base_dilation=*/2, + /*expected=*/expected); } TEST_P(HloEvaluatorBf16Test, ReduceWindowAdd) { diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index c3b5838cf0a..a6a84d226f5 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -2673,16 +2673,27 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::vector base_index(rank); bool out_of_bound = false; for (int64 i = 0; i < rank; ++i) { + // Padding is applied to the dilated base. Say that padding is 3 and + // dilation is 2 for some dimension. After applying base dilation and + // padding, the dimension looks like: + // P P P E D D E D D ... E D D E P P P + // where E are the elements and D are the holes. So, the elements are + // located in indices: padding + k*base_dilation for k = {0, 1, 2, ...}. + // We are accessing elements in the transformed base at indices: + // window_count_index * stride + window_index * window_dilation. + // Solving for k gives us + // (win_count_i * stride + win_i * win_dilation - pad) / base_dilation + // When this is a natural number, we index an original element. + // Otherwise, we index a 0 (pad or hole), and we don't need to apply + // the callback f. base_index[i] = window_count_index[i] * window.dimensions(i).stride() + window_index[i] * window.dimensions(i).window_dilation() - window.dimensions(i).padding_low(); - // We are not in the base area if the dilation placed us out of bounds. if (base_index[i] % window.dimensions(i).base_dilation() != 0) { out_of_bound = true; break; } - // Apply the dilation to the base area. base_index[i] /= window.dimensions(i).base_dilation(); if (base_index[i] < 0 || base_index[i] >= base_shape.dimensions(i)) { out_of_bound = true; diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 7a6d563b83f..934c96d7630 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -1596,8 +1596,8 @@ Status HloFusionInstruction::DeduplicateFusionOperands() { if (operands_to_remove.empty()) { return Status::OK(); } - TF_RETURN_IF_ERROR( - fused_instructions_computation()->RemoveUnusedParameters()); + TF_RETURN_IF_ERROR(fused_instructions_computation() + ->RemoveUnusedParametersFromFusedComputation()); RemoveOperandsAtAscendingIndices(operands_to_remove); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 2c63247eea8..142b8f18ee6 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -109,6 +109,8 @@ class HloModule { return entry_computation_; } + bool has_entry_computation() const { return entry_computation_ != nullptr; } + // Returns the root instruction shape of entry computation. // // Precondition: entry_computation_ is not nullptr. diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 5ba390acfd4..9fb0cd7e077 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -207,7 +207,7 @@ StatusOr HloRunner::ExecuteWithDeviceBuffers( stream.Init(); ServiceExecutableRunOptions service_run_options = GetServiceRunOptionsForDevice(backend().default_device_ordinal(), &stream, - nullptr); + nullptr, RunId()); TF_ASSIGN_OR_RETURN(std::unique_ptr executable, CreateExecutable(std::move(module), run_hlo_passes)); @@ -243,7 +243,7 @@ StatusOr HloRunner::ExecuteWithDeviceBuffers( stream.Init(); ServiceExecutableRunOptions service_run_options = GetServiceRunOptionsForDevice(backend().default_device_ordinal(), &stream, - nullptr); + nullptr, RunId()); TF_ASSIGN_OR_RETURN( ScopedShapedBuffer retval, @@ -294,6 +294,7 @@ StatusOr> HloRunner::ExecuteReplicated( options.num_replicas * options.arguments.size() + 1); std::vector> argument_buffer_slices; int64 index = 0; + RunId run_id; for (int64 i = 0; i < options.num_replicas; ++i) { int64 device = (*device_assignment)(i, 0); TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, @@ -301,7 +302,7 @@ StatusOr> HloRunner::ExecuteReplicated( streams.push_back(absl::make_unique(executor)); streams.back()->Init(); service_run_options.emplace_back(GetServiceRunOptionsForDevice( - device, streams.back().get(), device_assignment)); + device, streams.back().get(), device_assignment, run_id)); // Copy arguments to device. for (const Literal* argument : options.arguments) { @@ -443,7 +444,8 @@ StatusOr> HloRunner::CreateExecutable( } ServiceExecutableRunOptions HloRunner::GetServiceRunOptionsForDevice( - int64 device, se::Stream* stream, DeviceAssignment* device_assignment) { + int64 device, se::Stream* stream, DeviceAssignment* device_assignment, + RunId run_id) { ExecutableRunOptions run_options; run_options.set_device_ordinal(device); run_options.set_stream(stream); @@ -453,6 +455,7 @@ ServiceExecutableRunOptions HloRunner::GetServiceRunOptionsForDevice( if (device_assignment != nullptr) { run_options.set_device_assignment(device_assignment); } + run_options.set_run_id(run_id); return ServiceExecutableRunOptions(run_options, backend().StreamBorrower()); } diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index 7e666a8186e..c077ccd95fe 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -206,7 +206,8 @@ class HloRunner { // will be used to configure the replication parameters. Replicated executions // should pass the device_assignment parameter. ServiceExecutableRunOptions GetServiceRunOptionsForDevice( - int64 device, se::Stream* stream, DeviceAssignment* device_assignment); + int64 device, se::Stream* stream, DeviceAssignment* device_assignment, + RunId run_id); std::unique_ptr backend_; }; diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc index ba856fc17af..18ab401bc89 100644 --- a/tensorflow/compiler/xla/service/hlo_value.cc +++ b/tensorflow/compiler/xla/service/hlo_value.cc @@ -91,7 +91,8 @@ string HloValue::ToShortString() const { ? defining_index().ToString() : ""; return StrCat(id(), " ", is_phi_ ? "PHI " : "", - defining_instruction()->name(), index_str); + defining_instruction()->name(), index_str, " @", + (has_color() ? color().value() : -1)); } string HloValue::ToString(int indent) const { diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index 7f0c1ccc728..feb3db64048 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -3,9 +3,10 @@ load( "if_static", ) -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//visibility:public"]) +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) cc_library( name = "interpreter_transfer_manager", diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index e1303f60779..72813d493cf 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -1,9 +1,10 @@ # Description: # Libraries for helping construct LLVM IR for XLA backends. -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = [":friends"]) +package( + default_visibility = [":friends"], + licenses = ["notice"], # Apache 2.0 +) package_group( name = "friends", @@ -39,6 +40,7 @@ cc_library( "//tensorflow/compiler/xla/service:logical_buffer", "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@llvm//:core", ], diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc index 761c6879db8..cd1431aa709 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/container/flat_hash_set.h" #include "llvm/IR/MDBuilder.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" @@ -40,15 +41,14 @@ void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo, // with our temporary buffers. buffer_slice = BufferAllocation::Slice(kParameterAllocation, 0, 0); } else { - const std::set slices = - assignment_.GetAllSlices(&hlo, index); - if (slices.empty() || slices.size() > 1) { + auto unique_slice = assignment_.GetUniqueSlice(&hlo, index); + if (!unique_slice.ok()) { // Skip HLOs which don't have a buffer assigned or for which the // buffer can't be determined statically. We cannot determine their // aliasing properties in these cases. return; } - buffer_slice = *slices.begin(); + buffer_slice = unique_slice.ValueOrDie(); } if (module_.config().debug_options().xla_llvm_enable_alias_scope_metadata()) { @@ -134,15 +134,26 @@ llvm::MDNode* AliasAnalysis::GetNoaliasMetadataForBuffer( // 3. Operands of the given hlo. // // This set can be increased as we need. - std::vector worklist; + std::vector worklist; + absl::flat_hash_set added_to_worklist; auto add_buffers_to_worklist = - [&worklist, &assignment](const HloInstruction* instruction) { + [&](const HloInstruction* instruction) { + // Buffers of parameters cannot be added to the noalias set. + if (instruction->opcode() == HloOpcode::kParameter) { + return; + } + if (added_to_worklist.contains(instruction)) { + return; + } + added_to_worklist.insert(instruction); ShapeUtil::ForEachSubshape( instruction->shape(), [&](const Shape& /*shape*/, const ShapeIndex& index) { - for (const LogicalBuffer* buffer : + for (const BufferValue* buffer : assignment.GetSourceBuffers(instruction, index)) { - worklist.push_back(buffer); + if (assignment.HasAllocation(*buffer)) { + worklist.push_back(buffer); + } } }); }; @@ -160,12 +171,7 @@ llvm::MDNode* AliasAnalysis::GetNoaliasMetadataForBuffer( } std::set buffers; - for (const LogicalBuffer* buffer : worklist) { - // Skip buffers which cannot be added to the noalias set. - if (!assignment.HasAllocation(*buffer) || - buffer->instruction()->opcode() == HloOpcode::kParameter) { - continue; - } + for (const BufferValue* buffer : worklist) { const BufferAllocation::Slice noalias_slice = assignment.GetAssignedAllocation(*buffer).GetSlice(*buffer); // Our buffer must not overlap with the noalias slice. diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc index 4974cb57db3..ba199f35712 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc @@ -23,6 +23,37 @@ limitations under the License. namespace xla { namespace llvm_ir { +bool MayBeImplementedAsInPlaceDynamicUpdateSlice(const HloInstruction* instr) { + // Today we can't emit a dynamic-update-slice if the DUS node is parallized; + // the emitter will not emit correct code. It's possible to change this, but + // then ParallelTaskAssigner would have to somehow know whether a node *will* + // be emitted as an in-place DUS, and it can't, because it doesn't have a + // buffer assignment when it runs. + if (!instr->outer_dimension_partitions().empty()) { + return false; + } + + // Until we know the final buffer assignment, any unfused dynamic-update-slice + // might be implementable as an in-place DUS. + if (instr->opcode() == HloOpcode::kDynamicUpdateSlice) { + return true; + } + + // A fusion may be implementable as an in-place dynamic update slice if + // - it's a loop fusion, + // - dynamic-update-slice is the root of the fusion, and + // - operand 0 of the dynamic-update-slice is a parameter to the fusion + // (ignoring any get-tuple-element operations in the way). + if (instr->IsLoopFusion()) { + const HloInstruction* fused_root = instr->fused_expression_root(); + return fused_root->opcode() == HloOpcode::kDynamicUpdateSlice && + fused_root->operand(0)->LatestNonGteAncestor()->opcode() == + HloOpcode::kParameter; + } + + return false; +} + bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice, const BufferAssignment& assignment) { CHECK_EQ(HloOpcode::kDynamicUpdateSlice, dynamic_update_slice->opcode()); @@ -32,6 +63,29 @@ bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice, assignment.SharesTopLevelSlice(dynamic_update_slice, operand); } +bool CanEmitFusedDynamicUpdateSliceInPlace(HloInstruction* fusion, + const BufferAssignment& assignment) { + CHECK_EQ(fusion->opcode(), HloOpcode::kFusion); + if (!MayBeImplementedAsInPlaceDynamicUpdateSlice(fusion)) { + return false; + } + + // Walk DynamicUpdateSlice operand(0) to fused parameter and get its + // associated operand. See if it shares an allocation with this operand. + HloInstruction* fused_root = fusion->fused_expression_root(); + HloInstruction* fusion_operand; + ShapeIndex index; + std::tie(fusion_operand, index) = + fused_root->mutable_operand(0)->LatestNonGteAncestorAndIndex(); + // MayBeImplementedAsInPlaceDynamicUpdateSlice should have ensured that + // fusion_operand is a parameter. + CHECK_EQ(fusion_operand->opcode(), HloOpcode::kParameter); + auto* operand = fusion->operand(fusion_operand->parameter_number()); + return assignment.HasAllocationAt(operand, index) && + assignment.HasAllocationAt(fusion, {}) && + assignment.SharesSliceAtIndex(fusion, {}, operand, index); +} + // Shared implementation of EmitDynamicUpdateSliceInPlace and // EmitFusedDynamicUpdateSliceInPlace. // diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h index c4da28229d0..70dc368d5d7 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h @@ -30,6 +30,22 @@ namespace llvm_ir { using GeneratorForOperandIrArrays = std::function()>; +// Determines whether the given instruction might be implemented as an +// in-place dynamic-update-slice after we have a buffer assignment. +// +// If this returns false, then CanUpdateDynamicSliceInPlace and +// CanEmitFusedDynamicUpdateSliceInPlace will also return false. +// +// This is useful if you want to check whether an instruction might be an +// in-place DUS during an HLO pass, at which point you don't have a buffer +// assignment. +// +// Note that simplifications to the HLO graph might change this function from +// returning false to returning true. Specifically, simplifying the contents of +// fusion nodes might cause a false->true transition. In general this isn't a +// problem by the time you're calling this function, but beware. +bool MayBeImplementedAsInPlaceDynamicUpdateSlice(const HloInstruction* instr); + // Checks if we can emit code for the given DynamicUpdateSlice node that updates // its input in place. Returns true if the dynamic-update-slice's // array-to-be-updated and output share the same BufferAllocation::Slice. @@ -40,28 +56,8 @@ bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice, // Checks if the given fusion node is amenable to being implemented by // EmitFusedDynamicUpdateSliceInPlace. -inline bool CanEmitFusedDynamicUpdateSliceInPlace( - HloInstruction* fusion, const BufferAssignment& assignment) { - CHECK_EQ(fusion->opcode(), HloOpcode::kFusion); - HloInstruction* fused_root = fusion->fused_expression_root(); - if (fused_root->opcode() != HloOpcode::kDynamicUpdateSlice || - !fusion->IsLoopFusion()) { - return false; - } - // Walk DynamicUpdateSlice operand(0) to fused parameter and get its - // associated operand. See if it shares an allocation with this operand. - HloInstruction* fusion_operand; - ShapeIndex index; - std::tie(fusion_operand, index) = - fused_root->mutable_operand(0)->LatestNonGteAncestorAndIndex(); - if (fusion_operand->opcode() != HloOpcode::kParameter) { - return false; - } - auto* operand = fusion->operand(fusion_operand->parameter_number()); - return assignment.HasAllocationAt(operand, index) && - assignment.HasAllocationAt(fusion, {}) && - assignment.SharesSliceAtIndex(fusion, {}, operand, index); -} +bool CanEmitFusedDynamicUpdateSliceInPlace(HloInstruction* fusion, + const BufferAssignment& assignment); // Emits IR for running the given dynamic-update-slice op in-place -- that is, // where the input and output buffers share the same slice, so we can simply diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc index 886a0545624..75e704bac66 100644 --- a/tensorflow/compiler/xla/service/platform_util.cc +++ b/tensorflow/compiler/xla/service/platform_util.cc @@ -263,12 +263,18 @@ PlatformUtil::GetStreamExecutors( // Block here in thread_pool destructor until all devices are initialized. } VLOG(1) << "Device initialization complete"; - if (absl::c_all_of(stream_executors, - [](se::StreamExecutor* s) { return s == nullptr; })) { + + std::vector out; + for (se::StreamExecutor* executor : stream_executors) { + if (executor != nullptr) { + out.push_back(executor); + } + } + if (out.empty()) { return InternalError("no supported devices found for platform %s", platform->Name()); } - return stream_executors; + return out; } } // namespace xla diff --git a/tensorflow/compiler/xla/service/platform_util.h b/tensorflow/compiler/xla/service/platform_util.h index 592b20282f3..5764f2c11d9 100644 --- a/tensorflow/compiler/xla/service/platform_util.h +++ b/tensorflow/compiler/xla/service/platform_util.h @@ -58,9 +58,7 @@ class PlatformUtil { static StatusOr GetPlatformExceptFor( const string& platform_name); - // Returns a vector of StreamExecutors for the given platform. The vector is - // indexed by device ordinal (device numbering used by StreamExecutor). If an - // element is nullptr, then the device is present by not supported by XLA. + // Returns a vector of StreamExecutors for the given platform. // If populated, only the devices in allowed_devices will have // their StreamExecutors initialized, otherwise all StreamExecutors will be // initialized and returned. diff --git a/tensorflow/compiler/xla/service/service_executable_run_options.h b/tensorflow/compiler/xla/service/service_executable_run_options.h index 7fc66310ee7..58028aebe1f 100644 --- a/tensorflow/compiler/xla/service/service_executable_run_options.h +++ b/tensorflow/compiler/xla/service/service_executable_run_options.h @@ -24,7 +24,7 @@ limitations under the License. namespace xla { // Class containing options for running a LocalExecutable and other auxiliary -// data, now only a stream cache for GPU backend. +// data. class ServiceExecutableRunOptions { public: using StreamBorrower = std::function(int)>; diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index 999e8a9c0ac..cec954645cc 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -46,7 +46,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // Don't try this transformation if the while loop isn't removable, since if // it succeeds ultimately we're going to have to replace the old while loop // with a new one. - if (!while_op->parent()->IsRemovable(while_op)) { + if (!while_op->parent()->IsSafelyRemovable(while_op)) { VLOG(2) << "Can't remove dead parameters from non-removable while op."; return false; } @@ -455,7 +455,7 @@ static StatusOr TryRemoveConstantParams(HloInstruction* while_op) { static StatusOr TryRemoveWhileLoop(HloInstruction* while_op) { // Cowardly refuse to remove loops that are not removable. In practice, this // means that we can't remove loops that have control predecessors/successors. - if (!while_op->parent()->IsRemovable(while_op)) { + if (!while_op->parent()->IsSafelyRemovable(while_op)) { VLOG(2) << "Not attempting to remove while loop that is not removable: " << while_op->ToShortString(); return false; diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc index 661b7aa7d99..4c221e2c116 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc @@ -35,7 +35,7 @@ StatusOr ZeroSizedHloElimination::Run(HloModule* module) { instruction->opcode() == HloOpcode::kConstant) { continue; } - if (comp->IsRemovable(instruction) && + if (comp->IsSafelyRemovable(instruction) && ShapeUtil::IsZeroElementArray(instruction->shape())) { // If the instruction doesn't have a layout, use a default layout for // the literal. diff --git a/tensorflow/compiler/xla/status_macros.h b/tensorflow/compiler/xla/status_macros.h index 315136acc71..c37087cb2c8 100644 --- a/tensorflow/compiler/xla/status_macros.h +++ b/tensorflow/compiler/xla/status_macros.h @@ -187,28 +187,4 @@ class StatusAdaptorForMacros { .with_log_stack_trace() \ .add_ret_check_failure(#condition) -#define TF_ASSERT_OK_AND_ASSIGN(lhs, rexpr) \ - TF_ASSERT_OK_AND_ASSIGN_IMPL( \ - TF_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, \ - rexpr); - -#define TF_ASSERT_OK_AND_ASSIGN_IMPL(statusor, lhs, rexpr) \ - auto statusor = (rexpr); \ - ASSERT_TRUE(statusor.status().ok()) << statusor.status(); \ - lhs = std::move(statusor.ValueOrDie()) - -#define TF_STATUS_MACROS_CONCAT_NAME(x, y) TF_STATUS_MACROS_CONCAT_IMPL(x, y) -#define TF_STATUS_MACROS_CONCAT_IMPL(x, y) x##y - -#define TF_ASSIGN_OR_RETURN(lhs, rexpr) \ - TF_ASSIGN_OR_RETURN_IMPL( \ - TF_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, rexpr) - -#define TF_ASSIGN_OR_RETURN_IMPL(statusor, lhs, rexpr) \ - auto statusor = (rexpr); \ - if (TF_PREDICT_FALSE(!statusor.ok())) { \ - return statusor.status(); \ - } \ - lhs = std::move(statusor.ValueOrDie()) - #endif // TENSORFLOW_COMPILER_XLA_STATUS_MACROS_H_ diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index cff87c59938..b2ba65eb46d 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -8,10 +8,9 @@ load( ) load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") -licenses(["notice"]) # Apache 2.0 - package( default_visibility = [":friends"], + licenses = ["notice"], # Apache 2.0 ) package_group( @@ -1715,7 +1714,7 @@ xla_test( # This test is tagged "manual" because it requires multiple GPUs, and # Forge only supports single-GPU tests. Guitar skips "manual" tests # unless they're also tagged "guitar". - "noguitar", # TODO(b/131524578): Re-enable this. + "guitar", "manual", "multi_gpu", "no_oss", diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index d700437ed35..daaf332ed0f 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -105,7 +105,7 @@ class ClientLibraryTestBase : public ::testing::Test { const Shape* shape_with_output_layout = nullptr); // This executes the computation via the reference client (which connects a - // interpreter backend). The result is used as the expected values of the + // interpreter backend). The result is used as the expected value of the // computation. StatusOr ExecuteAndTransferReference( const XlaComputation& computation, @@ -385,6 +385,9 @@ class ClientLibraryTestBase : public ::testing::Test { StatusOr> ComputeValueAndReference( XlaBuilder* builder, absl::Span arguments); + // Converts an f32 literal to bf16 if use_bfloat16_ is true. + Literal MaybeConvertLiteralToBfloat16(const Literal& literal); + LocalClient* client_; LocalClient* ref_client_; // To compute reference result. ExecutionOptions execution_options_; @@ -402,8 +405,7 @@ class ClientLibraryTestBase : public ::testing::Test { const string& error_message)>& verify_output, const Shape* output_with_layout = nullptr); - // Converts an f32 shape/literal to bf16 if use_bfloat16_ is true. - Literal MaybeConvertLiteralToBfloat16(const Literal& literal); + // Converts an f32 shape to bf16 if use_bfloat16_ is true. Shape MaybeConvertShapeToBfloat16(const Shape& shape); // Whether to run tests with all float-type input/output converted to diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc index 7eaa2791d47..2843d77607e 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.cc +++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc @@ -117,8 +117,10 @@ LocalClientTestBase::LocalClientTestBase(se::Platform* platform) : local_client_( ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie()), thread_pool_wrapper_(new EigenThreadPoolWrapper()) { + // Take the first executor, since it's the default one. stream_executor_ = PlatformUtil::GetStreamExecutors(local_client_->platform()) - .ValueOrDie()[local_client_->default_device_ordinal()]; + .ValueOrDie() + .front(); transfer_manager_ = TransferManager::GetForPlatform(local_client_->platform()).ValueOrDie(); } diff --git a/tensorflow/compiler/xla/tests/multi_device_all_reduce_test.cc b/tensorflow/compiler/xla/tests/multi_device_all_reduce_test.cc index 7895895e3e7..da0c94c8fa9 100644 --- a/tensorflow/compiler/xla/tests/multi_device_all_reduce_test.cc +++ b/tensorflow/compiler/xla/tests/multi_device_all_reduce_test.cc @@ -21,7 +21,9 @@ limitations under the License. #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/lib/core/blocking_counter.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/core/threadpool.h" // Tests cross-GPU all-reduce operatons. // @@ -210,5 +212,43 @@ XLA_TEST_F(MultiDeviceAllReduceTest, NcclChannelCaching) { EXPECT_THAT(OpenNcclChannels(), IsEmpty()); } +// Runs the same executable many times concurrently. The all-reduces should not +// conflict with one another. +XLA_TEST_F(MultiDeviceAllReduceTest, ManyConcurrentAllReduces) { + const int64 kNumElems = 1024; + const int64 kNumThreads = 200; + const int64 kRunsPerThread = 10; + + auto config = GetModuleConfigForTest(); + config.set_replica_count(2); + auto executable = test_runner_ + .CreateExecutable(MakeCrsModule(kNumElems, config), + /*run_hlo_passes=*/true) + .ValueOrDie(); + std::vector devices = {0, 1}; + auto device_assn = MakeDeviceAssn(devices); + + std::vector input_vec(kNumElems); + absl::c_iota(input_vec, 0); + auto input_literal = LiteralUtil::CreateR1(input_vec); + HloRunner::ReplicatedExecuteOptions opts; + opts.num_replicas = devices.size(); + opts.use_threads = true; + opts.arguments.push_back(&input_literal); + + tensorflow::BlockingCounter done(kNumThreads * kRunsPerThread); + tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), TestName(), + kNumThreads); + for (int64 i = 0; i < kNumThreads * kRunsPerThread; ++i) { + pool.Schedule([&] { + TF_ASSERT_OK( + test_runner_.ExecuteReplicated(executable.get(), opts, &device_assn) + .status()); + done.DecrementCount(); + }); + } + done.Wait(); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 352b59f248b..fc0a4f541c6 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -527,32 +527,20 @@ XLA_TEST_P(ReduceWindowTest, Add128In128) { TEST_P(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) { Array2D input_array(14, 14, 1.0f); const auto input = CreateConstantFromArray(input_array, &builder_); - int win_len = 3; int stride = 1; Padding padding = Padding::kSame; ReduceWindowAdd(input, {win_len, win_len}, {stride, stride}, padding); - - auto res = ReferenceUtil::ReduceWindow2DAdd( - input_array, 0.0f, {win_len, win_len}, {stride, stride}, padding); - - ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), - {}, DefaultErrorSpec()); + ComputeAndCompare(&builder_, {}, DefaultErrorSpec()); } TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) { Array2D input_array(6, 4, 1.0f); XlaOp input = Broadcast( CreateConstantFromLiteral(LiteralUtil::One(F32), &builder_), {6, 4}); - Padding padding = Padding::kSame; ReduceWindowAdd(input, {4, 2}, {3, 3}, padding); - - auto res = ReferenceUtil::ReduceWindow2DAdd(input_array, 0.0f, {4, 2}, {3, 3}, - padding); - - ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), - {}, DefaultErrorSpec()); + ComputeAndCompare(&builder_, {}, DefaultErrorSpec()); } INSTANTIATE_TEST_CASE_P(ReduceWindowTestInstance, ReduceWindowTest, @@ -1056,77 +1044,139 @@ struct R2ReduceWindowTestData { int64 base_bounds[2]; int64 window_bounds[2]; int64 strides[2]; + int64 base_dilation[2]; + int64 window_dilation[2]; int64 pad_low[2]; int64 pad_high[2]; int64 layout[2]; Reducer reducer; } kR2TestCases[] = { {/*base_bounds=*/{4, 18}, /*window_bounds=*/{2, 4}, - /*strides=*/{1, 2}, /*pad_low=*/{0, 1}, /*pad_high=*/{1, 1}, + /*strides=*/{1, 2}, + /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1}, + /*pad_low=*/{0, 1}, /*pad_high=*/{1, 1}, /*layout=*/{0, 1}, /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{2, 5}, /*window_bounds=*/{2, 4}, - /*strides=*/{1, 1}, /*pad_low=*/{0, 1}, /*pad_high=*/{1, 2}, + /*strides=*/{1, 1}, + /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1}, + /*pad_low=*/{0, 1}, /*pad_high=*/{1, 2}, /*layout=*/{0, 1}, /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{1, 3}, /*window_bounds=*/{2, 3}, - /*strides=*/{1, 1}, /*pad_low=*/{0, 1}, /*pad_high=*/{1, 1}, + /*strides=*/{1, 1}, + /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1}, + /*pad_low=*/{0, 1}, /*pad_high=*/{1, 1}, /*layout=*/{0, 1}, /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{3, 129}, /*window_bounds=*/{1, 100}, - /*strides=*/{2, 99}, /*pad_low=*/{0, 0}, /*pad_high=*/{35, 35}, + /*strides=*/{2, 99}, + /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1}, + /*pad_low=*/{0, 0}, /*pad_high=*/{35, 35}, /*layout=*/{0, 1}, /*reducer=*/Reducer::kAdd}, // TODO(b/74260408): This test last failed on GPU on 2018-03-08, likely due to a // ptxas bug. #ifndef XLA_TEST_BACKEND_GPU {/*base_bounds=*/{6, 152}, /*window_bounds=*/{2, 25}, - /*strides=*/{5, 4}, /*pad_low=*/{0, 1}, /*pad_high=*/{10, 11}, + /*strides=*/{5, 4}, + /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1}, + /*pad_low=*/{0, 1}, /*pad_high=*/{10, 11}, /*layout=*/{0, 1}, /*reducer=*/Reducer::kAdd}, #endif {/*base_bounds=*/{6, 4}, /*window_bounds=*/{4, 2}, - /*strides=*/{3, 3}, /*pad_low=*/{0, 1}, /*pad_high=*/{0, 1}, + /*strides=*/{3, 3}, + /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1}, + /*pad_low=*/{0, 1}, /*pad_high=*/{0, 1}, /*layout=*/{0, 1}, /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{5, 147}, /*window_bounds=*/{1, 36}, - /*strides=*/{4, 5}, /*pad_low=*/{0, 0}, /*pad_high=*/{17, 17}, + /*strides=*/{4, 5}, + /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1}, + /*pad_low=*/{0, 0}, /*pad_high=*/{17, 17}, /*layout=*/{1, 0}, /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{4, 153}, /*window_bounds=*/{2, 93}, - /*strides=*/{1, 1}, /*pad_low=*/{0, 1}, /*pad_high=*/{46, 46}, + /*strides=*/{1, 1}, + /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1}, + /*pad_low=*/{0, 1}, /*pad_high=*/{46, 46}, /*layout=*/{1, 0}, /*reducer=*/Reducer::kAdd}, // Regression test for a bug that appeared in Inception (b/34784899). {/*base_bounds=*/{28, 28}, /*window_bounds=*/{3, 3}, - /*strides=*/{1, 1}, /*pad_low=*/{1, 1}, /*pad_high=*/{1, 1}, + /*strides=*/{1, 1}, + /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1}, + /*pad_low=*/{1, 1}, /*pad_high=*/{1, 1}, /*layout=*/{1, 0}, /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{4, 4}, /*window_bounds=*/{2, 2}, - /*strides=*/{1, 1}, /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0}, + /*strides=*/{1, 1}, + /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1}, + /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0}, /*layout=*/{1, 0}, - /*reducer=*/Reducer::kAdd}, + /*reducer=*/Reducer::kMax}, + {/*base_bounds=*/{4, 4}, /*window_bounds=*/{2, 2}, + /*strides=*/{1, 1}, + /*base_dilation=*/{1, 1}, /*window_dilation=*/{2, 2}, + /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0}, + /*layout=*/{1, 0}, + /*reducer=*/Reducer::kMax}, + {/*base_bounds=*/{4, 4}, /*window_bounds=*/{2, 2}, + /*strides=*/{1, 1}, + /*base_dilation=*/{2, 2}, /*window_dilation=*/{1, 1}, + /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0}, + /*layout=*/{1, 0}, + /*reducer=*/Reducer::kMax}, + {/*base_bounds=*/{4, 4}, /*window_bounds=*/{2, 2}, + /*strides=*/{2, 2}, + /*base_dilation=*/{2, 2}, /*window_dilation=*/{1, 1}, + /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0}, + /*layout=*/{1, 0}, + /*reducer=*/Reducer::kMax}, + {/*base_bounds=*/{4, 4}, /*window_bounds=*/{2, 2}, + /*strides=*/{2, 2}, + /*base_dilation=*/{2, 2}, /*window_dilation=*/{1, 1}, + /*pad_low=*/{3, 3}, /*pad_high=*/{3, 3}, + /*layout=*/{1, 0}, + /*reducer=*/Reducer::kMax}, + {/*base_bounds=*/{4, 4}, /*window_bounds=*/{2, 2}, + /*strides=*/{2, 2}, + /*base_dilation=*/{2, 2}, /*window_dilation=*/{2, 2}, + /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0}, + /*layout=*/{1, 0}, + /*reducer=*/Reducer::kMax}, // Regression test for a bug that appeared in Inception (b/34784899). {/*base_bounds=*/{4, 32}, /*window_bounds=*/{2, 2}, - /*strides=*/{2, 2}, /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0}, + /*strides=*/{2, 2}, + /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1}, + /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0}, /*layout=*/{1, 0}, /*reducer=*/Reducer::kAdd}, // Regression test for b/73903312: bf16 lacks precision to store result of // very large windows. Testing with a reasonable window larger than 128. {/*base_bounds=*/{8, 130}, /*window_bounds=*/{1, 130}, - /*strides=*/{1, 1}, /*pad_low=*/{0, 130}, /*pad_high=*/{0, 0}, + /*strides=*/{1, 1}, + /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1}, + /*pad_low=*/{0, 130}, /*pad_high=*/{0, 0}, /*layout=*/{1, 0}, /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{8, 256}, /*window_bounds=*/{1, 4}, - /*strides=*/{1, 64}, /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0}, + /*strides=*/{1, 64}, + /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1}, + /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0}, /*layout=*/{1, 0}, /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{4096, 4096}, /*window_bounds=*/{1, 4}, - /*strides=*/{1, 1024}, /*pad_low=*/{0, 0}, /*pad-high=*/{0, 0}, + /*strides=*/{1, 1024}, + /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1}, + /*pad_low=*/{0, 0}, /*pad-high=*/{0, 0}, /*layout=*/{1, 0}, /*reducer=*/Reducer::kAdd}, // Regression test for b/72234705: bf16 lacks precision to store incremental // results on very large windows. Using smaller window with minor dim 128. {/*base_bounds=*/{8, 128}, /*window_bounds=*/{2, 128}, - /*strides=*/{1, 1}, /*pad_low=*/{0, 0}, /*pad-high=*/{0, 0}, + /*strides=*/{1, 1}, + /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1}, + /*pad_low=*/{0, 0}, /*pad-high=*/{0, 0}, /*layout=*/{1, 0}, /*reducer=*/Reducer::kAdd}, }; @@ -1135,9 +1185,11 @@ string R2ReduceWindowTestDataToString( ::testing::tuple>& data) { const auto& param = ::testing::get<0>(data.param); string str = absl::StrCat( - "base_bounds_", absl::StrJoin(param.base_bounds, "x"), // - "__window_bounds_", absl::StrJoin(param.window_bounds, "x"), // - "__strides_", absl::StrJoin(param.strides, "x"), // + "base_bounds_", absl::StrJoin(param.base_bounds, "x"), // + "__window_bounds_", absl::StrJoin(param.window_bounds, "x"), // + "__strides_", absl::StrJoin(param.strides, "x"), // + "__base_dilation_", absl::StrJoin(param.base_dilation, "x"), // + "__window_dilation_", absl::StrJoin(param.window_dilation, "x"), // "__pad_low_", absl::StrJoin(param.pad_low, "x"), "__pad_high_", absl::StrJoin(param.pad_high, "x"), "__layout_", param.layout[0], "_", param.layout[1], // @@ -1158,14 +1210,18 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, XlaBuilder b(TestName()); const auto& param = ::testing::get<0>(GetParam()); - const float kInitValue = 0.0f; Array2D input(param.base_bounds[0], param.base_bounds[1], 1.0f); + if (!::testing::get<1>(GetParam())) { + // We only do this in F32 mode, to avoid precision issues with BF16. + input = *MakeLinspaceArray2D(0, 100, param.base_bounds[0], + param.base_bounds[1]); + } Literal input_literal = LiteralUtil::CreateR2FromArray2DWithLayout( input, LayoutUtil::MakeLayout(param.layout)); XlaOp parameter; - auto input_arg = CreateParameterAndTransferLiteral(0, input_literal, "p0", - &b, ¶meter); + CreateParameterAndTransferLiteral(0, input_literal, "p0", &b, ¶meter); + std::vector> padding(2); for (int i = 0; i < 2; ++i) { padding[i] = {param.pad_low[i], param.pad_high[i]}; @@ -1173,6 +1229,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, auto computation = param.reducer == kAdd ? CreateScalarAddComputation(FloatType(), &b) : CreateScalarMaxComputation(FloatType(), &b); + const float kInitValue = 0.0f; auto init_value = CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b); ReduceWindowWithGeneralPadding( @@ -1181,20 +1238,12 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, /*computation=*/computation, /*window_dimensions=*/param.window_bounds, /*window_strides=*/param.strides, - /*base_dilations=*/{}, - /*window_dilations=*/{}, + /*base_dilations=*/param.base_dilation, + /*window_dilations=*/param.window_dilation, /*padding=*/padding); - auto reduce_func = param.reducer == kAdd - ? +[](float a, float b) { return a + b; } - : +[](float a, float b) { return std::max(a, b); }; - auto expected = ReferenceUtil::ReduceWindow2DGeneric( - /*operand=*/input, /*init=*/kInitValue, /*reduce_func=*/reduce_func, - /*window=*/param.window_bounds, - /*stride=*/param.strides, /*padding=*/padding); - - ComputeAndCompareLiteral(&b, LiteralUtil::CreateFromArray(*expected), - {input_arg.get()}, DefaultErrorSpec()); + ComputeAndCompare(&b, {MaybeConvertLiteralToBfloat16(input_literal)}, + DefaultErrorSpec()); } }; diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc index 4337aa4bf9a..1fa43c65445 100644 --- a/tensorflow/compiler/xla/tests/test_utils_test.cc +++ b/tensorflow/compiler/xla/tests/test_utils_test.cc @@ -258,7 +258,7 @@ XLA_TEST_F(TestUtilsTest, MakeFakeArgumentsForGather) { auto module = ParseHloString(R"( HloModule Test -ENTRY %module(paramater.0: f32[200,100,300], parameter.1: s32[10,2]) -> +ENTRY %module(parameter.0: f32[200,100,300], parameter.1: s32[10,2]) -> f32[10,300] { %parameter.0 = f32[200,100,300] parameter(0) %parameter.1 = s32[10,2] parameter(1) diff --git a/tensorflow/compiler/xla/tests/transfer_manager_test.cc b/tensorflow/compiler/xla/tests/transfer_manager_test.cc index 00b72cedbf5..697c24e6587 100644 --- a/tensorflow/compiler/xla/tests/transfer_manager_test.cc +++ b/tensorflow/compiler/xla/tests/transfer_manager_test.cc @@ -100,6 +100,28 @@ XLA_TEST_F(TransferManagerTest, TransferR1F32) { result); } +XLA_TEST_F(TransferManagerTest, TransferR1F32AwkwardSizes) { + // Test transferring R1s from 0 to kMaxR1Size. The goal is to find bugs + // related to "awkwardly" sized R1s. + constexpr int kMaxR1Size = (1 << 11); + for (int i = 0; i < kMaxR1Size; ++i) { + std::vector inputs(i); + std::iota(inputs.begin(), inputs.end(), 0); + Literal literal = LiteralUtil::CreateR1(inputs); + const Shape& shape = literal.shape(); + auto device_buffer = AllocateDeviceBuffer(shape); + + // Round trip literal through device. + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + Literal result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); + + LiteralTestUtil::ExpectR1Equal(inputs, result); + } +} + XLA_TEST_F(TransferManagerTest, TransferR1LargeF32) { std::vector test_vector(1024 * 1024); std::iota(test_vector.begin(), test_vector.end(), 0); @@ -276,8 +298,8 @@ XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) { } XLA_TEST_F(TransferManagerTest, TransferTokenFromDevice) { - // "Copy" a token from the device. The token has no physical representation so - // no copying is actually performed, but it shouldn't fail. + // "Copy" a token from the device. The token has no physical representation + // so no copying is actually performed, but it shouldn't fail. // TODO(b/110532604): Add transferring the token to device when this is // supported. auto device_buffer = AllocateDeviceBuffer(ShapeUtil::MakeTokenShape()); diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 4edd13c79c7..fe8e83512f4 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -1,8 +1,9 @@ # Tools and utilities that aid in XLA development and usage. -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//tensorflow/compiler/xla:internal"]) +package( + default_visibility = ["//tensorflow/compiler/xla:internal"], + licenses = ["notice"], # Apache 2.0 +) # Filegroup used to collect source files for dependency checking. filegroup( diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index 257b1ef5c3d..411b305c6ab 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -89,6 +89,8 @@ struct Options { Options() : intra_op_thread_pool_size(tensorflow::port::NumSchedulableCPUs()) {} + bool NeedsRealData() const { return !use_fake_data && !compile_only; } + string fake_infeed_shape; string fake_outfeed_shape; @@ -106,6 +108,8 @@ struct Options { int num_runs = 1; int intra_op_thread_pool_size; + + bool compile_only = false; }; StatusOr> CompileExecutable( @@ -355,9 +359,9 @@ StatusOr> ParseRecordIoFile(absl::string_view filename, CHECK(!snapshots.empty()) << "No proto is successfully parsed from the file - the file possibly " "has a mismatched compression option, format, etc."; - CHECK(opts.use_fake_data) - << "Without --use_fake_data, you must pass an HloSnapshot -- HloProto " - "and textual HLO don't carry real data."; + CHECK(!opts.NeedsRealData()) + << "Without --use_fake_data or --compile_only, you must pass an " + "HloSnapshot -- HloProto and textual HLO don't carry real data."; return snapshots; } @@ -373,9 +377,9 @@ StatusOr ParseSingleHloFile(const string& filename, if (s.code() == tensorflow::error::NOT_FOUND) { return s; } - CHECK(opts.use_fake_data) - << "Without --use_fake_data, you must pass an HloSnapshot -- HloProto " - "and textual HLO don't carry real data."; + CHECK(!opts.NeedsRealData()) + << "Without --use_fake_data or --compile_only, you must pass an " + "HloSnapshot -- HloProto and textual HLO don't carry real data."; fprintf(stderr, "%s: is not HloSnapshot. Trying HloProto.\n", filename.c_str()); @@ -457,6 +461,11 @@ int RealMain(absl::Span args, const Options& opts) { exit_status = EXIT_FAILURE; continue; } + + if (opts.compile_only) { + continue; + } + LocalExecutable* executable = executables[i].ValueOrDie().get(); LOG(ERROR) << "Running iteration " << i; StatusOr result_status = @@ -518,6 +527,9 @@ int main(int argc, char** argv) { &opts.intra_op_thread_pool_size, "How many threads to use in the intra-op thread pool. " "Defaults to the number of CPUs."), + tensorflow::Flag("compile_only", &opts.compile_only, + "Whether the input should only be compiled, as opposed " + "to compiled and executed."), }; xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index 55b092cfbaa..dacb5faa228 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -548,6 +548,78 @@ Status EraseElementFromVector(std::vector* container, const T& value) { container->erase(it); return Status::OK(); } + +// MakeCleanup(f) returns an RAII cleanup object that calls 'f' in its +// destructor. The easiest way to use MakeCleanup is with a lambda argument, +// capturing the return value in an 'auto' local variable. Most users will not +// need more sophisticated syntax than that. +// +// Example: +// void func() { +// auto resource = acquire_resource(); +// auto cleanup = MakeCleanup([&] { release_resource(resource); }); +// TF_RETURN_IF_ERROR(...); // phew, calls release_resource! +// } +// +// You can use Cleanup directly, instead of using MakeCleanup and auto, +// but there's rarely a reason to do that. +// +// You can call 'release()' on a Cleanup object to cancel the cleanup +// +// You probably do not want to capture by reference in the cleanup lambda a +// variable that is returned by the function. This can lead to disabling of RVO +// at best, and undefined behavior at worst. +template +class Cleanup { + public: + Cleanup() : released_(true), f_() {} + + template + explicit Cleanup(G&& f) : f_(std::forward(f)) {} + + Cleanup(Cleanup&& src) : released_(src.is_released()), f_(src.release()) {} + + // Implicitly move-constructible from any compatible Cleanup. The source + // will be released as if src.release() were called. A moved-from Cleanup can + // be safely destroyed or reassigned. + template + Cleanup(Cleanup&& src) : released_(src.is_released()), f_(src.release()) {} + + // Assignment to a Cleanup object behaves like destroying it and making a new + // one in its place, analogous to unique_ptr semantics. + Cleanup& operator=(Cleanup&& src) { + if (!released_) std::move(f_)(); + released_ = src.released_; + f_ = src.release(); + return *this; + } + + ~Cleanup() { + if (!released_) std::move(f_)(); + } + + // Releases the cleanup function instead of running it. Hint: use + // c.release()() to run early. + F release() { + released_ = true; + return std::move(f_); + } + + bool is_released() const { return released_; } + + private: + static_assert(!std::is_reference::value, "F must not be a reference"); + + bool released_ = false; + F f_; +}; + +template ::type> +ABSL_MUST_USE_RESULT Cleanup MakeCleanup(F&& f) { + return Cleanup(std::forward(f)); +} + } // namespace xla #define XLA_LOG_LINES(SEV, STRING) \ diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 67f76d00703..7dc87ae08b6 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -610,7 +610,7 @@ message OpSharding { // all-to-all). message ReplicaGroup { // The ids of the replicas that belongs to the same group. The ordering of the - // ids matters in some op (e.g., all-to-all). + // ids matters in some ops (e.g., all-to-all). repeated int64 replica_ids = 1; } diff --git a/tensorflow/compiler/xrt/BUILD b/tensorflow/compiler/xrt/BUILD index acd984f9e99..694e75a447d 100644 --- a/tensorflow/compiler/xrt/BUILD +++ b/tensorflow/compiler/xrt/BUILD @@ -44,12 +44,15 @@ cc_library( srcs = [ "xrt_compilation_cache.cc", "xrt_device.cc", + "xrt_memory_manager.cc", "xrt_state.cc", "xrt_util.cc", ], hdrs = [ "xrt_compilation_cache.h", "xrt_device.h", + "xrt_memory_manager.h", + "xrt_refptr.h", "xrt_state.h", "xrt_util.h", ], diff --git a/tensorflow/compiler/xrt/cc/BUILD b/tensorflow/compiler/xrt/cc/BUILD index 5c1e86b76b4..59a965945ad 100644 --- a/tensorflow/compiler/xrt/cc/BUILD +++ b/tensorflow/compiler/xrt/cc/BUILD @@ -1,7 +1,6 @@ -licenses(["notice"]) # Apache 2.0 - package( default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 ) load( diff --git a/tensorflow/compiler/xrt/client/BUILD b/tensorflow/compiler/xrt/client/BUILD index 3908f026bcf..c06ae7fb1cb 100644 --- a/tensorflow/compiler/xrt/client/BUILD +++ b/tensorflow/compiler/xrt/client/BUILD @@ -65,6 +65,7 @@ cc_library( ":xrt_tf_client", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/service:computation_placer", diff --git a/tensorflow/compiler/xrt/client/xrt_client.cc b/tensorflow/compiler/xrt/client/xrt_client.cc index c1f06e91c4f..a91aba17650 100644 --- a/tensorflow/compiler/xrt/client/xrt_client.cc +++ b/tensorflow/compiler/xrt/client/xrt_client.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xrt/client/xrt_client.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xrt/client/xrt_tf_client.h" @@ -71,8 +72,11 @@ xla::StatusOr DeserializeTensorProtoAsLiteral( } // namespace -XrtBuffer::XrtBuffer(XrtTensorHandle handle, xla::Shape shape) - : handle_(std::move(handle)), shape_(std::move(shape)) {} +XrtBuffer::XrtBuffer(XrtTensorHandle handle, int xrt_device_ordinal, + xla::Shape shape) + : handle_(std::move(handle)), + xrt_device_ordinal_(xrt_device_ordinal), + shape_(std::move(shape)) {} XrtBuffer::~XrtBuffer() { Delete(); } @@ -100,17 +104,27 @@ XrtBuffer::~XrtBuffer() { Delete(); } "XRTAllocate", {&literal_handle}, /*output_arity=*/1, /*attrs=*/{}, tf_device_id)[0]); - return std::make_shared(std::move(buffer_handle), literal.shape()); + return std::make_shared(std::move(buffer_handle), + xrt_device_ordinal, literal.shape()); } /*static*/ xla::StatusOr> XrtBuffer::MakeTuple( const std::shared_ptr& context, - const std::vector>& elements) { + const std::vector>& elements, + int xrt_device_ordinal) { if (elements.empty()) { - return errors::Unimplemented( - "The arity zero case of MakeTuple is not implemented."); + // XRTMakeTuple cannot construct empty tuples. Construct via a literal + // instead. + return FromLiteral(context, xrt_device_ordinal, + xla::LiteralUtil::MakeTuple({})); } - int tf_device_id = elements[0]->handle().device_id(); + + if (xrt_device_ordinal < 0 || + xrt_device_ordinal >= context->tf_device_ids().size()) { + return errors::InvalidArgument("Invalid XRT device ordinal ", + xrt_device_ordinal); + } + int tf_device_id = context->tf_device_ids().at(xrt_device_ordinal); xrt::XLATupleNode tuple_description; std::vector element_shapes; element_shapes.reserve(elements.size()); @@ -144,7 +158,8 @@ XrtBuffer::~XrtBuffer() { Delete(); } XrtTensorHandle buffer_handle = std::move(context->tf_context()->EnqueueOp( "XRTMakeTuple", args, /*output_arity=*/1, attrs, tf_device_id)[0]); return std::make_shared( - std::move(buffer_handle), xla::ShapeUtil::MakeTupleShape(element_shapes)); + std::move(buffer_handle), xrt_device_ordinal, + xla::ShapeUtil::MakeTupleShape(element_shapes)); } xla::StatusOr XrtBuffer::ToLiteral() const { @@ -193,8 +208,8 @@ XrtBuffer::DestructureTuple() { handle_.context()->EnqueueOp("XRTSubTuple", {&handle_, &index}, /*output_arity=*/1, /*attrs=*/{}, handle_.device_id())[0]); - output.push_back( - std::make_shared(std::move(sub), shape_.tuple_shapes(i))); + output.push_back(std::make_shared( + std::move(sub), xrt_device_ordinal_, shape_.tuple_shapes(i))); } return output; } @@ -343,7 +358,8 @@ xla::StatusOr> XrtExecutable::Execute( XrtTensorHandle result_handle = std::move(handle_.context()->EnqueueOp( "XRTExecute", inputs, /*output_arity=*/1, attrs, tf_device_id)[0]); - return std::make_shared(std::move(result_handle), shape_.result()); + return std::make_shared(std::move(result_handle), + xrt_device_ordinal, shape_.result()); } xla::StatusOr>> @@ -453,7 +469,7 @@ XrtExecutable::ExecuteReplicated( // TODO(phawkins): use a per-core result shape here. results(i, j) = std::make_shared( - std::move(outputs[output_num]), shape_.result()); + std::move(outputs[output_num]), xrt_device_ordinal, shape_.result()); ++output_num; } } diff --git a/tensorflow/compiler/xrt/client/xrt_client.h b/tensorflow/compiler/xrt/client/xrt_client.h index c54f156e95b..fe0b650fb95 100644 --- a/tensorflow/compiler/xrt/client/xrt_client.h +++ b/tensorflow/compiler/xrt/client/xrt_client.h @@ -55,7 +55,8 @@ class XrtBuffer { // Builds a new XrtBuffer tuple from its constituent parts. static xla::StatusOr> MakeTuple( const std::shared_ptr& context, - const std::vector>& elements); + const std::vector>& elements, + int xrt_device_ordinal); // Converts an XrtBuffer to an XLA literal, copying the buffer from the remote // host. Blocks until the buffer is available. @@ -71,7 +72,7 @@ class XrtBuffer { // tensors and vice-versa for TF interoperability. XrtBuffer() = default; - XrtBuffer(XrtTensorHandle handle, xla::Shape shape); + XrtBuffer(XrtTensorHandle handle, int xrt_device_ordinal, xla::Shape shape); ~XrtBuffer(); // Calls Delete(). // A buffer reference is moveable but not copyable. @@ -81,11 +82,13 @@ class XrtBuffer { XrtBuffer& operator=(XrtBuffer&&) = default; const XrtTensorHandle& handle() const { return handle_; } + int xrt_device_ordinal() const { return xrt_device_ordinal_; } const xla::Shape& shape() const { return shape_; } private: // Tensor that contains the XRT allocation ID. XrtTensorHandle handle_; + int xrt_device_ordinal_; xla::Shape shape_; }; diff --git a/tensorflow/compiler/xrt/client/xrt_client_test.cc b/tensorflow/compiler/xrt/client/xrt_client_test.cc index e64c986f44e..d9e94b01d2c 100644 --- a/tensorflow/compiler/xrt/client/xrt_client_test.cc +++ b/tensorflow/compiler/xrt/client/xrt_client_test.cc @@ -302,6 +302,22 @@ TEST_F(XrtClientTest, TupleDestructuringAndDelete) { pieces[1]->Delete(); } +TEST_F(XrtClientTest, EmptyTuples) { + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr context, MakeContext()); + + // Tests sending a literal to and from the device. + TF_ASSERT_OK_AND_ASSIGN( + std::shared_ptr buffer, + XrtBuffer::MakeTuple(context, /*elements=*/{}, /*xrt_device_ordinal=*/0)); + TF_ASSERT_OK_AND_ASSIGN(std::vector> pieces, + buffer->DestructureTuple()); + EXPECT_EQ(pieces.size(), 0); + + TF_ASSERT_OK_AND_ASSIGN(xla::Literal out, buffer->ToLiteral()); + ASSERT_TRUE(out.shape().IsTuple()); + EXPECT_EQ(out.shape().tuple_shapes_size(), 0); +} + TEST_F(XrtClientTest, TupleConstructionAndDestructuring) { TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr context, MakeContext()); @@ -326,8 +342,9 @@ TEST_F(XrtClientTest, TupleConstructionAndDestructuring) { EXPECT_TRUE(xla::LiteralTestUtil::Equal(b, b_in)); std::vector> elems = {a_buffer, b_buffer}; - TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr buffer, - XrtBuffer::MakeTuple(context, elems)); + TF_ASSERT_OK_AND_ASSIGN( + std::shared_ptr buffer, + XrtBuffer::MakeTuple(context, elems, /*xrt_device_ordinal=*/0)); TF_ASSERT_OK_AND_ASSIGN(std::vector> pieces, buffer->DestructureTuple()); diff --git a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc index d89dc4642be..231387e314f 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xrt/xrt.pb.h" #include "tensorflow/compiler/xrt/xrt_compilation_cache.h" #include "tensorflow/compiler/xrt/xrt_device.h" +#include "tensorflow/compiler/xrt/xrt_memory_manager.h" #include "tensorflow/compiler/xrt/xrt_state.h" #include "tensorflow/compiler/xrt/xrt_util.h" #include "tensorflow/core/framework/op_kernel.h" @@ -72,28 +73,31 @@ uint32 GetXLARandomSeed() { } xla::StatusOr GetInputBuffers( - ResourceMgr* rm, const std::vector& input_coords, - bool release_inputs) { + XRTMemoryManager::WorkingSet* working_set, xla::Backend* backend, + const std::vector& input_coords, bool release_inputs) { InputBuffers input_buffers; input_buffers.input_tuples.reserve(input_coords.size()); input_buffers.input_allocations.reserve(input_coords.size()); input_buffers.input_pointers.reserve(input_coords.size()); for (size_t i = 0; i < input_coords.size(); ++i) { - XRTTupleAllocation* tuple; TF_RETURN_IF_ERROR( - XRTTupleAllocation::Lookup(rm, input_coords[i].handle, &tuple)); + working_set->LookupAndPin(backend, input_coords[i].handle)); + auto tuple = working_set->PinnedTuples().back(); input_buffers.input_tuples.emplace_back(tuple); if (release_inputs) { // We are holding a reference to the tuple, so we can safely delete it // from the resource manager here. - TF_RETURN_IF_ERROR(XRTTupleAllocation::DeleteFromResourceManager( - rm, input_coords[i].handle)); + TF_RETURN_IF_ERROR( + working_set->MemoryManager()->Release(input_coords[i].handle)); VLOG(2) << "Released allocation handle " << input_coords[i].handle; } if (input_coords[i].index.empty()) { - input_buffers.input_allocations.emplace_back(tuple->ToShapedBuffer()); + TF_ASSIGN_OR_RETURN(xla::ShapedBuffer shaped_buffer, + tuple->ToShapedBuffer()); + input_buffers.input_allocations.emplace_back(std::move(shaped_buffer)); } else { - xla::ShapedBuffer shaped_buffer = tuple->ToShapedBuffer(); + TF_ASSIGN_OR_RETURN(xla::ShapedBuffer shaped_buffer, + tuple->ToShapedBuffer()); TF_ASSIGN_OR_RETURN(xla::ShapedBuffer sub_shaped_buffer, shaped_buffer.SubShapedBuffer(input_coords[i].index)); input_buffers.input_allocations.emplace_back( @@ -107,28 +111,25 @@ xla::StatusOr GetInputBuffers( } xla::StatusOr GetChainedOpInputs( - const xrt::XRTChainedExecuteOp& op, int current_index, - absl::Span> ops_outputs) { + const xrt::XRTChainedExecuteOp& op, + absl::Span> op_inputs) { InputBuffers input_buffers; input_buffers.input_tuples.reserve(op.inputs_size()); input_buffers.input_allocations.reserve(op.inputs_size()); input_buffers.input_pointers.reserve(op.inputs_size()); - for (auto& input : op.inputs()) { - if (input.op_index() >= current_index) { - return errors::InvalidArgument( - "Input index ", input.op_index(), - " is above the current position: ", current_index); - } - input_buffers.input_tuples.emplace_back(ops_outputs[input.op_index()]); + for (int i = 0; i < op.inputs_size(); ++i) { + auto& input = op.inputs(i); + input_buffers.input_tuples.emplace_back(op_inputs[i]); // Thanks to the greatness of proto3, there is no way to query for // explicitly set fields, so the default for output_index (zero) means no // sub-index. As consequence, the real index is output_index - 1. if (input.output_index() == 0) { - input_buffers.input_allocations.emplace_back( - input_buffers.input_tuples.back()->ToShapedBuffer()); + TF_ASSIGN_OR_RETURN(xla::ShapedBuffer shaped_buffer, + input_buffers.input_tuples.back()->ToShapedBuffer()); + input_buffers.input_allocations.emplace_back(std::move(shaped_buffer)); } else { - xla::ShapedBuffer shaped_buffer = - input_buffers.input_tuples.back()->ToShapedBuffer(); + TF_ASSIGN_OR_RETURN(xla::ShapedBuffer shaped_buffer, + input_buffers.input_tuples.back()->ToShapedBuffer()); TF_ASSIGN_OR_RETURN( xla::ShapedBuffer sub_shaped_buffer, shaped_buffer.SubShapedBuffer({input.output_index() - 1})); @@ -142,7 +143,7 @@ xla::StatusOr GetChainedOpInputs( return std::move(input_buffers); } -xla::StatusOr> ExecuteComputation( +xla::StatusOr> RunExecutable( OpKernelContext* context, XRTGenericDeviceAccessor::ScopedRef* device_ref, xla::LocalExecutable* executable, const InputBuffers& input_buffers, se::Stream* stream, int rng_seed) { @@ -190,15 +191,35 @@ xla::StatusOr> ExecuteComputation( } xla::StatusOr> ExecuteComputation( - OpKernelContext* context, ResourceMgr* rm, + OpKernelContext* context, XRTMemoryManager* memory_manager, + XRTGenericDeviceAccessor::ScopedRef* device_ref, + xla::LocalExecutable* executable, const InputBuffers& input_buffers, + se::Stream* stream, int rng_seed) { + auto runfn = [&]() { + return RunExecutable(context, device_ref, executable, input_buffers, stream, + rng_seed); + }; + + // We pass zero as requested_free_size as there is no simple way to get the + // peak heap size. Upon zero, the Run() API will try to free chunks of device + // memory, until either the runfn can run, or we run out of freeable memory. + return memory_manager->Run>( + runfn, device_ref->backend(), device_ref->device_ordinal(), + /*requested_free_size=*/0); +} + +xla::StatusOr> ExecuteComputation( + OpKernelContext* context, const RefPtr& memory_manager, XRTGenericDeviceAccessor::ScopedRef* device_ref, xla::LocalExecutable* executable, const std::vector& input_coords, bool release_inputs, se::Stream* stream, int rng_seed) { + XRTMemoryManager::WorkingSet working_set(memory_manager); TF_ASSIGN_OR_RETURN(InputBuffers input_buffers, - GetInputBuffers(rm, input_coords, release_inputs)); - return ExecuteComputation(context, device_ref, executable, input_buffers, - stream, rng_seed); + GetInputBuffers(&working_set, device_ref->backend(), + input_coords, release_inputs)); + return ExecuteComputation(context, memory_manager.get(), device_ref, + executable, input_buffers, stream, rng_seed); } // XRTExecuteOp @@ -265,8 +286,9 @@ Status XRTExecuteOp::DoWork(OpKernelContext* context) { se::Stream* stream = context->op_device_context() ? context->op_device_context()->stream() : nullptr; + RefPtr memory_manager = XRTMemoryManager::Get(rm); TF_ASSIGN_OR_RETURN(std::vector input_coords, - GetComputationInputs(context, rm, "input_handles")); + GetComputationInputs(context, "input_handles")); std::unique_ptr entry; TF_RETURN_IF_ERROR(cache->Lookup(compilation_handle, &entry)); @@ -279,10 +301,11 @@ Status XRTExecuteOp::DoWork(OpKernelContext* context) { TF_ASSIGN_OR_RETURN( RefPtr output_tuple, - ExecuteComputation(context, rm, &device_ref, executable, input_coords, - release_inputs, stream, rng_seed)); + ExecuteComputation(context, memory_manager, &device_ref, executable, + input_coords, release_inputs, stream, rng_seed)); - return CreateExecuteOutput(context, rm, std::move(output_tuple), + return CreateExecuteOutput(context, memory_manager.get(), + std::move(output_tuple), config_proto.return_exploded_tuple()); } @@ -346,22 +369,23 @@ Status XRTExecuteChainedOp::DoWork(OpKernelContext* context) { se::Stream* stream = context->op_device_context() ? context->op_device_context()->stream() : nullptr; - auto execute_op = - [&](const xrt::XRTChainedExecuteOp& op, int current_index, - absl::Span> ops_outputs) + RefPtr memory_manager = XRTMemoryManager::Get(rm); + auto execute_op = [&](const xrt::XRTChainedExecuteOp& op, + absl::Span> op_inputs) -> xla::StatusOr> { TF_ASSIGN_OR_RETURN(InputBuffers input_buffers, - GetChainedOpInputs(op, current_index, ops_outputs)); + GetChainedOpInputs(op, op_inputs)); std::unique_ptr entry; TF_RETURN_IF_ERROR(cache->Lookup(op.computation_handle(), &entry)); xla::LocalExecutable* executable = entry->get().get_executable(); - return ExecuteComputation(context, &device_ref, executable, input_buffers, - stream, rng_seed); + return ExecuteComputation(context, memory_manager.get(), &device_ref, + executable, input_buffers, stream, rng_seed); }; - return ExecuteChained(context, rm, plan, config, execute_op); + return ExecuteChained(context, memory_manager, device_ref.backend(), + device_ref.device_ordinal(), plan, config, execute_op); } XRTExecuteChainedOp::~XRTExecuteChainedOp() = default; diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h index 8a54e0987e5..c3511b1d5d4 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h +++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xrt/xrt.pb.h" #include "tensorflow/compiler/xrt/xrt_device.h" +#include "tensorflow/compiler/xrt/xrt_memory_manager.h" #include "tensorflow/compiler/xrt/xrt_state.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/framework/op_kernel.h" @@ -103,8 +104,8 @@ class XRTStateHelpers { TF_RET_CHECK( TensorShapeUtils::IsScalar(input_tensor_list[input_index].shape())); int64 key = input_tensor_list[input_index].scalar()(); - TF_RETURN_IF_ERROR( - XRTTupleAllocation::Lookup(rm, key, &input.allocation)); + TF_ASSIGN_OR_RETURN(input.allocation, + XRTMemoryManager::Get(rm)->Lookup(key)); input.release_allocation_after_use = release_this_input; } } @@ -192,17 +193,14 @@ class XRTAllocateOp : public OpKernel { class DeviceAccessor::ScopedRef device_ref; OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref)); + RefPtr memory_manager = XRTMemoryManager::Get(rm); XRTTupleAllocation* allocation; OP_REQUIRES_OK(ctx, XRTTupleAllocation::CreateAndTransfer( - literal, device_ref.backend(), + literal, memory_manager.get(), device_ref.backend(), device_ref.device_ordinal(), &allocation)); - // Intern takes ownership of our reference to allocation. - int64 key; - OP_REQUIRES_OK(ctx, allocation->Intern(rm, &key)); - Tensor output(DT_INT64, TensorShape({})); - output.scalar()() = key; + output.scalar()() = memory_manager->Register(allocation); ctx->set_output(0, output); } }; @@ -291,17 +289,14 @@ class XRTAllocateFromTensorOp : public OpKernel { class DeviceAccessor::ScopedRef device_ref; OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref)); + RefPtr memory_manager = XRTMemoryManager::Get(rm); XRTTupleAllocation* allocation; OP_REQUIRES_OK(ctx, XRTTupleAllocation::CreateAndTransfer( - literal, device_ref.backend(), + literal, memory_manager.get(), device_ref.backend(), device_ref.device_ordinal(), &allocation)); - // Intern takes ownership of our reference to allocation. - int64 key; - OP_REQUIRES_OK(ctx, allocation->Intern(rm, &key)); - Tensor output(DT_INT64, TensorShape({})); - output.scalar()() = key; + output.scalar()() = memory_manager->Register(allocation); ctx->set_output(0, output); } @@ -342,28 +337,22 @@ class XRTSubTupleOp : public OpKernel { ResourceMgr* rm; OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); - XRTTupleAllocation* allocation; - OP_REQUIRES_OK( - ctx, XRTTupleAllocation::Lookup(rm, allocation_handle, &allocation)); - core::ScopedUnref allocation_unref(allocation); + RefPtr memory_manager = XRTMemoryManager::Get(rm); + RefPtr allocation; + OP_REQUIRES_OK(ctx, memory_manager->Lookup(allocation_handle, &allocation)); if (discard_) { VLOG(2) << "Releasing handle " << allocation_handle; - OP_REQUIRES_OK(ctx, XRTTupleAllocation::DeleteFromResourceManager( - rm, allocation_handle)); + OP_REQUIRES_OK(ctx, memory_manager->Release(allocation_handle)); } XRTTupleAllocation* suballocation; OP_REQUIRES_OK( - ctx, XRTTupleAllocation::MakeSubBuffer(allocation, shape_index, + ctx, XRTTupleAllocation::MakeSubBuffer(allocation.get(), shape_index, &suballocation, !discard_)); - // Intern takes ownership of our reference to suballocation. - int64 key; - OP_REQUIRES_OK(ctx, suballocation->Intern(rm, &key)); - Tensor output(DT_INT64, TensorShape({})); - output.scalar()() = key; + output.scalar()() = memory_manager->Register(suballocation); ctx->set_output(0, output); } }; @@ -398,14 +387,6 @@ class XRTMakeTupleOp : public OpKernel { // exit. std::vector input_vector( arg_list.size()); - auto cleanup = gtl::MakeCleanup([&input_vector] { - for (auto& input : input_vector) { - if (input.allocation != nullptr) { - input.allocation->Unref(); - } - } - }); - ResourceMgr* rm; OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); @@ -425,28 +406,22 @@ class XRTMakeTupleOp : public OpKernel { OP_REQUIRES_OK( ctx, DeviceAccessor::InitScopedRef(ctx, device_ordinal, &device_ref)); + RefPtr memory_manager = XRTMemoryManager::Get(rm); XRTTupleAllocation* output_allocation; OP_REQUIRES_OK(ctx, XRTTupleAllocation::MakeTuple( - device_ref.backend(), device_ref.device_ordinal(), - tuple_shape_tree, &output_allocation)); - // Add a ScopedUnref to simplify the error path while calling - // DeleteFromResourceManager. - core::ScopedUnref unref(output_allocation); + memory_manager.get(), device_ref.backend(), + device_ref.device_ordinal(), tuple_shape_tree, + &output_allocation)); + RefPtr output_ptr(output_allocation); for (int i = 0; i < input_vector.size(); ++i) { if (input_vector[i].release_allocation_after_use) { - OP_REQUIRES_OK(ctx, XRTTupleAllocation::DeleteFromResourceManager( - rm, arg_list[i].scalar()())); + OP_REQUIRES_OK(ctx, + memory_manager->Release(arg_list[i].scalar()())); } } - // Intern takes ownership of a reference to output_allocation, so add - // another since the ScopedUnref will release one when this method exits. - output_allocation->Ref(); - int64 key; - OP_REQUIRES_OK(ctx, output_allocation->Intern(rm, &key)); - Tensor output(DT_INT64, TensorShape({})); - output.scalar()() = key; + output.scalar()() = memory_manager->Register(std::move(output_ptr)); ctx->set_output(0, output); } }; @@ -473,15 +448,13 @@ class XRTReadLiteralOp : public OpKernel { ResourceMgr* rm; OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); - XRTTupleAllocation* allocation; - OP_REQUIRES_OK( - ctx, XRTTupleAllocation::Lookup(rm, allocation_handle, &allocation)); - core::ScopedUnref allocation_unref(allocation); + RefPtr memory_manager = XRTMemoryManager::Get(rm); + RefPtr allocation; + OP_REQUIRES_OK(ctx, memory_manager->Lookup(allocation_handle, &allocation)); if (discard_) { VLOG(2) << "Releasing handle " << allocation_handle; - OP_REQUIRES_OK(ctx, XRTTupleAllocation::DeleteFromResourceManager( - rm, allocation_handle)); + OP_REQUIRES_OK(ctx, memory_manager->Release(allocation_handle)); } // We are guaranteed that the underlying device object won't be deleted out @@ -491,9 +464,7 @@ class XRTReadLiteralOp : public OpKernel { ctx, allocation->device_ordinal(), &device_ref)); xla::Literal literal(allocation->on_host_shape()); - OP_REQUIRES_OK( - ctx, allocation->ToLiteral(device_ref.backend(), - device_ref.device_ordinal(), &literal)); + OP_REQUIRES_OK(ctx, allocation->ToLiteral(device_ref.backend(), &literal)); xla::LiteralProto literal_proto = literal.ToProto(); Tensor output(DT_STRING, TensorShape({})); @@ -529,15 +500,13 @@ class XRTReadToTensorOp : public OpKernel { ResourceMgr* rm; OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); - XRTTupleAllocation* allocation; - OP_REQUIRES_OK( - ctx, XRTTupleAllocation::Lookup(rm, allocation_handle, &allocation)); - core::ScopedUnref allocation_unref(allocation); + RefPtr memory_manager = XRTMemoryManager::Get(rm); + RefPtr allocation; + OP_REQUIRES_OK(ctx, memory_manager->Lookup(allocation_handle, &allocation)); if (discard_) { VLOG(2) << "Releasing handle " << allocation_handle; - OP_REQUIRES_OK(ctx, XRTTupleAllocation::DeleteFromResourceManager( - rm, allocation_handle)); + OP_REQUIRES_OK(ctx, memory_manager->Release(allocation_handle)); } // We are guaranteed that the underlying device object won't be deleted out @@ -573,15 +542,14 @@ class XRTReadToTensorOp : public OpKernel { XRTTupleAllocation* sub; TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer( - allocation, index, &sub, /*alias_parent_allocation=*/true)); + allocation.get(), index, &sub, /*alias_parent_allocation=*/true)); core::ScopedUnref sub_unref(sub); xla::MutableBorrowingLiteral literal; TF_RETURN_IF_ERROR(HostTensorToMutableBorrowingLiteral( xla::LayoutUtil::GetWithDefaultLayout(*subshape), output_tensor, &literal)); - TF_RETURN_IF_ERROR(sub->ToLiteral( - device_ref.backend(), device_ref.device_ordinal(), &literal)); + TF_RETURN_IF_ERROR(sub->ToLiteral(device_ref.backend(), &literal)); ++output; return Status::OK(); @@ -624,10 +592,10 @@ class XRTWriteLiteralOp : public OpKernel { ResourceMgr* rm; OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); - XRTTupleAllocation* allocation; - OP_REQUIRES_OK( - ctx, XRTTupleAllocation::Lookup(rm, allocation_handle, &allocation)); - core::ScopedUnref allocation_unref(allocation); + RefPtr memory_manager = XRTMemoryManager::Get(rm); + RefPtr allocation; + OP_REQUIRES_OK(ctx, memory_manager->Lookup(allocation_handle, &allocation)); + // We are guaranteed that the underlying device object won't be deleted out // from under us, while the ScopedRef is live. typename DeviceAccessor::ScopedRef device_ref; @@ -657,12 +625,12 @@ class XRTReleaseAllocationOp : public OpKernel { ResourceMgr* rm; OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); + RefPtr memory_manager = XRTMemoryManager::Get(rm); const Tensor& allocation_handle = ctx->input(0); auto flat_keys = allocation_handle.flat(); for (int64 i = 0; i < flat_keys.size(); ++i) { int64 key = flat_keys(i); - OP_REQUIRES_OK(ctx, - XRTTupleAllocation::DeleteFromResourceManager(rm, key)); + OP_REQUIRES_OK(ctx, memory_manager->Release(key)); VLOG(2) << "Released allocation handle " << key; } } @@ -684,7 +652,7 @@ class XRTReleaseAllAllocationsOp : public OpKernel { ResourceMgr* rm; OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); - OP_REQUIRES_OK(ctx, XRTTupleAllocation::ReleaseAllAllocations(rm)); + XRTMemoryManager::Get(rm)->ReleaseAllAllocations(); } }; @@ -701,11 +669,11 @@ class XRTCompactAllocationsOp : public OpKernel { ResourceMgr* rm; OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); + RefPtr memory_manager = XRTMemoryManager::Get(rm); class DeviceAccessor::ScopedRef device_ref; OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref)); - OP_REQUIRES_OK(ctx, - XRTTupleAllocation::CompactAllocations( - rm, device_ref.backend(), device_ref.device_ordinal())); + OP_REQUIRES_OK(ctx, memory_manager->CompactAllocations( + device_ref.backend(), device_ref.device_ordinal())); } }; diff --git a/tensorflow/compiler/xrt/tests/BUILD b/tensorflow/compiler/xrt/tests/BUILD index 3a19327e5b5..f8341e1ee0f 100644 --- a/tensorflow/compiler/xrt/tests/BUILD +++ b/tensorflow/compiler/xrt/tests/BUILD @@ -1,13 +1,12 @@ -licenses(["notice"]) # Apache 2.0 - package( default_visibility = [ "//learning/brain:__subpackages__", "//tensorflow/compiler:__subpackages__", ], + licenses = ["notice"], # Apache 2.0 ) -load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test", "tf_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_cuda_cc_test") load( "//tensorflow/core:platform/default/build_config_root.bzl", "tf_cuda_tests_tags", @@ -34,6 +33,8 @@ cc_library( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xrt:xrt_proto", "//tensorflow/compiler/xrt:xrt_server", diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc index 305b3a67fae..b5108acff16 100644 --- a/tensorflow/compiler/xrt/tests/raw_api_test.cc +++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc @@ -25,6 +25,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" @@ -47,6 +49,15 @@ limitations under the License. namespace tensorflow { namespace { +class XrtClientSession : public ClientSession { + public: + explicit XrtClientSession(const Scope& scope) : ClientSession(scope) { + auto clear_all = ops::XRTReleaseAllAllocations(scope); + std::vector outputs; + TF_CHECK_OK(Run(ClientSession::FeedType(), {}, {clear_all}, &outputs)); + } +}; + string* xla_test_device_ptr; // initial value set in main() string* xla_platform_ptr; // initial value set in main() @@ -235,6 +246,26 @@ xla::XlaComputation AddAndSubTuple() { return builder.Build().ValueOrDie(); } +xla::XlaComputation BroadcastComputation( + const xla::Shape& shape, absl::Span dimensions) { + xla::XlaBuilder builder("BroadcastComputation"); + auto p0 = xla::Parameter(&builder, 0, shape, "P0"); + xla::Broadcast(p0, dimensions); + return builder.Build().ValueOrDie(); +} + +xla::XlaComputation IsEqualComputation(const xla::Shape& shape) { + xla::XlaBuilder builder("IsEqualComputation"); + auto p0 = xla::Parameter(&builder, 0, shape, "P0"); + auto p1 = xla::Parameter(&builder, 1, shape, "P1"); + auto cmp = + xla::Ne(xla::Sub(p0, p1), xla::Zero(&builder, shape.element_type())); + auto icmp = xla::ConvertElementType(cmp, xla::S32); + xla::ReduceAll(icmp, xla::Zero(&builder, xla::S32), + xla::CreateScalarAddComputation(xla::S32, &builder)); + return builder.Build().ValueOrDie(); +} + void StoreComputationSnapshot(const xla::XlaComputation& computation, xla::HloSnapshot* dst) { auto snapshot = computation.Snapshot().ValueOrDie(); @@ -279,7 +310,7 @@ TEST(RawApiTest, AllocFromTensor) { auto read_back = ops::XRTReadLiteralAndRelease(root, handle); TF_ASSERT_OK(root.status()); - ClientSession session(root); + XrtClientSession session(root); std::vector outputs; TF_EXPECT_OK(session.Run({read_back}, &outputs)); EXPECT_EQ(outputs.size(), 1); @@ -310,7 +341,7 @@ TEST(RawApiTest, AllocFromTensorTuple) { auto read_back = ops::XRTReadLiteralAndRelease(root, handle); TF_ASSERT_OK(root.status()); - ClientSession session(root); + XrtClientSession session(root); std::vector outputs; TF_EXPECT_OK(session.Run({read_back}, &outputs)); EXPECT_EQ(outputs.size(), 1); @@ -336,7 +367,7 @@ TEST(RawApiTest, AllocFromTensorTupleSingle) { auto read_back = ops::XRTReadLiteralAndRelease(root, handle); TF_ASSERT_OK(root.status()); - ClientSession session(root); + XrtClientSession session(root); std::vector outputs; TF_EXPECT_OK(session.Run({read_back}, &outputs)); EXPECT_EQ(outputs.size(), 1); @@ -362,7 +393,7 @@ TEST(RawApiTest, AllocFromTensorRelayout) { auto read_back = ops::XRTReadLiteralAndRelease(root, handle); TF_ASSERT_OK(root.status()); - ClientSession session(root); + XrtClientSession session(root); std::vector outputs; TF_EXPECT_OK(session.Run({read_back}, &outputs)); EXPECT_EQ(outputs.size(), 1); @@ -389,7 +420,7 @@ TEST(RawApiTest, AllocAndRewrite) { auto read_back = ops::XRTReadLiteral(root, handle); TF_ASSERT_OK(root.status()); - ClientSession session(root); + XrtClientSession session(root); std::vector outputs; TF_EXPECT_OK(session.Run({read_back, handle}, &outputs)); EXPECT_EQ(outputs.size(), 2); @@ -442,7 +473,7 @@ TEST(RawApiTest, AllocReleaseMany) { auto handle2 = ops::XRTAllocate(root, value2); TF_ASSERT_OK(root.status()); - ClientSession session(root); + XrtClientSession session(root); std::vector outputs; TF_EXPECT_OK(session.Run({handle1, handle2}, &outputs)); EXPECT_EQ(outputs.size(), 2); @@ -491,7 +522,7 @@ TEST(RawApiTest, CompileAndReleaseMany) { auto c_handle2 = ops::XRTCompile(root, computation2); TF_ASSERT_OK(root.status()); - ClientSession session(root); + XrtClientSession session(root); std::vector outputs; TF_EXPECT_OK(session.Run({c_handle1.handle, c_handle2.handle}, &outputs)); EXPECT_EQ(outputs.size(), 2); @@ -518,7 +549,7 @@ TEST(RawApiTest, AllocAndClearAll) { auto handle = ops::XRTAllocate(root, value); TF_ASSERT_OK(root.status()); - ClientSession session(root); + XrtClientSession session(root); std::vector outputs; TF_EXPECT_OK(session.Run({handle}, &outputs)); EXPECT_EQ(outputs.size(), 1); @@ -549,7 +580,7 @@ TEST(RawApiTest, ReadAndWriteState) { root.WithControlDependencies(read_back), handle); TF_ASSERT_OK(root.status()); - ClientSession session(root); + XrtClientSession session(root); std::vector outputs; TF_EXPECT_OK( session.Run(ClientSession::FeedType(), {read_back}, {release}, &outputs)); @@ -571,7 +602,7 @@ TEST(RawApiTest, ReadAndWriteStateAutoFree) { auto read_back = ops::XRTReadLiteralAndRelease(root, handle); TF_ASSERT_OK(root.status()); - ClientSession session(root); + XrtClientSession session(root); std::vector outputs; TF_EXPECT_OK(session.Run({read_back}, &outputs)); @@ -602,7 +633,7 @@ TEST(RawApiTest, SubBuffer) { auto value_00 = ops::XRTReadLiteralAndRelease(root, sub_00); TF_ASSERT_OK(root.status()); - ClientSession session(root); + XrtClientSession session(root); std::vector outputs; TF_EXPECT_OK(session.Run({value_0, value_1, value_00}, &outputs)); @@ -678,7 +709,7 @@ TEST(RawApiTest, MakeTuple) { auto res_1 = ops::XRTReadLiteralAndRelease(root, handle_4); TF_ASSERT_OK(root.status()); - ClientSession session(root); + XrtClientSession session(root); std::vector outputs; TF_EXPECT_OK(session.Run({res_0, res_1}, &outputs)); xla::LiteralProto response_0; @@ -718,7 +749,7 @@ TEST(RawApiTest, ExecuteChainedOpByOp) { root, ops::Const(root.WithDevice("/device:CPU:0"), c_sub_scale)); TF_ASSERT_OK(root.status()); - ClientSession session(root); + XrtClientSession session(root); std::vector outputs; TF_EXPECT_OK( session.Run({c_add_scale_op.handle, c_sub_scale_op.handle}, &outputs)); @@ -788,7 +819,7 @@ TEST(RawApiTest, ExecuteChained) { root, ops::Const(root.WithDevice("/device:CPU:0"), c_sub_scale)); TF_ASSERT_OK(root.status()); - ClientSession session(root); + XrtClientSession session(root); std::vector outputs; TF_EXPECT_OK( session.Run({c_add_scale_op.handle, c_sub_scale_op.handle}, &outputs)); @@ -920,7 +951,7 @@ TEST(RawApiTest, CompileAndExecute) { auto read_back = ops::XRTReadLiteralAndRelease(root, result); TF_ASSERT_OK(root.status()); - ClientSession session(root); + XrtClientSession session(root); std::vector outputs; TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); @@ -975,7 +1006,7 @@ TEST(RawApiTest, CompileAndExecuteWithArgumentVector) { auto read_back = ops::XRTReadLiteralAndRelease(root, result); TF_ASSERT_OK(root.status()); - ClientSession session(root); + XrtClientSession session(root); std::vector outputs; TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); @@ -1025,7 +1056,7 @@ TEST(RawApiTest, CompileWithXlaReturnShapes) { auto release = ops::XRTReleaseCompilationHandle(root, c_handle.handle); TF_ASSERT_OK(root.status()); - ClientSession session(root); + XrtClientSession session(root); std::vector outputs; TF_EXPECT_OK(session.Run(ClientSession::FeedType(), {c_handle.program_shape}, {release}, &outputs)); @@ -1094,7 +1125,7 @@ TEST(RawApiTest, DotGeneralWithLayoutTest) { auto read_back = ops::XRTReadLiteralAndRelease(root, result); TF_ASSERT_OK(root.status()); - ClientSession session(root); + XrtClientSession session(root); std::vector outputs; TF_EXPECT_OK(session.Run({read_back}, &outputs)); @@ -1129,7 +1160,7 @@ TEST(RawApiTest, CompileAndExecuteZeroArg) { auto read_back = ops::XRTReadLiteralAndRelease(root, result); TF_ASSERT_OK(root.status()); - ClientSession session(root); + XrtClientSession session(root); std::vector outputs; TF_EXPECT_OK(session.Run({read_back}, &outputs)); @@ -1179,7 +1210,7 @@ TEST(RawApiTest, CompileAndExecuteReturnTuple) { auto read_back = ops::XRTReadLiteralAndRelease(root, result); TF_ASSERT_OK(root.status()); - ClientSession session(root); + XrtClientSession session(root); std::vector outputs; TF_EXPECT_OK(session.Run({read_back}, &outputs)); @@ -1230,7 +1261,7 @@ TEST(RawApiTest, CompileAndExecuteReturnExplodedTuple) { {Output(p0_handle), Output(p1_handle)}); TF_ASSERT_OK(root.status()); - ClientSession session(root); + XrtClientSession session(root); std::vector outputs; TF_EXPECT_OK(session.Run({result}, &outputs)); EXPECT_EQ(outputs.size(), 1); @@ -1272,7 +1303,7 @@ TEST(RawApiTest, LeakCompilationReference) { auto c_handle = ops::XRTCompile(root, computation); TF_ASSERT_OK(root.status()); - ClientSession session(root); + XrtClientSession session(root); std::vector outputs; TF_EXPECT_OK(session.Run({c_handle.handle}, &outputs)); } @@ -1316,7 +1347,7 @@ TEST(RawApiTest, CompileAndExecuteWithReusedBuffers) { e.set_release_compilation_handle(true); Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); - ClientSession session(root); + XrtClientSession session(root); auto e_config = ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString()); auto c_data = @@ -1412,7 +1443,7 @@ TEST(RawApiTest, CompileAndExecuteWithS64Argument) { auto read_back = ops::XRTReadLiteralAndRelease(root, result); TF_ASSERT_OK(root.status()); - ClientSession session(root); + XrtClientSession session(root); std::vector outputs; TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); @@ -1444,7 +1475,7 @@ TEST(RawApiTest, TestDeviceMemoryCompaction) { } TF_ASSERT_OK(root.status()); - ClientSession session(root); + XrtClientSession session(root); std::vector outputs; TF_EXPECT_OK(session.Run(handle_outputs, &outputs)); EXPECT_EQ(outputs.size(), handle_outputs.size()); @@ -1488,6 +1519,95 @@ TEST(RawApiTest, TestDeviceMemoryCompaction) { } } +TEST(RawApiTest, TestDeviceMemorySwap) { + const xla::Shape scalar_shape = xla::ShapeUtil::MakeShape(xla::F32, {}); + // 100MB F32 tensor. + const xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {5000, 5000}); + const xla::int64 tensor_size = xla::ShapeUtil::ByteSizeOf(shape); + // On CPU we cannot trigger OOM/swap. For TPU and GPU we select 16GB as + // maximum memory. + xla::int64 device_memory_size = 8LL * 1024 * 1024 * 1024; + if (*xla_test_device_ptr == "TPU" || *xla_test_device_ptr == "XLA_GPU") { + device_memory_size = 16LL * 1024 * 1024 * 1024; + } + + xrt::XLAAllocation p0; + *p0.mutable_value() = xla::LiteralUtil::CreateR0(0.90434).ToProto(); + + // Create a computation which broadcasts a scalar to a big tensor. + xrt::XLAComputation c_bcast; + { + auto shapes = c_bcast.mutable_config()->mutable_program_shape(); + *shapes->add_parameters() = scalar_shape.ToProto(); + *shapes->mutable_result() = shape.ToProto(); + StoreComputationSnapshot( + BroadcastComputation(scalar_shape, shape.dimensions()), + c_bcast.mutable_hlo_snapshot()); + } + + // Create a computation which compares two tensors. + xrt::XLAComputation c_equal; + { + auto shapes = c_equal.mutable_config()->mutable_program_shape(); + *shapes->add_parameters() = shape.ToProto(); + *shapes->add_parameters() = shape.ToProto(); + *shapes->mutable_result() = + xla::ShapeUtil::MakeShape(xla::S32, {}).ToProto(); + StoreComputationSnapshot(IsEqualComputation(shape), + c_equal.mutable_hlo_snapshot()); + } + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(false); + e.set_release_compilation_handle(false); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + XrtClientSession session(root); + auto e_config = + ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString()); + auto bcast_computation = + ops::Const(root.WithDevice("/device:CPU:0"), c_bcast.SerializeAsString()); + auto c_bcast_handle = ops::XRTCompile(root, bcast_computation); + auto equal_computation = + ops::Const(root.WithDevice("/device:CPU:0"), c_equal.SerializeAsString()); + auto c_equal_handle = ops::XRTCompile(root, equal_computation); + auto p0_value = + ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString()); + auto p0_handle = ops::XRTAllocate(root, p0_value); + std::vector outputs; + std::vector device_handles; + + // Create more data the device can take using the broadcast computation. + xla::int64 num_tensors = 8 + device_memory_size / tensor_size; + for (xla::int64 i = 0; i < num_tensors; ++i) { + auto result = ops::XRTExecute(root, c_bcast_handle.handle, e_config, + {Output(p0_handle)}); + TF_ASSERT_OK(root.status()); + TF_ASSERT_OK(session.Run({result}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + device_handles.push_back(outputs[0].scalar()()); + } + + // Trigger computations on XRT handles to verify the swap-out/swap-in logic, + // by comparing sequential couple of tensors. + auto zero_literal = xla::LiteralUtil::CreateR0(0); + for (size_t i = 0; i + 1 < device_handles.size(); ++i) { + auto exec_op = ops::XRTExecute( + root, c_equal_handle.handle, e_config, + {Input(device_handles[i]), Input(device_handles[i + 1])}); + auto read_back = ops::XRTReadLiteral(root, exec_op); + + TF_ASSERT_OK(root.status()); + TF_ASSERT_OK(session.Run({read_back}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + auto literal = xla::Literal::CreateFromProto(response).ValueOrDie(); + EXPECT_EQ(literal, zero_literal); + } +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/xrt/xrt_memory_manager.cc b/tensorflow/compiler/xrt/xrt_memory_manager.cc new file mode 100644 index 00000000000..3a304764800 --- /dev/null +++ b/tensorflow/compiler/xrt/xrt_memory_manager.cc @@ -0,0 +1,353 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xrt/xrt_memory_manager.h" + +#include +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/core/lib/random/random.h" + +namespace tensorflow { +namespace { + +// We use kDeviceBits to store the device ordinal in the handle. We store the +// device in the upper part of the int64 handle to make sure the random bits are +// in the lower part which is better when storing the handle as a key for +// unordered maps. +const int kDeviceBits = 12; + +int64 MakeDeviceHandle(int64 device_ordinal, int64 rnd_value) { + const int64 kUidMask = (static_cast(1) << (64 - kDeviceBits)) - 1; + return (device_ordinal << (64 - kDeviceBits)) | (rnd_value & kUidMask); +} + +int GetDeviceFromHandle(int64 handle) { + return (handle >> (64 - kDeviceBits)) & ((1 << kDeviceBits) - 1); +} + +} // namespace + +class XRTMemoryManager::DeviceContext { + struct Alloc { + explicit Alloc(RefPtr tuple) + : tuple(std::move(tuple)) {} + + RefPtr tuple; + }; + + using AllocList = std::list; + + public: + int64 Register(RefPtr tuple) { + while (true) { + int64 handle = MakeDeviceHandle(tuple->device_ordinal(), CreateUid()); + mutex_lock lock(lock_); + allocs_.emplace_front(tuple); + if (alloc_map_.emplace(handle, allocs_.begin()).second) { + return handle; + } + // The chances of hitting an existing handle are so remote, it is much + // more convenient to add to the list before, and eventually removing. + allocs_.erase(allocs_.begin()); + } + } + + bool Release(int64 handle) { + mutex_lock lock(lock_); + auto it = alloc_map_.find(handle); + if (it == alloc_map_.end()) { + return false; + } + allocs_.erase(it->second); + alloc_map_.erase(it); + return true; + } + + RefPtr Lookup(int64 handle) { + mutex_lock lock(lock_); + auto it = alloc_map_.find(handle); + if (it == alloc_map_.end()) { + return nullptr; + } + // LRU + allocs_.splice(allocs_.begin(), allocs_, it->second); + return it->second->tuple; + } + + void Clear() { + mutex_lock lock(lock_); + alloc_map_.clear(); + allocs_.clear(); + } + + Status CompactAllocations(XRTMemoryManager* memory_manager, + xla::Backend* backend) { + VLOG(4) << "CompactAllocations started"; + mutex_lock lock(lock_); + Status status; + std::vector swapped; + // We are swapping out from the most recently used allocations. This is + // desirable since the most recently used will be finding themselves at the + // bottom of the allocation space. Since these are more likely to be pinned + // allocations, a further trim done by following TryFreeMemory() call will + // eventually drop the higher located allocations, with better chance of + // reducing fragmentation. + // Also, by swapping out the pinned allocations first, those will also be + // the first to be restored, and hence if we will ever find OOM on the way + // out, we would more likely be swapping in not pinned ones. + for (auto it = allocs_.begin(); it != allocs_.end(); ++it) { + // We are compacting all the allocations, so we will temporarily swap out + // even pinned allocations. + auto swap_result_or = it->tuple->SwapOut(backend, /*swap_pinned=*/true); + if (!swap_result_or.ok()) { + status = swap_result_or.status(); + break; + } + if (swap_result_or.ValueOrDie()) { + swapped.push_back(it); + } + } + // At this point we have released all the device memory we could release. + // Load back the tuple allocations we have swapped out above. + for (auto& it : swapped) { + auto swap_result_or = it->tuple->SwapIn(memory_manager, backend); + if (!swap_result_or.ok()) { + // If we failed to restored a pinned allocation, better to CHECK here + // than wondering why XRTTupleAllocation calls fail with errors about + // missing buffers. + CHECK(!it->tuple->IsPinned()); // Crash OK + if (status.ok()) { + status = swap_result_or.status(); + } + } + } + VLOG(4) << "CompactAllocations finished: " << status; + return status; + } + + // Tries to free size bytes by freeing some unpinned device memory. Returns + // the amount of memory which was able to free. + xla::StatusOr TryFreeMemory(xla::Backend* backend, size_t size) { + mutex_lock lock(lock_); + size_t swapped_size = 0; + for (auto it = allocs_.rbegin(); it != allocs_.rend(); ++it) { + TF_ASSIGN_OR_RETURN(bool swap_result, + it->tuple->SwapOut(backend, /*swap_pinned=*/false)); + if (swap_result) { + swapped_size += it->tuple->GetDeviceMemorySize(); + if (swapped_size >= size) { + break; + } + } + } + VLOG(3) << "Swapped out " << swapped_size << " bytes"; + return swapped_size; + } + + private: + static int64 CreateUid() { + int64 uid; + do { + uid = random::New64() & INT64_MAX; + } while (uid == InvalidKey()); + return uid; + } + + // We store Alloc records inside an std::list so we can LRU it, and + // store the list iterators within the handle map, as list iterators don't get + // invalidated by (other elements) removals or position swaps. + mutex lock_; + AllocList allocs_; + std::unordered_map alloc_map_; +}; + +XRTMemoryManager::WorkingSet::WorkingSet( + RefPtr memory_manager) + : memory_manager_(std::move(memory_manager)) {} + +XRTMemoryManager::WorkingSet::~WorkingSet() { + for (auto& tuple : pinned_tuples_) { + tuple->Unpin(); + } +} + +Status XRTMemoryManager::WorkingSet::LookupAndPin(xla::Backend* backend, + int64 handle) { + TF_ASSIGN_OR_RETURN(auto tuple, memory_manager_->Lookup(handle)); + TF_RETURN_IF_ERROR( + tuple->PinAndSwapIn(memory_manager_.get(), backend).status()); + pinned_tuples_.push_back(std::move(tuple)); + return Status::OK(); +} + +/* static */ RefPtr XRTMemoryManager::Get(ResourceMgr* rm) { + static string* container = new string("XrtState"); + static string* name = new string("MemoryManager"); + XRTMemoryManager* memory_manager = nullptr; + TF_CHECK_OK(rm->LookupOrCreate( + *container, *name, &memory_manager, [](XRTMemoryManager** ret) { + *ret = new XRTMemoryManager(); + return Status::OK(); + })); + return memory_manager; +} + +int64 XRTMemoryManager::Register(RefPtr tuple) { + DeviceContext* device_context = GetDeviceContext(tuple->device_ordinal(), + /*create_if_missing=*/true); + return device_context->Register(std::move(tuple)); +} + +xla::StatusOr> XRTMemoryManager::Lookup( + int64 handle) { + int device_ordinal = GetDeviceFromHandle(handle); + DeviceContext* device_context = GetDeviceContext(device_ordinal, + /*create_if_missing=*/false); + if (device_context == nullptr) { + return errors::NotFound("XRT memory handle not found: ", handle); + } + RefPtr tuple = device_context->Lookup(handle); + if (tuple == nullptr) { + return errors::NotFound("XRT memory handle not found: ", handle); + } + return std::move(tuple); +} + +Status XRTMemoryManager::Release(int64 handle) { + int device_ordinal = GetDeviceFromHandle(handle); + DeviceContext* device_context = GetDeviceContext(device_ordinal, + /*create_if_missing=*/false); + if (device_context == nullptr || !device_context->Release(handle)) { + return errors::NotFound("XRT memory handle not found: ", handle); + } + return Status::OK(); +} + +Status XRTMemoryManager::CompactAllocations(xla::Backend* backend, + int device_ordinal) { + DeviceContext* device_context = GetDeviceContext(device_ordinal, + /*create_if_missing=*/false); + return device_context != nullptr + ? device_context->CompactAllocations(this, backend) + : Status::OK(); +} + +void XRTMemoryManager::ReleaseAllAllocations() { + mutex_lock lock(lock_); + for (auto& device_context : device_contexts_) { + if (device_context != nullptr) { + device_context->Clear(); + } + } +} + +xla::StatusOr XRTMemoryManager::Allocate( + xla::Backend* backend, int device_ordinal, size_t size) { + se::DeviceMemoryAllocator* allocator = backend->memory_allocator(); + auto memory_or = + allocator->Allocate(device_ordinal, size, /*retry_on_failure=*/false); + if (memory_or.status().code() == error::RESOURCE_EXHAUSTED) { + VLOG(4) << "Allocate of " << size << " bytes failed on device " + << device_ordinal; + + DeviceContext* device_context = + GetDeviceContext(device_ordinal, + /*create_if_missing=*/false); + if (device_context != nullptr) { + Status status = device_context->TryFreeMemory(backend, size).status(); + if (status.ok()) { + // As long as there is no error, we still try again the allocation, even + // if the TryFreeMemory() call ended up freeing less memory than the + // required size. Fragmentation could make the memory allocation succeed + // even if the freed memory is indeed lower. + memory_or = allocator->Allocate(device_ordinal, size, + /*retry_on_failure=*/false); + } else if (status.code() != error::RESOURCE_EXHAUSTED) { + VLOG(4) << "Allocate of " << size << " bytes on device " + << device_ordinal << ": " << status; + return status; + } + } + } + return memory_or; +} + +string XRTMemoryManager::DebugString() const { + // We might want to emit more detailed information here, like per device + // memory allocations. + return "XRTMemoryManager"; +} + +XRTMemoryManager::DeviceContext* XRTMemoryManager::GetDeviceContext( + int device_ordinal, bool create_if_missing) { + mutex_lock lock(lock_); + if (device_ordinal >= device_contexts_.size()) { + if (!create_if_missing) { + return nullptr; + } + device_contexts_.resize(device_ordinal + 1); + } + DeviceContext* device_context = device_contexts_[device_ordinal].get(); + if (device_context == nullptr && create_if_missing) { + device_contexts_[device_ordinal] = absl::make_unique(); + device_context = device_contexts_[device_ordinal].get(); + } + return device_context; +} + +Status XRTMemoryManager::TryFreeMemoryStep(MemoryReclaimContext* mrctx, + const Status& status) { + DeviceContext* device_context = GetDeviceContext(mrctx->device_ordinal, + /*create_if_missing=*/false); + if (device_context == nullptr) { + return status; + } + if (!mrctx->done_freeing) { + // If the caller passed us a zero requested_free_size, we try to free chunks + // of kMaxFreeSize memory, until either the run function suceeds, or we run + // out of freeable memory. + const size_t kMaxFreeSize = 1000000000; + size_t free_size = + (mrctx->requested_free_size > 0) + ? std::min(mrctx->requested_free_size - mrctx->free_size, + kMaxFreeSize) + : kMaxFreeSize; + if (free_size > 0) { + auto free_size_or = + device_context->TryFreeMemory(mrctx->backend, free_size); + if (!free_size_or.ok()) { + return status; + } + size_t size = free_size_or.ValueOrDie(); + mrctx->free_size += size; + if (size > 0) { + return Status::OK(); + } + } + mrctx->done_freeing = true; + } + if (!mrctx->done_compacting) { + mrctx->done_compacting = true; + if (device_context->CompactAllocations(this, mrctx->backend).ok()) { + return Status::OK(); + } + } + return status; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/xrt/xrt_memory_manager.h b/tensorflow/compiler/xrt/xrt_memory_manager.h new file mode 100644 index 00000000000..445be45cf57 --- /dev/null +++ b/tensorflow/compiler/xrt/xrt_memory_manager.h @@ -0,0 +1,177 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XRT_XRT_MEMORY_MANAGER_H_ +#define TENSORFLOW_COMPILER_XRT_XRT_MEMORY_MANAGER_H_ + +#include +#include + +#include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/compiler/xrt/xrt_refptr.h" +#include "tensorflow/compiler/xrt/xrt_state.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/device_memory_allocator.h" +#include "tensorflow/stream_executor/stream_executor.h" + +namespace tensorflow { + +// The XRTMemoryManager manages all the XRT allocations. It is a ResourceBase +// object which leaves within the ResourceMgr. This is only one XRT memory +// manager object within the ResourceMgr container. +class XRTMemoryManager : public ResourceBase { + // The DeviceContext class, defined and implemented locally inside the + // xrt_memory_manager.cc file, holds, for each device, all the information + // related to the XRT memory management for such device. + class DeviceContext; + + public: + // A working set is a set of tuple allocations which are the input of a given + // operation, and as such they must be pinned on the device memory. The tuple + // allocations added to the WorkingSet will be unpinned at object destruction. + class WorkingSet { + public: + explicit WorkingSet(RefPtr memory_manager); + + ~WorkingSet(); + + // Looks up the tuple handle within the memory manager, and pins it to the + // device (if not already pinned). + Status LookupAndPin(xla::Backend* backend, int64 handle); + + const std::vector>& PinnedTuples() const { + return pinned_tuples_; + } + + const RefPtr& MemoryManager() const { + return memory_manager_; + } + + private: + RefPtr memory_manager_; + std::vector> pinned_tuples_; + }; + + // Retrieves the XRTMemoryManager singleton stored within the ResourceMgr. + static RefPtr Get(ResourceMgr* rm); + + // Registers an XRTTupleAllocation and returns the unique handle identifying + // it. + int64 Register(RefPtr tuple); + + // Looks up an handle returned by the Register() API and returns the + // XRTTupleAllocation behind it. + xla::StatusOr> Lookup(int64 handle); + + Status Lookup(int64 handle, RefPtr* tuple) { + TF_ASSIGN_OR_RETURN(*tuple, Lookup(handle)); + return Status::OK(); + } + + // Releases an handle by dropping the refences count held on the + // XRTTupleAllocation by the XRTMemoryManager. Existing XRTTupleAllocation + // references will continue to be valid. + Status Release(int64 handle); + + // Tries to compact all the memory allocations on a given device. This is + // currently done by swapping-out all the existing allocation, and swapping + // them back in. + Status CompactAllocations(xla::Backend* backend, int device_ordinal); + + // Releases all the device memory allocated by XRT within the resource + // manager. + void ReleaseAllAllocations(); + + // Tries to allocate size bytes of device memory from the device_ordinal + // device. Might attempt to free some unpinned device memory, if the underline + // allocator call fails, and try the allocation again. + xla::StatusOr Allocate(xla::Backend* backend, + int device_ordinal, + size_t size); + + // Runs the specified function and handling the error::RESOURCE_EXHAUSTED + // status code coming out of it. In such cases, we run different memory + // freeing operations trying to make runfn succeed. The requested_free_size + // argument represents an hint of the requested memory size which would make + // runfn succeed. + template + xla::StatusOr Run(const std::function()>& runfn, + xla::Backend* backend, int device_ordinal, + size_t requested_free_size); + + string DebugString() const override; + + // Returns the invalid key value, which will be never generated by the + // Intern() API. + static int64 InvalidKey() { return 0; } + + private: + // Structure used to track the progress of a try-to-free operation. It is + // initialized and the passed to the TryFreeMemoryStep() API. + struct MemoryReclaimContext { + MemoryReclaimContext(xla::Backend* backend, int device_ordinal, + size_t requested_free_size) + : backend(backend), + device_ordinal(device_ordinal), + requested_free_size(requested_free_size) {} + + xla::Backend* const backend = nullptr; + const int device_ordinal = 0; + const size_t requested_free_size = 0; + size_t free_size = 0; + bool done_freeing = false; + bool done_compacting = false; + }; + + DeviceContext* GetDeviceContext(int device_ordinal, bool create_if_missing); + + // Called multiple times while trying to make a memory consuming function call + // to fit. Performs progressively more expensive memory reduction operations, + // until returning error::RESOURCE_EXHAUSTED when no further reductions are + // possible. + Status TryFreeMemoryStep(MemoryReclaimContext* mrctx, const Status& status); + + mutex lock_; + std::vector> device_contexts_; +}; + +template +xla::StatusOr XRTMemoryManager::Run( + const std::function()>& runfn, xla::Backend* backend, + int device_ordinal, size_t requested_free_size) { + MemoryReclaimContext mrctx(backend, device_ordinal, requested_free_size); + while (true) { + // We assume that runfn is a relatively fast-fail function compared to the + // operations required to free up the required memory. Here we call into the + // TryFreeMemoryStep() API multiple times, which will run progressively more + // expensive operations. + auto result_or = runfn(); + if (result_or.status().code() != error::RESOURCE_EXHAUSTED) { + return result_or; + } + TF_RETURN_IF_ERROR(TryFreeMemoryStep(&mrctx, result_or.status())); + } +} + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_XRT_XRT_MEMORY_MANAGER_H_ diff --git a/tensorflow/compiler/xrt/xrt_refptr.h b/tensorflow/compiler/xrt/xrt_refptr.h new file mode 100644 index 00000000000..2db20dd71ce --- /dev/null +++ b/tensorflow/compiler/xrt/xrt_refptr.h @@ -0,0 +1,108 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Utility functions in support of the XRT API. + +#ifndef TENSORFLOW_COMPILER_XRT_XRT_REFPTR_H_ +#define TENSORFLOW_COMPILER_XRT_XRT_REFPTR_H_ + +#include + +namespace tensorflow { + +// Reference counted smart pointer for XRT objects providing the standard +// Ref()/Unref() APIs. +template +class RefPtr { + public: + RefPtr() = default; + // Creates a RefPtr from a pointer. This is an ownership transfer operation, + // and the caller has to own a valid reference to ptr (unless ptr is nullptr). + RefPtr(T* ptr) : ptr_(ptr) {} // NOLINT + RefPtr(const RefPtr& other) : ptr_(other.ptr_) { Acquire(ptr_); } + RefPtr(RefPtr&& other) : ptr_(other.ptr_) { other.ptr_ = nullptr; } + + ~RefPtr() { Release(ptr_); } + + RefPtr& operator=(const RefPtr& other) { + if (this != &other) { + Acquire(other.ptr_); + Release(ptr_); + ptr_ = other.ptr_; + } + return *this; + } + + RefPtr& operator=(RefPtr&& other) { + if (this != &other) { + Release(ptr_); + ptr_ = other.ptr_; + other.ptr_ = nullptr; + } + return *this; + } + + operator bool() const { return ptr_ != nullptr; } // NOLINT + bool operator==(const RefPtr& rhs) const { return ptr_ == rhs.ptr_; } + bool operator!=(const RefPtr& rhs) const { return ptr_ != rhs.ptr_; } + bool operator==(const T* ptr) const { return ptr_ == ptr; } + bool operator!=(const T* ptr) const { return ptr_ != ptr; } + bool operator==(std::nullptr_t ptr) const { return ptr_ == ptr; } + bool operator!=(std::nullptr_t ptr) const { return ptr_ != ptr; } + + T* get() const { return ptr_; } + + T* operator->() const { + CHECK(ptr_ != nullptr); // Crash OK + return ptr_; + } + + T& operator*() const { + CHECK(ptr_ != nullptr); // Crash OK + return *ptr_; + } + + T* release() { + T* ptr = ptr_; + ptr_ = nullptr; + return ptr; + } + + // Resets the RefPtr from a pointer. This is an ownership transfer operation, + // and the caller has to own a valid reference to ptr (unless ptr is nullptr). + void reset(T* ptr = nullptr) { + Release(ptr_); + ptr_ = ptr; + } + + private: + static void Release(T* ptr) { + if (ptr != nullptr) { + ptr->Unref(); + } + } + + static void Acquire(T* ptr) { + if (ptr != nullptr) { + ptr->Ref(); + } + } + + T* ptr_ = nullptr; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_XRT_XRT_REFPTR_H_ diff --git a/tensorflow/compiler/xrt/xrt_state.cc b/tensorflow/compiler/xrt/xrt_state.cc index fa25b727a3d..2f5eb5aec1e 100644 --- a/tensorflow/compiler/xrt/xrt_state.cc +++ b/tensorflow/compiler/xrt/xrt_state.cc @@ -18,31 +18,24 @@ limitations under the License. #include "tensorflow/compiler/xrt/xrt_state.h" -#include - #include #include #include #include #include "absl/memory/memory.h" -#include "absl/strings/str_cat.h" -#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/backend.h" -#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/random/random.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/stream_executor/stream_executor.h" +#include "tensorflow/compiler/xrt/xrt_memory_manager.h" namespace tensorflow { - namespace { +// Helper typedef to make ShapeTree ForEach helper lambda signatures more +// readable. They need a type of const T& where in this case T is the +// following pointer. +typedef XRTBufferAllocation* XRTBufferAllocationPtr; + class BufferAllocStats { public: struct Stats { @@ -71,26 +64,15 @@ class BufferAllocStats { std::map stats_; }; -const char* kTupleContainer = "tuples"; - -int64 get_uid() { - int64 uid; - do { - uid = random::New64() & INT64_MAX; - } while (uid == XRTTupleAllocation::InvalidKey()); - return uid; -} - BufferAllocStats* GetAllocStats() { static BufferAllocStats* stats = new BufferAllocStats(); return stats; } Status AllocateScopedShapedBuffer( - xla::Backend* backend, int device_ordinal, const xla::Shape& shape, - std::unique_ptr* buffer) { + XRTMemoryManager* memory_manager, xla::Backend* backend, int device_ordinal, + const xla::Shape& shape, std::unique_ptr* buffer) { auto transfer_manager = backend->transfer_manager(); - auto allocator = backend->memory_allocator(); TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal)); // XLA may use a different representation on device than the representation on @@ -111,18 +93,19 @@ Status AllocateScopedShapedBuffer( // it goes out of scope. That's useful if we return early as the result of an // error allocating one of the later buffers. *buffer = absl::make_unique( - shape, on_device_shape, allocator, device_ordinal); + shape, on_device_shape, backend->memory_allocator(), device_ordinal); for (auto& index_to_buffer : (*buffer)->buffers()) { - xla::Shape subshape = + const xla::Shape& subshape = xla::ShapeUtil::GetSubshape(on_device_shape, index_to_buffer.first); uint64 size = transfer_manager->GetByteSizeRequirement(subshape); TF_ASSIGN_OR_RETURN( se::OwningDeviceMemory buffer, - allocator->Allocate(device_ordinal, size, /*retry_on_failure=*/false)); + memory_manager->Allocate(backend, device_ordinal, size)); // Move our buffer into shaped_buffer, which takes ownership of it. index_to_buffer.second = buffer.Release(); VLOG(2) << "Allocated buffer at " << index_to_buffer.second.opaque() - << " index " << index_to_buffer.first.ToString(); + << " index " << index_to_buffer.first.ToString() << " (" << size + << " bytes)"; } TF_RETURN_IF_ERROR( @@ -136,8 +119,7 @@ Status AllocateScopedShapedBuffer( XRTBufferAllocation::XRTBufferAllocation(const se::DeviceMemoryBase& allocation, int device_ordinal, se::DeviceMemoryAllocator* allocator) - : size_(allocation.size()), - allocation_(allocation), + : allocation_(allocation), device_ordinal_(device_ordinal), allocator_(allocator) { if (VLOG_IS_ON(2)) { @@ -153,21 +135,15 @@ XRTBufferAllocation::~XRTBufferAllocation() { GetAllocStats()->ReportFree(device_ordinal_, allocation_.size()); } // Deallocate explicitly allows allocation_ to be null. - Status s = allocator_->Deallocate(device_ordinal_, allocation_); - // Nothing to do but check fail here if memory datastructures are corrupted. - CHECK(s.ok()); - VLOG(2) << "Freed buffer at " << allocation_.opaque(); + TF_CHECK_OK(allocator_->Deallocate(device_ordinal_, allocation_)); + VLOG(2) << "Freed buffer at " << allocation_.opaque() << " (" + << allocation_.size() << " bytes)"; } const se::DeviceMemoryBase& XRTBufferAllocation::allocation() { return allocation_; } -void XRTBufferAllocation::DiscardAllocation() { - // Replace the allocation with a null. - allocation_ = se::DeviceMemoryBase(); -} - XRTTupleAllocation::XRTTupleAllocation(int device_ordinal, se::DeviceMemoryAllocator* allocator, const xla::Shape& on_host_shape, @@ -176,23 +152,29 @@ XRTTupleAllocation::XRTTupleAllocation(int device_ordinal, allocator_(allocator), on_host_shape_(on_host_shape), on_device_shape_(on_device_shape), - buffers_(&on_device_shape_) {} + buffers_(&on_device_shape_), + pin_count_(0) {} -XRTTupleAllocation::~XRTTupleAllocation() { - for (auto& buffer : buffers_) { - buffer.second->Unref(); +XRTTupleAllocation::~XRTTupleAllocation() { ReleaseBuffers(); } + +void XRTTupleAllocation::ReleaseBuffers() { + for (auto& index_buffer : buffers_) { + if (index_buffer.second != nullptr) { + index_buffer.second->Unref(); + index_buffer.second = nullptr; + } } } /*static*/ Status XRTTupleAllocation::CreateAndTransfer( - const xla::LiteralBase& literal, xla::Backend* backend, int device_ordinal, + const xla::LiteralBase& literal, XRTMemoryManager* memory_manager, + xla::Backend* backend, int device_ordinal, XRTTupleAllocation** allocation) { auto transfer_manager = backend->transfer_manager(); - auto allocator = backend->memory_allocator(); - std::unique_ptr scoped_buffer; - TF_RETURN_IF_ERROR(AllocateScopedShapedBuffer( - backend, device_ordinal, literal.shape(), &scoped_buffer)); + TF_RETURN_IF_ERROR(AllocateScopedShapedBuffer(memory_manager, backend, + device_ordinal, literal.shape(), + &scoped_buffer)); TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal)); TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice( stream.get(), literal, *scoped_buffer)); @@ -202,11 +184,13 @@ XRTTupleAllocation::~XRTTupleAllocation() { // call. To avoid a leak, there must be no error-case returns from here until // the end of the method. auto shaped_buffer = scoped_buffer->release(); - *allocation = new XRTTupleAllocation(device_ordinal, allocator, - shaped_buffer.on_host_shape(), - shaped_buffer.on_device_shape()); + *allocation = new XRTTupleAllocation( + device_ordinal, backend->memory_allocator(), + shaped_buffer.on_host_shape(), shaped_buffer.on_device_shape()); (*allocation) - ->InitializeFromShapedBuffer(shaped_buffer, allocator, device_ordinal); + ->InitializeFromShapedBuffer(shaped_buffer, backend->memory_allocator(), + device_ordinal); + (*allocation)->SetDeviceMemorySize(); return Status::OK(); } @@ -220,24 +204,22 @@ XRTTupleAllocation::~XRTTupleAllocation() { shaped_buffer.on_device_shape()); (*allocation) ->InitializeFromShapedBuffer(shaped_buffer, allocator, device_ordinal); + (*allocation)->SetDeviceMemorySize(); return Status::OK(); } -Status XRTTupleAllocation::ToLiteral(xla::Backend* backend, int device_ordinal, +Status XRTTupleAllocation::ToLiteral(xla::Backend* backend, xla::MutableLiteralBase* literal) { - auto transfer_manager = backend->transfer_manager(); - TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal)); + mutex_lock lock(lock_); + return literal_ == nullptr ? StoreToLiteral(backend, literal) + : literal->CopyFrom(*literal_); +} - // Validate the allocation buffers as if nulls gets to - // TransferLiteralFromDevice() a CHECK is issued. - xla::ShapedBuffer shaped_buffer = ToShapedBuffer(); - for (auto& index_buffer : shaped_buffer.buffers()) { - if (index_buffer.second.is_null()) { - return errors::InvalidArgument("Literal buffer at index ", - index_buffer.first.ToString(), - " has been released"); - } - } +Status XRTTupleAllocation::StoreToLiteral(xla::Backend* backend, + xla::MutableLiteralBase* literal) { + auto transfer_manager = backend->transfer_manager(); + TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal())); + TF_ASSIGN_OR_RETURN(xla::ShapedBuffer shaped_buffer, ToShapedBuffer()); return transfer_manager->TransferLiteralFromDevice(stream.get(), shaped_buffer, *literal); } @@ -250,52 +232,102 @@ Status XRTTupleAllocation::WriteLiteral(xla::Backend* backend, xla::ShapeUtil::HumanStringWithLayout(literal.shape()), " device=", xla::ShapeUtil::HumanStringWithLayout(on_host_shape())); } + mutex_lock lock(lock_); + if (literal_ != nullptr) { + // The allocation is currently swapped out, and we have a host literal for + // its content. Just update the host literal with the new value. + return literal_->CopyFrom(literal); + } + TF_ASSIGN_OR_RETURN(xla::ShapedBuffer shaped_buffer, ToShapedBuffer()); auto transfer_manager = backend->transfer_manager(); TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal())); return transfer_manager->TransferLiteralToDevice(stream.get(), literal, - ToShapedBuffer()); + shaped_buffer); } +xla::StatusOr XRTTupleAllocation::SwapOut(xla::Backend* backend, + bool swap_pinned) { + mutex_lock lock(lock_); + if (literal_ == nullptr && (!IsPinned() || swap_pinned)) { + xla::Literal literal(on_host_shape()); + TF_RETURN_IF_ERROR(StoreToLiteral(backend, &literal)); + ReleaseBuffers(); + literal_ = absl::make_unique(std::move(literal)); + return true; + } + return false; +} + +xla::StatusOr XRTTupleAllocation::SwapIn(XRTMemoryManager* memory_manager, + xla::Backend* backend) { + // We need to call AllocateScopedShapedBuffer() outside the locks, since the + // XRTMemoryManager might end up calling back into the SwapOut() API. + // So we do a quick check before using the IsSwapped() API, and it can happen + // that the allocation becomes swapped in after the check. This means which we + // will end up doing an allocation, and then releasing it soon after (via its + // scoped variables). This is an unlikely scenario (two threads calling + // SwapIn() on the same allocation) though. + if (!IsSwapped()) { + return false; + } + + auto transfer_manager = backend->transfer_manager(); + std::unique_ptr scoped_buffer; + TF_RETURN_IF_ERROR( + AllocateScopedShapedBuffer(memory_manager, backend, device_ordinal(), + on_host_shape(), &scoped_buffer)); + TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal())); + + mutex_lock lock(lock_); + if (literal_ != nullptr) { + TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice( + stream.get(), *literal_, *scoped_buffer)); + + auto shaped_buffer = scoped_buffer->release(); + InitializeFromShapedBuffer(shaped_buffer, backend->memory_allocator(), + device_ordinal()); + literal_ = nullptr; + return true; + } + return false; +} + +xla::StatusOr XRTTupleAllocation::PinAndSwapIn( + XRTMemoryManager* memory_manager, xla::Backend* backend) { + Pin(); + return SwapIn(memory_manager, backend); +} + +bool XRTTupleAllocation::IsSwapped() const { + mutex_lock lock(lock_); + return literal_ != nullptr; +} + +int64 XRTTupleAllocation::Pin() { return pin_count_.fetch_add(1); } + +int64 XRTTupleAllocation::Unpin() { return pin_count_.fetch_sub(1); } + +bool XRTTupleAllocation::IsPinned() const { return pin_count_ != 0; } + void XRTTupleAllocation::DiscardAllocation( const xla::ShapeIndex& buffer_index) { buffers_.element(buffer_index)->DiscardAllocation(); } -const xla::Shape& XRTTupleAllocation::on_host_shape() { return on_host_shape_; } +const xla::Shape& XRTTupleAllocation::on_host_shape() const { + return on_host_shape_; +} -const xla::Shape& XRTTupleAllocation::on_device_shape() { +const xla::Shape& XRTTupleAllocation::on_device_shape() const { return on_device_shape_; } -int XRTTupleAllocation::device_ordinal() { return device_ordinal_; } +int XRTTupleAllocation::device_ordinal() const { return device_ordinal_; } -const se::DeviceMemoryBase& XRTTupleAllocation::root_allocation() { +const se::DeviceMemoryBase& XRTTupleAllocation::root_allocation() const { return buffers_.element({})->allocation(); } -/*static*/ Status XRTTupleAllocation::Lookup(ResourceMgr* rm, int64 key, - XRTTupleAllocation** allocation) { - string key_string = absl::StrCat(key); - TF_RETURN_IF_ERROR(rm->Lookup(kTupleContainer, key_string, allocation)); - return Status::OK(); -} - -/*static*/ Status XRTTupleAllocation::DeleteFromResourceManager(ResourceMgr* rm, - int64 key) { - string key_string = absl::StrCat(key); - return rm->Delete(kTupleContainer, key_string); -} - -/* static */ Status XRTTupleAllocation::ReleaseAllAllocations(ResourceMgr* rm) { - VLOG(1) << "Releasing all XRT held device memory"; - return rm->Cleanup(kTupleContainer); -} - -// Helper typedef to make ShapeTree ForEach helper lambda signatures more -// readable. They need a type of const T& where in this case T is the -// following pointer. -typedef XRTBufferAllocation* XRTBufferAllocationPtr; - /*static*/ Status XRTTupleAllocation::MakeSubBuffer( XRTTupleAllocation* parent, const xla::ShapeIndex& subshape, XRTTupleAllocation** allocation, bool alias_parent_allocation) { @@ -330,46 +362,21 @@ typedef XRTBufferAllocation* XRTBufferAllocationPtr; parent_index.push_back(index[i]); } *buffer = parent->buffers_.element(parent_index); - *parent->buffers_.mutable_element(parent_index) = - new XRTBufferAllocation(se::DeviceMemoryBase(), - parent->device_ordinal(), - parent->allocator_); + *parent->buffers_.mutable_element(parent_index) = nullptr; }); } - + (*allocation)->SetDeviceMemorySize(); return Status::OK(); } -/* static */ Status XRTTupleAllocation::CompactAllocations( - ResourceMgr* rm, xla::Backend* backend, int device_ordinal) { - std::vector tuples; - rm->GetContainerResources(kTupleContainer, &tuples); - - std::vector> host_tuples; - for (auto& rm_tuple : tuples) { - XRTTupleAllocation* tuple = - dynamic_cast(rm_tuple.resource.get()); - if (tuple->device_ordinal() == device_ordinal) { - xla::Literal literal(tuple->on_host_shape()); - TF_RETURN_IF_ERROR(tuple->ToLiteral(backend, device_ordinal, &literal)); - host_tuples.emplace_back(rm_tuple.name, std::move(literal)); - // At this point there are two references held onto the XRTTupleAllocation - // object. One in the ResourceMgr, which we release here, and one held - // within the tuples vector, which we release in the tuples.clear() call - // below. - TF_RETURN_IF_ERROR( - rm->Delete(kTupleContainer, rm_tuple.name)); +void XRTTupleAllocation::SetDeviceMemorySize() { + size_t size = 0; + for (auto& index_buffer : buffers_) { + if (index_buffer.second != nullptr) { + size += index_buffer.second->allocation().size(); } } - tuples.clear(); - - for (auto& name_literal : host_tuples) { - XRTTupleAllocation* tuple; - TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateAndTransfer( - name_literal.second, backend, device_ordinal, &tuple)); - TF_RETURN_IF_ERROR(rm->Create(kTupleContainer, name_literal.first, tuple)); - } - return Status::OK(); + device_memory_size_ = size; } /* static */ Status XRTTupleAllocation::ExpandTreeOfTuples( @@ -414,7 +421,7 @@ typedef XRTBufferAllocation* XRTBufferAllocationPtr; } /*static*/ Status XRTTupleAllocation::MakeTuple( - xla::Backend* backend, int device_ordinal, + XRTMemoryManager* memory_manager, xla::Backend* backend, int device_ordinal, const xla::ShapeTree& elements, XRTTupleAllocation** allocation) { auto transfer_manager = backend->transfer_manager(); @@ -429,8 +436,8 @@ typedef XRTBufferAllocation* XRTBufferAllocationPtr; // The aliasing is determined below based on whether or not all the inputs are // released while being transferred. allocation_tmp is a local pointer that is // copied to *allocation at the end only if the method succeeds. - auto allocation_tmp = new XRTTupleAllocation(device_ordinal, allocator, - host_shape, device_shape); + XRTTupleAllocation* allocation_tmp = new XRTTupleAllocation( + device_ordinal, allocator, host_shape, device_shape); core::ScopedUnref allocation_unref(allocation_tmp); // First allocate device memory for the new tuple index tables, one at each // internal node of the elements tree. Do this in a separate pass into a @@ -444,12 +451,12 @@ typedef XRTBufferAllocation* XRTBufferAllocationPtr; TF_RETURN_IF_ERROR(elements.ForEachElementWithStatus( [&](const xla::ShapeIndex& index, const ExpandedTupleInput& element) { if (!elements.IsLeaf(index)) { - xla::Shape subshape = + const xla::Shape& subshape = xla::ShapeUtil::GetSubshape(device_shape, index); uint64 size = transfer_manager->GetByteSizeRequirement(subshape); - TF_ASSIGN_OR_RETURN(se::OwningDeviceMemory buffer, - allocator->Allocate(device_ordinal, size, - /*retry_on_failure=*/false)); + TF_ASSIGN_OR_RETURN( + se::OwningDeviceMemory buffer, + memory_manager->Allocate(backend, device_ordinal, size)); VLOG(2) << "Allocated buffer at " << buffer->opaque() << " index " << index.ToString(); // Move the new buffer into new_tuple_buffers, which takes ownership @@ -487,10 +494,8 @@ typedef XRTBufferAllocation* XRTBufferAllocationPtr; // validated that release_allocation_after_use is false if // element.allocation appears in more than one leaf. element.allocation->buffers_.ForEachMutableElement( - [&](const xla::ShapeIndex& index, XRTBufferAllocationPtr* buffer) { - *buffer = new XRTBufferAllocation( - se::DeviceMemoryBase(), element.allocation->device_ordinal(), - element.allocation->allocator_); + [&](const xla::ShapeIndex&, XRTBufferAllocationPtr* buffer) { + *buffer = nullptr; }); } else { // Increment the refcount on each newly-aliased buffer. @@ -506,6 +511,7 @@ typedef XRTBufferAllocation* XRTBufferAllocationPtr; allocator); } }); + allocation_tmp->SetDeviceMemorySize(); // Because the internal nodes of tuple_buffers are exactly the new index // tables, WriteTupleIndexTables will write only the new index tables and not // rewrite the index tables for the existing allocations. @@ -519,36 +525,47 @@ typedef XRTBufferAllocation* XRTBufferAllocationPtr; return Status::OK(); } -Status XRTTupleAllocation::Intern(ResourceMgr* rm, int64* key) { - *key = get_uid(); - string key_string = absl::StrCat(*key); - return rm->Create(kTupleContainer, key_string, this); -} - -bool XRTTupleAllocation::IsExclusiveOwner() { - for (const auto& buffer : buffers_) { - if (!buffer.second->RefCountIsOne()) return false; +bool XRTTupleAllocation::IsExclusiveOwner() const { + for (const auto& index_buffer : buffers_) { + if (index_buffer.second != nullptr && + !index_buffer.second->RefCountIsOne()) { + return false; + } } return true; } +size_t XRTTupleAllocation::GetDeviceMemorySize() const { + return device_memory_size_; +} + void XRTTupleAllocation::InitializeFromShapedBuffer( const xla::ShapedBuffer& shaped_buffer, se::DeviceMemoryAllocator* allocator, int device_ordinal) { - for (auto& buffer : buffers_) { + for (auto& index_buffer : buffers_) { + if (index_buffer.second != nullptr) { + index_buffer.second->Unref(); + } // Make a reference-counted version of the allocated buffer. - buffer.second = new XRTBufferAllocation(shaped_buffer.buffer(buffer.first), - device_ordinal, allocator); + index_buffer.second = new XRTBufferAllocation( + shaped_buffer.buffer(index_buffer.first), device_ordinal, allocator); } } -xla::ShapedBuffer XRTTupleAllocation::ToShapedBuffer() { +xla::StatusOr XRTTupleAllocation::ToShapedBuffer() { xla::ShapedBuffer shaped_buffer(on_host_shape(), on_device_shape(), allocator_->platform(), device_ordinal_); - for (const auto& buffer : buffers_) { - shaped_buffer.set_buffer(buffer.second->allocation(), buffer.first); + for (const auto& index_buffer : buffers_) { + if (index_buffer.second == nullptr || + index_buffer.second->allocation().is_null()) { + return errors::InvalidArgument("Literal buffer at index ", + index_buffer.first.ToString(), + " has been released"); + } + shaped_buffer.set_buffer(index_buffer.second->allocation(), + index_buffer.first); } - return shaped_buffer; + return std::move(shaped_buffer); } Status XRTTupleAllocation::AliasBufferFrom(const XRTTupleAllocation& source, @@ -556,37 +573,69 @@ Status XRTTupleAllocation::AliasBufferFrom(const XRTTupleAllocation& source, const xla::ShapeIndex& dest_index) { XRTBufferAllocation* source_buffer = source.buffers_.element(source_index); XRTBufferAllocation* dest_buffer = buffers_.element(dest_index); - // We allow the destination size being zero, because there are cases where we - // are coming in later filling in null/uninitialized device buffers. - // In all other cases, the size of the new buffer must match. - if (source_buffer->size() != dest_buffer->size() && - dest_buffer->size() != 0) { - return errors::InvalidArgument( - "Source buffer at index ", source_index.ToString(), - " does not match the size of destination buffer at index ", - dest_index.ToString(), ": ", source_buffer->size(), " vs ", - dest_buffer->size()); + if (dest_buffer != nullptr) { + // We allow the destination size being zero, because there are cases where + // we are coming in later filling in null/uninitialized device buffers. In + // all other cases, the size of the new buffer must match. + if (source_buffer->allocation().size() != + dest_buffer->allocation().size() && + dest_buffer->allocation().size() != 0) { + return errors::InvalidArgument( + "Source buffer at index ", source_index.ToString(), + " does not match the size of destination buffer at index ", + dest_index.ToString(), ": ", source_buffer->allocation().size(), + " vs ", dest_buffer->allocation().size()); + } + } else { + const xla::Shape& source_subshape = + xla::ShapeUtil::GetSubshape(source.on_device_shape(), source_index); + const xla::Shape& dest_subshape = + xla::ShapeUtil::GetSubshape(on_device_shape(), dest_index); + if (!xla::ShapeUtil::Equal(source_subshape, dest_subshape)) { + return errors::InvalidArgument( + "Source and destination subshapes do not match: source=", + xla::ShapeUtil::HumanStringWithLayout(source_subshape), + " dest=", xla::ShapeUtil::HumanStringWithLayout(dest_subshape)); + } } *buffers_.mutable_element(dest_index) = source_buffer; source_buffer->Ref(); - dest_buffer->Unref(); + if (dest_buffer != nullptr) { + // If we handed over the ownership of a buffer in ToDeviceMemoryTree(), we + // will be called here on the way back from execution, to alias back the + // buffer at that index. In that case the buffers will be the same. So we + // need to discard the memory at the destination buffer, before releasing + // the reference. + if (dest_buffer->allocation().IsSameAs(source_buffer->allocation()) && + dest_buffer != source_buffer) { + dest_buffer->DiscardAllocation(); + } + dest_buffer->Unref(); + } return Status::OK(); } -xla::ShapeTree +xla::StatusOr> XRTTupleAllocation::ToDeviceMemoryTree( const std::function& release_checker) { xla::ShapeTree shaped_tree(on_device_shape()); - for (const auto& buffer : buffers_) { - if (!release_checker(buffer.first)) { - *shaped_tree.mutable_element(buffer.first) = buffer.second->allocation(); + for (const auto& index_buffer : buffers_) { + if (index_buffer.second == nullptr || + index_buffer.second->allocation().is_null()) { + return errors::InvalidArgument("Literal buffer at index ", + index_buffer.first.ToString(), + " has been released"); + } + if (!release_checker(index_buffer.first)) { + *shaped_tree.mutable_element(index_buffer.first) = + index_buffer.second->allocation(); } else { - *shaped_tree.mutable_element(buffer.first) = se::OwningDeviceMemory( - buffer.second->allocation(), device_ordinal_, allocator_); - DiscardAllocation(buffer.first); + // We keep the ownership of the device memory here. + *shaped_tree.mutable_element(index_buffer.first) = se::OwningDeviceMemory( + index_buffer.second->allocation(), device_ordinal_, allocator_); } } - return shaped_tree; + return std::move(shaped_tree); } } // namespace tensorflow diff --git a/tensorflow/compiler/xrt/xrt_state.h b/tensorflow/compiler/xrt/xrt_state.h index 4d284382532..929c77b3f5c 100644 --- a/tensorflow/compiler/xrt/xrt_state.h +++ b/tensorflow/compiler/xrt/xrt_state.h @@ -18,6 +18,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XRT_XRT_STATE_H_ #define TENSORFLOW_COMPILER_XRT_XRT_STATE_H_ +#include #include #include #include @@ -27,17 +28,21 @@ limitations under the License. #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/compiler/xrt/xrt_refptr.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/stream_executor/device_memory_allocator.h" #include "tensorflow/stream_executor/stream_executor.h" namespace tensorflow { +// Cannot include xrt_memory_manager.h here, as it needs to include this file. +class XRTMemoryManager; + // TODO(misard) make this a Tensor if and when that makes sense. // A reference-counted wrapper around a buffer allocation. This maps an XLA // tuple index or a non-tuple XLA shape to a region of device memory. The device @@ -51,36 +56,23 @@ class XRTBufferAllocation : public core::RefCounted { // The region of device memory being wrapped. const se::DeviceMemoryBase& allocation(); - // Sets the DeviceMemoryBase to be null. DiscardAllocation should be called - // when ownership of the underlying buffer has been transferred, e.g., to an - // output buffer when input and output buffers are aliased during - // execution. The call to DiscardAllocation prevents any device buffer being - // freed when the reference count drops to zero. - void DiscardAllocation(); - - // Returns the expected size of the allocation. Since DiscardAllocation() will - // set allocation_ to {null,0}, and since later we might want to replace the - // discarded buffer with a new one, we need to be able to verify the size - // compatibility. - uint64 size() const { return size_; } + void DiscardAllocation() { allocation_ = se::DeviceMemoryBase(); } private: - uint64 size_ = 0; se::DeviceMemoryBase allocation_; int device_ordinal_; se::DeviceMemoryAllocator* allocator_; }; -// Entry in the resource manager corresponding to an allocation handle returned -// to a client. The handle identifies an immutable tuple of data in device -// memory. New handles can be created in three ways: by passing a literal in -// which case device memory is allocated and the literal is transferred to that -// memory; by aliasing a sub-shape of an existing tuple-shaped handle; or by -// aliasing a vector of existing handles to create a new tuple. The underlying -// storage is reference-counted. When a handle is released, the reference count -// of each storage buffer is decremented, and buffers with no outstanding -// references are freed. -class XRTTupleAllocation : public ResourceBase { +// A XRTTupleAllocation represents an allocated memory area on the device. +// New tuples can be created in three ways: by passing a literal in which case +// device memory is allocated and the literal is transferred to that memory; by +// aliasing a sub-shape of an existing tuple-shaped handle; or by aliasing a +// vector of existing handles to create a new tuple. The underlying storage is +// reference-counted. When a handle is released, the reference count of each +// storage buffer is decremented, and buffers with no outstanding references are +// freed. +class XRTTupleAllocation : public core::RefCounted { public: ~XRTTupleAllocation() override; @@ -88,6 +80,7 @@ class XRTTupleAllocation : public ResourceBase { // literal to that memory, and returns a XRTTupleAllocation handle to the // allocated buffers. static Status CreateAndTransfer(const xla::LiteralBase& literal, + XRTMemoryManager* memory_manager, xla::Backend* backend, int device_ordinal, XRTTupleAllocation** allocation); @@ -106,16 +99,11 @@ class XRTTupleAllocation : public ResourceBase { XRTTupleAllocation** allocation, bool alias_parent_allocation); - // Runs a compaction cycle which copies the device data to host, frees the - // device data, and then reallocate and send back the data. - static Status CompactAllocations(ResourceMgr* rm, xla::Backend* backend, - int device_ordinal); - // A structure describing a leaf of a tree of tuples to expand. Each leaf // contains an allocation and indicates whether or not the allocation's handle // should be freed after incorporating its buffers into the expanded tree. struct ExpandedTupleInput { - XRTTupleAllocation* allocation; + RefPtr allocation; bool release_allocation_after_use; }; @@ -129,52 +117,70 @@ class XRTTupleAllocation : public ResourceBase { // an input is repeated, release_input_handle must be false for every leaf // where that input appears. The latter property is not validated by MakeTuple // and must be enforced by the caller. - static Status MakeTuple(xla::Backend* backend, int device_ordinal, + static Status MakeTuple(XRTMemoryManager* memory_manager, + xla::Backend* backend, int device_ordinal, const xla::ShapeTree& elements, XRTTupleAllocation** allocation); - // Retrieves the allocation interned under key from rm. The caller owns a - // reference to allocation after looking it up. - static Status Lookup(ResourceMgr* rm, int64 key, - XRTTupleAllocation** allocation); - - // Deletes the reference in the rm to an allocation interned under key. - static Status DeleteFromResourceManager(ResourceMgr* rm, int64 key); - - // Releases all the device memory allocated by XRT within the resource - // manager. - static Status ReleaseAllAllocations(ResourceMgr* rm); - - // Returns the invalid key value, which will be never generated by the - // Intern() API. - static int64 InvalidKey() { return 0; } - - // Adds the allocation to a ResourceMgr and returns the key that will be used - // to retrieve it. Transfers a reference on *this to rm. - Status Intern(ResourceMgr* rm, int64* key); - // Copies the allocation from device to host and returns it in literal. - Status ToLiteral(xla::Backend* backend, int device_ordinal, - xla::MutableLiteralBase* literal); + Status ToLiteral(xla::Backend* backend, xla::MutableLiteralBase* literal); // Write a new literal value to the allocation. Status WriteLiteral(xla::Backend* backend, const xla::Literal& literal); + // Stores the content of the tuple allocation into the internal literal, and + // releases all the device buffers. The swap_pinned flag tells whether a + // pinned allocation should be swapped out. It should be false on all cases, + // but during the memory compaction operation from the XRTMemoryManager. + // Returns a boolean telling whether the allocation was swapped out. + xla::StatusOr SwapOut(xla::Backend* backend, bool swap_pinned); + + // Allocates the device memory required to store the tuple value held within + // the internal literal, and transfer the literal value into the device + // memory. Returns a boolean telling whether the allocation was swapped in. + xla::StatusOr SwapIn(XRTMemoryManager* memory_manager, + xla::Backend* backend); + + // Pins the allocation first, then swap it in (if it is not already). After + // this API returns, the allocation is pinned and its content on device + // memory. The caller is responsible for releasing the pin-count using the + // Unpin() API. + xla::StatusOr PinAndSwapIn(XRTMemoryManager* memory_manager, + xla::Backend* backend); + + // Checks whether the allocation is currently swapped out. + bool IsSwapped() const; + + // Increases the pin-count of this allocation. If the pin-count is greater + // than 0, the allocation cannot be swapped. Returned the pin-count value + // before the increase. + int64 Pin(); + + // Decreases the pin-count of this allocation. Returned the pin-count value + // before the decrease. + int64 Unpin(); + + // Checks whether the allocation is currently pinned. + bool IsPinned() const; + // True if none of the buffers in the allocation are aliased by any other live // handle. - bool IsExclusiveOwner(); + bool IsExclusiveOwner() const; + + // Retrieves the footprint in terms of device memory, of this allocation. + size_t GetDeviceMemorySize() const; // The ordinal of the device holding this tuple. - int device_ordinal(); + int device_ordinal() const; // Returns the shape of the tuple as seen by the host. - const xla::Shape& on_host_shape(); + const xla::Shape& on_host_shape() const; // Returns the shape of the tuple as stored on the device. - const xla::Shape& on_device_shape(); + const xla::Shape& on_device_shape() const; // Returns the buffer pointed to by the root of the tuple. - const se::DeviceMemoryBase& root_allocation(); + const se::DeviceMemoryBase& root_allocation() const; // Stops managing the storage for the allocation at buffer_index, e.g., // because it has been aliased to the output buffer of a computation. @@ -182,7 +188,7 @@ class XRTTupleAllocation : public ResourceBase { // Returns the tree of allocations as a ShapedBuffer. This tree may not have // the same shape as on_host_shape. - xla::ShapedBuffer ToShapedBuffer(); + xla::StatusOr ToShapedBuffer(); // Aliases the source buffer at source_index into the current tuple allocation // dest_index. @@ -191,14 +197,22 @@ class XRTTupleAllocation : public ResourceBase { const xla::ShapeIndex& dest_index); // Returns the device memory tree of this allocation. If the release_checker - // function returns true for a given index, the ownership of the device memory - // at that index is transferred to the result. Every attempt to read the value - // at that index will fail. - xla::ShapeTree ToDeviceMemoryTree( + // function returns true for a given index, an owned device memory is returned + // to the caller. But the tuple allocation cannot release the ownership in + // full, as the execute operation might fail. So we rely on a call to + // AliasBufferFrom() to re-alias back the buffers. This is not great (to say + // the least), but the current aliasing logic relies on + // MaybeOwningDeviceMemory being owned, to detect the fact that the user may + // want to alias a buffer. Unfortunately to do that, it needs to release the + // ownership, which is a problem if the execute will fail. + // This calls for a refactoring of the whole owning/maybe-owning interface to + // introduce a sharing concept (IOW shared_ptr model vs. unique_ptr). + // We'd need something similar to XRTTupleAllocation instead of + // ScopedShapedBuffer, which wants ownership and does not allow sharing. + xla::StatusOr> + ToDeviceMemoryTree( const std::function& release_checker); - string DebugString() const override { return "XLA allocation handle"; } - private: // Creates a new handle with (tuple) shape. XRTTupleAllocation(int device_ordinal, se::DeviceMemoryAllocator* allocator, @@ -211,6 +225,21 @@ class XRTTupleAllocation : public ResourceBase { se::DeviceMemoryAllocator* allocator, int device_ordinal); + // Releases all the XRTBufferAllocation buffer references and set the + // corresponding shape tree entry to nullptr. + void ReleaseBuffers(); + + // Stores the content of the allocation from device memory to the target host + // literal. + Status StoreToLiteral(xla::Backend* backend, + xla::MutableLiteralBase* literal); + + // Sets the total size of the buffers held within this allocation buffers. + // This API should be called once when an XRTTupleAllocation object is + // created, as the XRTTupleAllocation shapes never change, and hence the + // device memory size. + void SetDeviceMemorySize(); + // Takes a tree 'elements' where each leaf is an allocation, validates that // they are all on device_ordinal managed by allocator, and returns in // host_shape and device_shape the host/device shapes of the expanded tree, @@ -221,9 +250,13 @@ class XRTTupleAllocation : public ResourceBase { se::DeviceMemoryAllocator* allocator, xla::Shape* host_shape, xla::Shape* device_shape); + // The lock which protects the internal operations of the tuple allocation. Is + // mutable to allow const-like operations to be declared as such. + mutable mutex lock_; + // Location of the memory that is being managed. - int device_ordinal_; - se::DeviceMemoryAllocator* allocator_; + const int device_ordinal_; + se::DeviceMemoryAllocator* const allocator_; // The shape that the caller thinks the tuple has. const xla::Shape on_host_shape_; @@ -233,6 +266,13 @@ class XRTTupleAllocation : public ResourceBase { // The tree of reference-counted buffers, which uses on_device_shape_ as its // shape. xla::ShapeTree buffers_; + // The footprint of the allocation, when residing on device memory. + size_t device_memory_size_ = 0; + // If the allocation is swapped out, this is the literal storing its content. + std::unique_ptr literal_; + // A pinned allocation is one which cannot be swapped out. If pin_count_ > 0 + // then the allocation is pinned. + std::atomic pin_count_; }; } // namespace tensorflow diff --git a/tensorflow/compiler/xrt/xrt_util.cc b/tensorflow/compiler/xrt/xrt_util.cc index 518c993f390..baa7112710e 100644 --- a/tensorflow/compiler/xrt/xrt_util.cc +++ b/tensorflow/compiler/xrt/xrt_util.cc @@ -25,6 +25,88 @@ limitations under the License. namespace tensorflow { namespace { +// The ScopedHandles data structure is used in the ExecuteChained() API and its +// task is to track tuple allocation registrations. It is used both the track +// intermediate results of a chained computation, or its final results. Anything +// which is marked to be released, will be released using the XRTMemoryManager +// once the object is destroyed (unless an explicit call to Drop() or Release() +// is made). +class ScopedHandles { + public: + explicit ScopedHandles(RefPtr memory_manager) + : memory_manager_(std::move(memory_manager)) {} + + ~ScopedHandles() { + for (size_t i = 0; i < handles_.size(); ++i) { + if (handles_release_[i]) { + memory_manager_->Release(handles_[i]).IgnoreError(); + } + } + } + + int64 operator[](size_t index) const { return handles_.at(index); } + + size_t size() const { return handles_.size(); } + + // Adds the given handle at the index position, by marking it releasable + // according to the release argument. If an existing, and to-be-released + // handle already exists at the same index, it will be released. + Status Add(size_t index, int64 handle, bool release) { + if (index >= handles_.size()) { + handles_.resize(index + 1, XRTMemoryManager::InvalidKey()); + handles_release_.resize(index + 1, false); + } + if (handles_release_[index]) { + Status status = memory_manager_->Release(handles_[index]); + if (!status.ok()) { + if (release) { + memory_manager_->Release(handle).IgnoreError(); + } + return status; + } + } + handles_[index] = handle; + handles_release_[index] = release; + return Status::OK(); + } + + // Adds a to-be-released tuple allocation at the given index. + Status Add(size_t index, RefPtr tuple) { + return Add(index, memory_manager_->Register(std::move(tuple)), + /*release=*/true); + } + + // Drops the handle at the given index, and releases it using the + // XRTMemoryManager::Release() if marked as to-be-released. + Status Drop(size_t index) { + if (handles_release_.at(index)) { + TF_RETURN_IF_ERROR(memory_manager_->Release(handles_[index])); + } + Release(index); + return Status::OK(); + } + + // Releases the handle at the given index. The destructor will not use that + // XRTMemoryManager::Release() API on such handle. + int64 Release(size_t index) { + int64 handle = handles_.at(index); + handles_[index] = XRTMemoryManager::InvalidKey(); + handles_release_[index] = false; + return handle; + } + + // Looks up the handle stored at the given index, and returns the matching + // tuple allocation. + xla::StatusOr> Lookup(size_t index) const { + return memory_manager_->Lookup(handles_.at(index)); + } + + private: + RefPtr memory_manager_; + std::vector handles_; + std::vector handles_release_; +}; + bool DebugOptionsPassThroughEnabled() { const char* env = getenv("TF_XLA_DEBUG_OPTIONS_PASSTHROUGH"); bool enabled = @@ -61,6 +143,23 @@ Status MakeOutput(const RefPtr& output, int64 index, return Status::OK(); } +Status PopulateOpWorkingSet(xla::Backend* backend, + const xrt::XRTChainedExecuteOp& op, + int current_index, const ScopedHandles& outputs, + XRTMemoryManager::WorkingSet* working_set) { + for (int i = 0; i < op.inputs_size(); ++i) { + auto& input = op.inputs(i); + if (input.op_index() >= current_index) { + return errors::InvalidArgument( + "Input index ", input.op_index(), + " is above the current position: ", current_index); + } + TF_RETURN_IF_ERROR( + working_set->LookupAndPin(backend, outputs[input.op_index()])); + } + return Status::OK(); +} + } // namespace xla::DebugOptions BuildXlaDebugOptions(const xla::DebugOptions& ref_options) { @@ -81,7 +180,7 @@ xla::DebugOptions BuildXlaDebugOptions(const xla::DebugOptions& ref_options) { } xla::StatusOr> GetComputationInputs( - OpKernelContext* context, ResourceMgr* rm, const char* input_name) { + OpKernelContext* context, const char* input_name) { OpInputList arg_list; TF_RETURN_IF_ERROR(context->input_list(input_name, &arg_list)); // Concatenate all input uids from list of scalars-or-vectors carrying them. @@ -102,7 +201,8 @@ xla::StatusOr> GetComputationInputs( return std::move(input_coords); } -Status CreateExecuteOutput(OpKernelContext* context, ResourceMgr* rm, +Status CreateExecuteOutput(OpKernelContext* context, + XRTMemoryManager* memory_manager, RefPtr output_tuple, bool return_exploded_tuple) { if (return_exploded_tuple && output_tuple->on_host_shape().IsTuple()) { @@ -117,23 +217,21 @@ Status CreateExecuteOutput(OpKernelContext* context, ResourceMgr* rm, TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer( output_tuple.get(), {i}, &suballocation, /*alias_parent_allocation=*/false)); - int64 key; - TF_RETURN_IF_ERROR(suballocation->Intern(rm, &key)); - output_tensor->vec()(i) = key; + output_tensor->vec()(i) = memory_manager->Register(suballocation); } } else { Tensor* output_tensor; TF_RETURN_IF_ERROR( context->allocate_output(0, TensorShape({}), &output_tensor)); - int64 key; - TF_RETURN_IF_ERROR(output_tuple->Intern(rm, &key)); - output_tuple.release(); - output_tensor->scalar()() = key; + output_tensor->scalar()() = + memory_manager->Register(std::move(output_tuple)); } return Status::OK(); } -Status ExecuteChained(OpKernelContext* context, ResourceMgr* rm, +Status ExecuteChained(OpKernelContext* context, + const RefPtr& memory_manager, + xla::Backend* backend, int device_ordinal, const xrt::XRTChainedExecutePlan& plan, const xrt::XRTChainedExecuteConfig& config, const ChainedExecuteFn& execute_op) { @@ -145,41 +243,43 @@ Status ExecuteChained(OpKernelContext* context, ResourceMgr* rm, uses[input.op_index()] += 1; } } - std::vector> ops_outputs(plan.ops_size()); - std::vector> results; + + ScopedHandles outputs(memory_manager); + ScopedHandles results(memory_manager); for (int i = 0; i < plan.ops_size(); ++i) { auto& op = plan.ops(i); if (op.op_oneof_case() == xrt::XRTChainedExecuteOp::kDataHandle) { - // This operation is a device data load. Fetch the proper - // XRTTupleAllocation behind the user handle and fill up the op output at - // the current position. - XRTTupleAllocation* tuple; - TF_RETURN_IF_ERROR( - XRTTupleAllocation::Lookup(rm, op.data_handle(), &tuple)); - ops_outputs[i].reset(tuple); + // This operation is a device data load. Set the handle as output and + // leave the release flag off, since this is not an intermediate output. + TF_RETURN_IF_ERROR(outputs.Add(i, op.data_handle(), /*release=*/false)); } else if (op.op_oneof_case() == xrt::XRTChainedExecuteOp::kComputationHandle) { // This is an XRT execute operation, forward to the device specific - // handler. - TF_ASSIGN_OR_RETURN(ops_outputs[i], execute_op(op, i, ops_outputs)); + // handler. Populating the working set makes sure the input allocations + // for this execute operations are pinned to device memory. + XRTMemoryManager::WorkingSet working_set(memory_manager); + TF_RETURN_IF_ERROR( + PopulateOpWorkingSet(backend, op, i, outputs, &working_set)); + TF_ASSIGN_OR_RETURN(auto tuple, + execute_op(op, working_set.PinnedTuples())); + TF_RETURN_IF_ERROR(outputs.Add(i, std::move(tuple))); } else { return errors::InvalidArgument( "Undefined operation kind at post-order position ", i); } // If the result of this chained operation is an output result, feed the - // results vector at the desired position. + // results at the desired position. for (auto& output : op.outputs()) { - if (output.result_index() >= results.size()) { - results.resize(output.result_index() + 1); - } - TF_RETURN_IF_ERROR(MakeOutput(ops_outputs[i], output.output_index(), - &results[output.result_index()])); + TF_ASSIGN_OR_RETURN(auto tuple, outputs.Lookup(i)); + RefPtr result; + TF_RETURN_IF_ERROR(MakeOutput(tuple, output.output_index(), &result)); + TF_RETURN_IF_ERROR(results.Add(output.result_index(), std::move(result))); } // Drop intermediate results which have no more users. for (auto& input : op.inputs()) { uses[input.op_index()] -= 1; if (uses[input.op_index()] == 0) { - ops_outputs[input.op_index()].reset(); + TF_RETURN_IF_ERROR(outputs.Drop(input.op_index())); } } } @@ -188,12 +288,7 @@ Status ExecuteChained(OpKernelContext* context, ResourceMgr* rm, TF_RETURN_IF_ERROR(context->allocate_output( 0, TensorShape({static_cast(results.size())}), &output_tensor)); for (size_t i = 0; i < results.size(); ++i) { - int64 key = XRTTupleAllocation::InvalidKey(); - if (results[i] != nullptr) { - TF_RETURN_IF_ERROR(results[i]->Intern(rm, &key)); - results[i].release(); - } - output_tensor->vec()(i) = key; + output_tensor->vec()(i) = results.Release(i); } return Status::OK(); } diff --git a/tensorflow/compiler/xrt/xrt_util.h b/tensorflow/compiler/xrt/xrt_util.h index 07159dd5677..32244a63081 100644 --- a/tensorflow/compiler/xrt/xrt_util.h +++ b/tensorflow/compiler/xrt/xrt_util.h @@ -18,97 +18,19 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XRT_XRT_UTIL_H_ #define TENSORFLOW_COMPILER_XRT_XRT_UTIL_H_ +#include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xrt/xrt.pb.h" +#include "tensorflow/compiler/xrt/xrt_memory_manager.h" +#include "tensorflow/compiler/xrt/xrt_refptr.h" #include "tensorflow/compiler/xrt/xrt_state.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/lib/core/status.h" namespace tensorflow { -// Reference counted smart pointer for XRT objects providing the standard -// Ref()/Unref() APIs. -template -class RefPtr { - public: - RefPtr() = default; - // Creates a RefPtr from a pointer. This is an ownership transfer operation, - // and the caller has to own a valid reference to ptr (unless ptr is nullptr). - RefPtr(T* ptr) : ptr_(ptr) {} - RefPtr(const RefPtr& other) : ptr_(other.ptr_) { Acquire(ptr_); } - RefPtr(RefPtr&& other) : ptr_(other.ptr_) { other.ptr_ = nullptr; } - - ~RefPtr() { Release(ptr_); } - - RefPtr& operator=(const RefPtr& other) { - if (this != &other) { - Acquire(other.ptr_); - Release(ptr_); - ptr_ = other.ptr_; - } - return *this; - } - - RefPtr& operator=(RefPtr&& other) { - if (this != &other) { - Release(ptr_); - ptr_ = other.ptr_; - other.ptr_ = nullptr; - } - return *this; - } - - operator bool() const { return ptr_ != nullptr; } - bool operator==(const RefPtr& rhs) const { return ptr_ == rhs.ptr_; } - bool operator!=(const RefPtr& rhs) const { return ptr_ != rhs.ptr_; } - bool operator==(const T* ptr) const { return ptr_ == ptr; } - bool operator!=(const T* ptr) const { return ptr_ != ptr; } - bool operator==(std::nullptr_t ptr) const { return ptr_ == ptr; } - bool operator!=(std::nullptr_t ptr) const { return ptr_ != ptr; } - - T* get() const { return ptr_; } - - T* operator->() const { - CHECK(ptr_ != nullptr); // Crash OK - return ptr_; - } - - T& operator*() const { - CHECK(ptr_ != nullptr); // Crash OK - return *ptr_; - } - - T* release() { - T* ptr = ptr_; - ptr_ = nullptr; - return ptr; - } - - // Resets the RefPtr from a pointer. This is an ownership transfer operation, - // and the caller has to own a valid reference to ptr (unless ptr is nullptr). - void reset(T* ptr = nullptr) { - Release(ptr_); - ptr_ = ptr; - } - - private: - static void Release(T* ptr) { - if (ptr != nullptr) { - ptr->Unref(); - } - } - - static void Acquire(T* ptr) { - if (ptr != nullptr) { - ptr->Ref(); - } - } - - T* ptr_ = nullptr; -}; - struct InputCoords { explicit InputCoords(int64 handle) : handle(handle) {} InputCoords(int64 handle, xla::ShapeIndex index) @@ -128,12 +50,13 @@ xla::DebugOptions BuildXlaDebugOptions(const xla::DebugOptions& ref_options); // Populates the input_coords with a list of input coordinates from a input_name // op argument. xla::StatusOr> GetComputationInputs( - OpKernelContext* context, ResourceMgr* rm, const char* input_name); + OpKernelContext* context, const char* input_name); // Create the XRT execute output tensor given the computation result // (output_tuple). The return_exploded_tuple tells whether a tuple result should // be returned as vector of handles representing each tuple child. -Status CreateExecuteOutput(OpKernelContext* context, ResourceMgr* rm, +Status CreateExecuteOutput(OpKernelContext* context, + XRTMemoryManager* memory_manager, RefPtr output_tuple, bool return_exploded_tuple); @@ -141,9 +64,11 @@ Status CreateExecuteOutput(OpKernelContext* context, ResourceMgr* rm, // function. using ChainedExecuteFn = std::function>( - const xrt::XRTChainedExecuteOp&, int, + const xrt::XRTChainedExecuteOp&, absl::Span>)>; -Status ExecuteChained(OpKernelContext* context, ResourceMgr* rm, +Status ExecuteChained(OpKernelContext* context, + const RefPtr& memory_manager, + xla::Backend* backend, int device_ordinal, const xrt::XRTChainedExecutePlan& plan, const xrt::XRTChainedExecuteConfig& config, const ChainedExecuteFn& execute_op); diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 6760ef265d3..c9ee6f9ac83 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -1,9 +1,10 @@ # Description: # contains parts of TensorFlow that are experimental or unstable and which are not supported. -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//tensorflow:__subpackages__"]) +package( + default_visibility = ["//tensorflow:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) load("//third_party/mpi:mpi.bzl", "if_mpi") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") diff --git a/tensorflow/contrib/all_reduce/BUILD b/tensorflow/contrib/all_reduce/BUILD index f6c6560c1c3..2ebb821f0b3 100644 --- a/tensorflow/contrib/all_reduce/BUILD +++ b/tensorflow/contrib/all_reduce/BUILD @@ -3,9 +3,10 @@ # APIs are subject to change. Eventually to be replaced by equivalent # functionality within TensorFlow core. -package(default_visibility = ["//tensorflow:__subpackages__"]) - -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) diff --git a/tensorflow/contrib/android/BUILD b/tensorflow/contrib/android/BUILD index 5608e7ddafa..3806237cf9e 100644 --- a/tensorflow/contrib/android/BUILD +++ b/tensorflow/contrib/android/BUILD @@ -3,16 +3,17 @@ load("@build_bazel_rules_android//android:rules.bzl", "android_library") -package(default_visibility = ["//visibility:public"]) - -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) load( "//tensorflow:tensorflow.bzl", - "tf_copts", "if_android", + "tf_copts", ) exports_files([ diff --git a/tensorflow/contrib/android/asset_manager_filesystem.cc b/tensorflow/contrib/android/asset_manager_filesystem.cc index d14b2126a0f..a5aa950bff6 100644 --- a/tensorflow/contrib/android/asset_manager_filesystem.cc +++ b/tensorflow/contrib/android/asset_manager_filesystem.cc @@ -27,7 +27,7 @@ namespace { string RemoveSuffix(const string& name, const string& suffix) { string output(name); StringPiece piece(output); - str_util::ConsumeSuffix(&piece, suffix); + absl::ConsumeSuffix(&piece, suffix); return string(piece); } @@ -230,7 +230,7 @@ string AssetManagerFileSystem::NormalizeDirectoryPath(const string& fname) { string AssetManagerFileSystem::RemoveAssetPrefix(const string& name) { StringPiece piece(name); - str_util::ConsumePrefix(&piece, prefix_); + absl::ConsumePrefix(&piece, prefix_); return string(piece); } diff --git a/tensorflow/contrib/autograph/BUILD b/tensorflow/contrib/autograph/BUILD index e37ad7a7581..da83008c422 100644 --- a/tensorflow/contrib/autograph/BUILD +++ b/tensorflow/contrib/autograph/BUILD @@ -1,4 +1,6 @@ -licenses(["notice"]) # Apache 2.0 +package( + licenses = ["notice"], # Apache 2.0 +) load("//tensorflow:tensorflow.bzl", "py_test") diff --git a/tensorflow/contrib/autograph/examples/benchmarks/BUILD b/tensorflow/contrib/autograph/examples/benchmarks/BUILD index 651b108e239..0cc42f01fbd 100644 --- a/tensorflow/contrib/autograph/examples/benchmarks/BUILD +++ b/tensorflow/contrib/autograph/examples/benchmarks/BUILD @@ -1,4 +1,6 @@ -licenses(["notice"]) # Apache 2.0 +package( + licenses = ["notice"], # Apache 2.0 +) load("//tensorflow:tensorflow.bzl", "py_test") load("//tensorflow/tools/test:performance.bzl", "tf_py_logged_benchmark") diff --git a/tensorflow/contrib/batching/BUILD b/tensorflow/contrib/batching/BUILD index bc9b2b05172..3b8e3059501 100644 --- a/tensorflow/contrib/batching/BUILD +++ b/tensorflow/contrib/batching/BUILD @@ -2,10 +2,9 @@ package( default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 ) -licenses(["notice"]) # Apache 2.0 - load( "//tensorflow:tensorflow.bzl", "py_test", @@ -34,6 +33,7 @@ py_test( name = "batch_ops_test", size = "small", srcs = ["python/ops/batch_ops_test.py"], + python_version = "PY2", shard_count = 5, srcs_version = "PY2AND3", tags = [ diff --git a/tensorflow/contrib/bigtable/BUILD b/tensorflow/contrib/bigtable/BUILD index 71538e0770d..bc4c145668f 100644 --- a/tensorflow/contrib/bigtable/BUILD +++ b/tensorflow/contrib/bigtable/BUILD @@ -2,19 +2,18 @@ package( default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], # Apache 2.0 ) -licenses(["notice"]) # Apache 2.0 - load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") load( "//tensorflow:tensorflow.bzl", + "tf_cc_test", "tf_copts", "tf_custom_op_library", "tf_gen_op_libs", "tf_gen_op_wrapper_py", "tf_kernel_library", - "tf_cc_test", "tf_py_test", ) diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc b/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc index 0bdaf3ae0bd..01cedd8d762 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc @@ -73,7 +73,7 @@ string RegexFromStringSet(const std::vector& strs) { if (uniq.size() == 1) { return *uniq.begin(); } - return str_util::Join(uniq, "|"); + return absl::StrJoin(uniq, "|"); } } // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/BUILD b/tensorflow/contrib/boosted_trees/BUILD index 6791e379107..95c08f67e54 100644 --- a/tensorflow/contrib/boosted_trees/BUILD +++ b/tensorflow/contrib/boosted_trees/BUILD @@ -1,13 +1,14 @@ # TensorFlow code for training gradient boosted trees. -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) -package(default_visibility = [ - "//visibility:public", -]) - load("//tensorflow:tensorflow.bzl", "py_test") load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD index 968aff18053..8a2beede37d 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD +++ b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD @@ -1,14 +1,13 @@ # This directory contains estimators to train and run inference on # gradient boosted trees on top of TensorFlow. -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - package( default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 ) +exports_files(["LICENSE"]) + load("//tensorflow:tensorflow.bzl", "py_test") py_library( diff --git a/tensorflow/contrib/boosted_trees/lib/BUILD b/tensorflow/contrib/boosted_trees/lib/BUILD index 634dfab1090..56c55a4055d 100644 --- a/tensorflow/contrib/boosted_trees/lib/BUILD +++ b/tensorflow/contrib/boosted_trees/lib/BUILD @@ -1,17 +1,16 @@ # Description: # This directory contains common utilities used in boosted_trees. -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - package( default_visibility = [ "//tensorflow/contrib/boosted_trees:__subpackages__", "//tensorflow/contrib/boosted_trees:friends", ], + licenses = ["notice"], # Apache 2.0 ) +exports_files(["LICENSE"]) + load("//tensorflow:tensorflow.bzl", "py_test") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") diff --git a/tensorflow/contrib/boosted_trees/proto/BUILD b/tensorflow/contrib/boosted_trees/proto/BUILD index b07f0a43142..ed84c7a02d7 100644 --- a/tensorflow/contrib/boosted_trees/proto/BUILD +++ b/tensorflow/contrib/boosted_trees/proto/BUILD @@ -1,4 +1,6 @@ -licenses(["notice"]) # Apache 2.0 +package( + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) diff --git a/tensorflow/contrib/boosted_trees/resources/BUILD b/tensorflow/contrib/boosted_trees/resources/BUILD index c0651868453..1205ce55694 100644 --- a/tensorflow/contrib/boosted_trees/resources/BUILD +++ b/tensorflow/contrib/boosted_trees/resources/BUILD @@ -1,14 +1,13 @@ -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - package( default_visibility = [ "//tensorflow/contrib/boosted_trees:__subpackages__", "//tensorflow/contrib/boosted_trees:friends", ], + licenses = ["notice"], # Apache 2.0 ) +exports_files(["LICENSE"]) + cc_library( name = "stamped_resource", hdrs = ["stamped_resource.h"], diff --git a/tensorflow/contrib/checkpoint/python/BUILD b/tensorflow/contrib/checkpoint/python/BUILD index caedf5b2d1d..aa5c47c6350 100644 --- a/tensorflow/contrib/checkpoint/python/BUILD +++ b/tensorflow/contrib/checkpoint/python/BUILD @@ -1,6 +1,7 @@ -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//tensorflow:internal"]) +package( + default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], # Apache 2.0 +) load("//tensorflow:tensorflow.bzl", "tf_py_test") diff --git a/tensorflow/contrib/cloud/BUILD b/tensorflow/contrib/cloud/BUILD index 523a9efcf05..3a6b6232fb6 100644 --- a/tensorflow/contrib/cloud/BUILD +++ b/tensorflow/contrib/cloud/BUILD @@ -3,10 +3,9 @@ package( default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 ) -licenses(["notice"]) # Apache 2.0 - load( "//tensorflow:tensorflow.bzl", "tf_gen_op_libs", diff --git a/tensorflow/contrib/cloud/kernels/BUILD b/tensorflow/contrib/cloud/kernels/BUILD index 20f8c2b2453..13a03c81061 100644 --- a/tensorflow/contrib/cloud/kernels/BUILD +++ b/tensorflow/contrib/cloud/kernels/BUILD @@ -3,10 +3,9 @@ package( default_visibility = ["//visibility:private"], + licenses = ["notice"], # Apache 2.0 ) -licenses(["notice"]) # Apache 2.0 - load( "//tensorflow:tensorflow.bzl", "tf_cc_test", diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc index 7416eb19d33..cb02cb88a84 100644 --- a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc +++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc @@ -30,7 +30,7 @@ constexpr char kTestDataset[] = "test-dataset"; constexpr char kTestTable[] = "test-table"; bool HasSubstr(StringPiece base, StringPiece substr) { - bool ok = str_util::StrContains(base, substr); + bool ok = absl::StrContains(base, substr); EXPECT_TRUE(ok) << base << ", expected substring " << substr; return ok; } diff --git a/tensorflow/contrib/cluster_resolver/BUILD b/tensorflow/contrib/cluster_resolver/BUILD index f944b7f8843..a552173fb55 100644 --- a/tensorflow/contrib/cluster_resolver/BUILD +++ b/tensorflow/contrib/cluster_resolver/BUILD @@ -6,10 +6,9 @@ package( default_visibility = [ "//tensorflow:__subpackages__", ], + licenses = ["notice"], # Apache 2.0 ) -licenses(["notice"]) # Apache 2.0 - py_library( name = "cluster_resolver_pip", srcs_version = "PY2AND3", diff --git a/tensorflow/contrib/cmake/external/png.cmake b/tensorflow/contrib/cmake/external/png.cmake index 174f7d1d47f..c102b327dce 100644 --- a/tensorflow/contrib/cmake/external/png.cmake +++ b/tensorflow/contrib/cmake/external/png.cmake @@ -16,8 +16,8 @@ include (ExternalProject) include (GNUInstallDirs) set(png_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/png_archive) -set(png_URL https://mirror.bazel.build/github.com/glennrp/libpng/archive/v1.6.35.tar.gz) -set(png_HASH SHA256=6d59d6a154ccbb772ec11772cb8f8beb0d382b61e7ccc62435bf7311c9f4b210) +set(png_URL https://mirror.bazel.build/github.com/glennrp/libpng/archive/v1.6.37.tar.gz) +set(png_HASH SHA256=ca74a0dace179a8422187671aee97dd3892b53e168627145271cad5b5ac81307) set(png_BUILD ${CMAKE_BINARY_DIR}/png/src/png) set(png_INSTALL ${CMAKE_BINARY_DIR}/png/install) diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index cc263d7995c..24e45236a63 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -274,10 +274,9 @@ if (NOT WIN32) COMMAND ${PYTHON_EXECUTABLE} ${tensorflow_source_dir}/tensorflow/tools/git/gen_git_source.py ARGS --raw_generate ${VERSION_INFO_CC} --source_dir ${tensorflow_source_dir} --git_tag_override=${GIT_TAG_OVERRIDE} DEPENDS __force_rebuild) + set(tf_version_srcs ${tensorflow_source_dir}/tensorflow/core/util/version_info.cc) endif() -set(tf_version_srcs ${tensorflow_source_dir}/tensorflow/core/util/version_info.cc) - ######################################################## # tf_core_framework library ######################################################## diff --git a/tensorflow/contrib/compiler/BUILD b/tensorflow/contrib/compiler/BUILD index 773560fcd0b..2c7c56b361f 100644 --- a/tensorflow/contrib/compiler/BUILD +++ b/tensorflow/contrib/compiler/BUILD @@ -1,6 +1,7 @@ -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = [":friends"]) +package( + default_visibility = [":friends"], + licenses = ["notice"], # Apache 2.0 +) package_group( name = "friends", diff --git a/tensorflow/contrib/constrained_optimization/BUILD b/tensorflow/contrib/constrained_optimization/BUILD index bd81e36c423..ac5243d525d 100644 --- a/tensorflow/contrib/constrained_optimization/BUILD +++ b/tensorflow/contrib/constrained_optimization/BUILD @@ -1,6 +1,7 @@ -package(default_visibility = ["//visibility:public"]) - -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) diff --git a/tensorflow/contrib/copy_graph/BUILD b/tensorflow/contrib/copy_graph/BUILD index 6273bcf7a5c..55c75a30e14 100644 --- a/tensorflow/contrib/copy_graph/BUILD +++ b/tensorflow/contrib/copy_graph/BUILD @@ -1,12 +1,13 @@ # Description: # contains parts of TensorFlow that are experimental or unstable and which are not supported. -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) -package(default_visibility = ["//tensorflow:__subpackages__"]) - load("//tensorflow:tensorflow.bzl", "py_test") py_library( diff --git a/tensorflow/contrib/crf/BUILD b/tensorflow/contrib/crf/BUILD index 5c1a17df4f9..c57680f6e4f 100644 --- a/tensorflow/contrib/crf/BUILD +++ b/tensorflow/contrib/crf/BUILD @@ -2,12 +2,13 @@ # Contains classes to construct a CRF layer # APIs here are meant to evolve over time. -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) -package(default_visibility = ["//tensorflow:__subpackages__"]) - load("//tensorflow:tensorflow.bzl", "cuda_py_tests") py_library( diff --git a/tensorflow/contrib/cudnn_rnn/BUILD b/tensorflow/contrib/cudnn_rnn/BUILD index 174d82c1b9a..63f04de3317 100644 --- a/tensorflow/contrib/cudnn_rnn/BUILD +++ b/tensorflow/contrib/cudnn_rnn/BUILD @@ -4,10 +4,9 @@ package( default_visibility = ["//visibility:private"], + licenses = ["notice"], # Apache 2.0 ) -licenses(["notice"]) # Apache 2.0 - exports_files(["LICENSE"]) load("//tensorflow:tensorflow.bzl", "cuda_py_test") diff --git a/tensorflow/contrib/data/BUILD b/tensorflow/contrib/data/BUILD index 38f1c65a4d5..74e3ae067d6 100644 --- a/tensorflow/contrib/data/BUILD +++ b/tensorflow/contrib/data/BUILD @@ -1,6 +1,7 @@ -package(default_visibility = ["//tensorflow:internal"]) - -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 10475cf2866..354683505eb 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -1,6 +1,7 @@ -package(default_visibility = ["//tensorflow:internal"]) - -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) diff --git a/tensorflow/contrib/data/python/kernel_tests/reduce_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/reduce_dataset_test.py index 78019fcc7d8..8132f81c1cd 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reduce_dataset_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/reduce_dataset_test.py @@ -24,13 +24,11 @@ from tensorflow.contrib.data.python.ops import get_single_element from tensorflow.contrib.data.python.ops import grouping from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util -from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -@test_util.run_v1_only("deprecated API, no eager or V2 test coverage") +@test_util.run_all_in_graph_and_eager_modes class ReduceDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): @parameterized.named_parameters( @@ -51,13 +49,10 @@ class ReduceDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): sum_reducer = grouping.Reducer(init_fn, reduce_fn, finalize_fn) - stop_t = array_ops.placeholder(dtypes.int64, shape=[]) - dataset = dataset_ops.Dataset.range(stop_t) + dataset = dataset_ops.Dataset.range(stop) element = get_single_element.reduce_dataset(dataset, sum_reducer) - with self.cached_session() as sess: - value = sess.run(element, feed_dict={stop_t: stop}) - self.assertEqual(stop * (stop - 1) / 2, value) + self.assertEqual(stop * (stop - 1) / 2, self.evaluate(element)) if __name__ == "__main__": diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 0fb406f1167..a4176d522dc 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -1,6 +1,7 @@ -package(default_visibility = ["//tensorflow:internal"]) - -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index 6a88cc68162..0bff4fb7bcd 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -219,7 +219,7 @@ def assert_element_shape(expected_shapes): output_shapes = _merge_output_shapes( dataset_ops.get_legacy_output_shapes(dataset), expected_shapes) # pylint: disable=protected-access - return batching._RestructuredDataset( + return dataset_ops._RestructuredDataset( dataset.map(_check_shape), dataset_ops.get_legacy_output_types(dataset), output_shapes=output_shapes, diff --git a/tensorflow/contrib/decision_trees/proto/BUILD b/tensorflow/contrib/decision_trees/proto/BUILD index 06940a90d5c..0f58675af60 100644 --- a/tensorflow/contrib/decision_trees/proto/BUILD +++ b/tensorflow/contrib/decision_trees/proto/BUILD @@ -1,6 +1,7 @@ -package(default_visibility = ["//visibility:public"]) - -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) exports_files([ "LICENSE", diff --git a/tensorflow/contrib/deprecated/BUILD b/tensorflow/contrib/deprecated/BUILD index 035d8cfc37e..df747ea2c70 100644 --- a/tensorflow/contrib/deprecated/BUILD +++ b/tensorflow/contrib/deprecated/BUILD @@ -1,12 +1,13 @@ # Description: # Contains deprecated functions that we aren't quite ready to remove entirely -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) -package(default_visibility = ["//tensorflow:__subpackages__"]) - load("//tensorflow:tensorflow.bzl", "py_test") py_library( diff --git a/tensorflow/contrib/distribute/BUILD b/tensorflow/contrib/distribute/BUILD index 1fa4a9bcee1..112cb4ac54e 100644 --- a/tensorflow/contrib/distribute/BUILD +++ b/tensorflow/contrib/distribute/BUILD @@ -2,10 +2,9 @@ package( default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 ) -licenses(["notice"]) # Apache 2.0 - exports_files(["LICENSE"]) filegroup( diff --git a/tensorflow/contrib/distribute/README.md b/tensorflow/contrib/distribute/README.md index ea48cb390b9..680907252db 100644 --- a/tensorflow/contrib/distribute/README.md +++ b/tensorflow/contrib/distribute/README.md @@ -1,368 +1,5 @@ # Distribution Strategy -> *NOTE*: This is an experimental feature. The API and performance -> characteristics are subject to change. - -## Overview - -[`DistributionStrategy`](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/distribute/DistributionStrategy) -API is an easy way to distribute your training -across multiple devices/machines. Our goal is to allow users to use existing -models and training code with minimal changes to enable distributed training. -Moreover, we've designed the API in such a way that it works with both eager and -graph execution. - -Currently we support several types of strategies: - -* [`MirroredStrategy`](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/distribute/MirroredStrategy): -This does in-graph replication with synchronous training -on many GPUs on one machine. Essentially, we create copies of all variables in -the model's layers on each device. We then use all-reduce to combine gradients -across the devices before applying them to the variables to keep them in sync. -* [`CollectiveAllReduceStrategy`](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/distribute/CollectiveAllReduceStrategy): -This is a version of `MirroredStrategy` for multi-worker training. It uses -a collective op to do all-reduce. This supports between-graph communication and -synchronization, and delegates the specifics of the all-reduce implementation to -the runtime (as opposed to encoding it in the graph). This allows it to perform -optimizations like batching and switch between plugins that support different -hardware or algorithms. In the future, this strategy will implement -fault-tolerance to allow training to continue when there is worker failure. - -* [`ParameterServerStrategy`](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/distribute/ParameterServerStrategy): -This strategy supports using parameter servers either for multi-GPU local -training or asynchronous multi-machine training. When used to train locally, -variables are not mirrored, instead they are placed on the CPU and operations -are replicated across all local GPUs. In a multi-machine setting, some are -designated as workers and some as parameter servers. Each variable is placed on -one parameter server. Computation operations are replicated across all GPUs of -the workers. - -## Multi-GPU Training - -## Example with Keras API - -Let's see how to scale to multiple GPUs on one machine using `MirroredStrategy` with [tf.keras] (https://www.tensorflow.org/guide/keras). - -Let's define a simple input dataset for training this model. Note that currently we require using -[`tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) -with `DistributionStrategy`. - -```python -import tensorflow as tf -from tensorflow import keras - -features = tf.data.Dataset.from_tensors([1.]).repeat(10000).batch(10) -labels = tf.data.Dataset.from_tensors([1.]).repeat(10000).batch(10) -train_dataset = tf.data.Dataset.zip((features, labels)) -``` - -To distribute this Keras model on multiple GPUs using `MirroredStrategy` we -first instantiate a `MirroredStrategy` object. - -```python -distribution = tf.contrib.distribute.MirroredStrategy() -``` - -Take a very simple model consisting of a single layer. We need to create and compile -the model under the distribution strategy scope. - -```python -with distribution.scope(): - inputs = tf.keras.layers.Input(shape=(1,)) - predictions = tf.keras.layers.Dense(1)(inputs) - model = tf.keras.models.Model(inputs=inputs, outputs=predictions) - - model.compile(loss='mean_squared_error', - optimizer=tf.train.GradientDescentOptimizer(learning_rate=0.2)) -``` - -To train the model we call Keras `fit` API using the input dataset that we -created earlier, same as how we would in a non-distributed case. - -```python -model.fit(train_dataset, epochs=5, steps_per_epoch=10) -``` - -Similarly, we can also call `evaluate` and `predict` as before using appropriate -datasets. - -```python -model.evaluate(eval_dataset, steps=1) -model.predict(predict_dataset, steps=1) -``` - -That's all you need to train your model with Keras on multiple GPUs with -`MirroredStrategy`. It will take care of splitting up -the input dataset, replicating layers and variables on each device, and -combining and applying gradients. - -The model and input code does not have to change because we have changed the -underlying components of TensorFlow (such as -optimizer, batch norm and summaries) to become distribution-aware. -That means those components know how to -combine their state across devices. Further, saving and checkpointing works -seamlessly, so you can save with one or no distribution strategy and resume with -another. - - -## Example with Estimator API - -You can also use Distribution Strategy API with [`Estimator`](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator). Let's see a simple example of it's usage with `MirroredStrategy`. - - -Consider a very simple model function which tries to learn a simple function. - -```python -def model_fn(features, labels, mode): - layer = tf.layers.Dense(1) - logits = layer(features) - - if mode == tf.estimator.ModeKeys.PREDICT: - predictions = {"logits": logits} - return tf.estimator.EstimatorSpec(mode, predictions=predictions) - - loss = tf.losses.mean_squared_error( - labels=labels, predictions=tf.reshape(logits, [])) - - if mode == tf.estimator.ModeKeys.EVAL: - return tf.estimator.EstimatorSpec(mode, loss=loss) - - if mode == tf.estimator.ModeKeys.TRAIN: - train_op = tf.train.GradientDescentOptimizer(0.2).minimize(loss) - return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op) -``` - -Again, let's define a simple input function to feed data for training this model. - - -```python -def input_fn(): - features = tf.data.Dataset.from_tensors([[1.]]).repeat(100) - labels = tf.data.Dataset.from_tensors(1.).repeat(100) - return tf.data.Dataset.zip((features, labels)) -``` - -Now that we have a model function and input function defined, we can define the -estimator. To use `MirroredStrategy`, all we need to do is: - -* Create an instance of the `MirroredStrategy` class. -* Pass it to the -[`RunConfig`](https://www.tensorflow.org/api_docs/python/tf/estimator/RunConfig) -parameter of `Estimator`. - - -```python -distribution = tf.contrib.distribute.MirroredStrategy() -config = tf.estimator.RunConfig(train_distribute=distribution) -classifier = tf.estimator.Estimator(model_fn=model_fn, config=config) -classifier.train(input_fn=input_fn) -classifier.evaluate(input_fn=input_fn) -``` - -That's it! This change will now configure estimator to run on all GPUs on your -machine. - - -## Customization and Performance Tips - -Above, we showed the easiest way to use [`MirroredStrategy`](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/distribute/MirroredStrategy#__init__). -There are few things you can customize in practice: - -* You can specify a list of specific GPUs (using param `devices`) or the number -of GPUs (using param `num_gpus`), in case you don't want auto detection. -* You can specify various parameters for all reduce with the `cross_tower_ops` -param, such as the all reduce algorithm to use, and gradient repacking. - -We've tried to make it such that you get the best performance for your existing -model. We also recommend you follow the tips from -[Input Pipeline Performance Guide](https://www.tensorflow.org/performance/datasets_performance). -Specifically, we found using [`map_and_batch`](https://www.tensorflow.org/performance/datasets_performance#map_and_batch) -and [`dataset.prefetch`](https://www.tensorflow.org/performance/datasets_performance#pipelining) -in the input function gives a solid boost in performance. When using -`dataset.prefetch`, use `buffer_size=None` to let it detect optimal buffer size. - -## Multi-worker Training -### Overview - -For multi-worker training, no code change is required to the `Estimator` code. -You can run the same model code for all tasks in your cluster including -parameter servers and the evaluator. But you need to use -`tf.estimator.train_and_evaluate`, explicitly specify `num_gpus_per_worker` -for your strategy object, and set "TF\_CONFIG" environment variables for each -binary running in your cluster. We'll provide a Kubernetes template in the -[tensorflow/ecosystem](https://github.com/tensorflow/ecosystem) repo which sets -"TF\_CONFIG" for your training tasks. - -### TF\_CONFIG environment variable - -The "TF\_CONFIG" environment variables is a JSON string which specifies what -tasks constitute a cluster, their addresses and each task's role in the cluster. -One example of "TF\_CONFIG" is: - -```python -TF_CONFIG='{ - "cluster": { - "worker": ["host1:port", "host2:port", "host3:port"], - "ps": ["host4:port", "host5:port"] - }, - "task": {"type": "worker", "index": 1} -}' -``` - -This "TF\_CONFIG" specifies that there are three workers and two ps tasks in the -cluster along with their hosts and ports. The "task" part specifies that the -role of the current task in the cluster, worker 1. Valid roles in a cluster is -"chief", "worker", "ps" and "evaluator". There should be no "ps" job for -`CollectiveAllReduceStrategy` and `MirroredStrategy`. The "evaluator" job is -optional and can have at most one task. It does single machine evaluation and if -you don't want to do evaluation, you can pass in a dummy `input_fn` to the -`tf.estimator.EvalSpec` of `tf.estimator.train_and_evaluate`. - -### Dataset - -The `input_fn` you provide to estimator code is for one worker. So remember to -scale up your batch if you have multiple GPUs on each worker. - -The same `input_fn` will be used for all workers if you use -`CollectiveAllReduceStrategy` and `ParameterServerStrategy`. Therefore it is -important to shuffle your dataset in your `input_fn`. - -`MirroredStrategy` will insert a `tf.dataset.Dataset.shard` call in you -`input_fn` if `auto_shard_dataset` is set to `True`. As a result, each worker -gets a fraction of your input data. - -### Performance Tips - -We have been actively working on multi-worker performance. Currently, prefer -`CollectiveAllReduceStrategy` for synchronous multi-worker training. - -### Example - -Let's use the same example for multi-worker. We'll start a cluster with 3 -workers doing synchronous all-reduce training. In the following code snippet, we -start multi-worker training using `tf.estimator.train_and_evaluate`: - -```python -def model_main(): - distribution = tf.contrib.distribute.CollectiveAllReduceStrategy( - num_gpus_per_worker=2) - config = tf.estimator.RunConfig(train_distribute=distribution) - estimator = tf.estimator.Estimator(model_fn=model_fn, config=config) - train_spec = tf.estimator.TrainSpec(input_fn=input_fn) - eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn) - tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) -``` - -**Note**: You don't have to set "TF\_CONFIG" manually if you use our provided -Kubernetes template. - -You'll then need 3 machines, find out their host addresses and one available -port on each machine. Then set "TF\_CONFIG" in each binary and run the above -model code. - -In your worker 0, run: - -```python -os.environ["TF_CONFIG"] = json.dumps({ - "cluster": { - "worker": ["host1:port", "host2:port", "host3:port"] - }, - "task": {"type": "worker", "index": 0} -}) - -# Call the model_main function defined above. -model_main() -``` - -In your worker 1, run: - -```python -os.environ["TF_CONFIG"] = json.dumps({ - "cluster": { - "worker": ["host1:port", "host2:port", "host3:port"] - }, - "task": {"type": "worker", "index": 1} -}) - -# Call the model_main function defined above. -model_main() -``` - -In your worker 2, run: - -```python -os.environ["TF_CONFIG"] = json.dumps({ - "cluster": { - "worker": ["host1:port", "host2:port", "host3:port"] - }, - "task": {"type": "worker", "index": 2} -}) - -# Call the model_main function defined above. -model_main() -``` - -Then you'll find your cluster has started training! You can inspect the logs of -workers or start a tensorboard. - -### Standalone client mode - -We have a new way to run distributed training. You can bring up standard -tensorflow servers in your cluster and run your model code anywhere such as on -your laptop. - -In the above example, instead of calling `model_main`, you can call -`tf.contrib.distribute.run_standard_tensorflow_server().join()`. This will bring -up a cluster running standard tensorflow servers which wait for your request to -start training. - -On your laptop, you can run - -```python -distribution = tf.contrib.distribute.CollectiveAllReduceStrategy( - num_gpus_per_worker=2) -config = tf.estimator.RunConfig( - experimental_distribute=tf.contrib.distribute.DistributeConfig( - train_distribute=distribution, - remote_cluster={"worker": ["host1:port", "host2:port", "host3:port"]})) -estimator = tf.estimator.Estimator(model_fn=model_fn, config=config) -train_spec = tf.estimator.TrainSpec(input_fn=input_fn) -eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn) -tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) -``` - -Then you will see the training logs on your laptop. You can terminate the -training by terminating your process on your laptop. You can also modify your -code and run a new model against the same cluster. - -We've been optimizing the performance of standalone client mode. If you notice -high latency between your laptop and your cluster, you can reduce that latency -by running your model binary in the cluster. - -## Caveats - -This feature is in early stages and there are a lot of improvements forthcoming: - -* Summaries are only computed in the first tower in `MirroredStrategy`. -* Eager support is in the works; performance can be more challenging with eager -execution. -* We currently support the following predefined Keras callbacks: -`ModelCheckpointCallback`, `TensorBoardCallback`. We will soon be adding support for -some of the other callbacks such as `EarlyStopping`, `ReduceLROnPlateau`, etc. If you -create your own callback, you will not have access to all model properties and -validation data. -* If you are [`batching`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#batch) -your input data, we will place one batch on each GPU in each step. So your -effective batch size will be `num_gpus * batch_size`. Therefore, consider -adjusting your learning rate or batch size according to the number of GPUs. -We are working on addressing this limitation by splitting each batch across GPUs -instead. -* PartitionedVariables are not supported yet. - -## What's next? - -Please give distribution strategies a try. This feature is in early stages and -is evolving, so we welcome your feedback via -[issues on GitHub](https://github.com/tensorflow/tensorflow/issues/new). - - +See the guide for overview and examples: +[TensorFlow v1.x](https://www.tensorflow.org/guide/distribute_strategy), +[TensorFlow v2.x](https://www.tensorflow.org/alpha/guide/distribute_strategy). diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index c5ddf6b5533..ecece6b1ef2 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -8,10 +8,9 @@ package( default_visibility = [ "//tensorflow:internal", ], + licenses = ["notice"], # Apache 2.0 ) -licenses(["notice"]) # Apache 2.0 - exports_files(["LICENSE"]) py_library( @@ -86,6 +85,7 @@ cuda_py_test( ], tags = [ "multi_and_single_gpu", + "no_oss", # TODO(b/133330625) ], ) @@ -227,64 +227,6 @@ cuda_py_test( ], ) -cuda_py_test( - name = "estimator_integration_test", - srcs = ["estimator_integration_test.py"], - additional_deps = [ - "//tensorflow/python/distribute:combinations", - "//tensorflow/python/distribute:strategy_combinations", - "@absl_py//absl/testing:parameterized", - "//third_party/py/numpy", - "//tensorflow/contrib/optimizer_v2:training", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/eager:test", - "//tensorflow/python/estimator:estimator_py", - "//tensorflow/python/feature_column", - "//tensorflow/python:framework_ops", - "//tensorflow/python:platform", - "//tensorflow/python:summary", - ], - tags = [ - "multi_and_single_gpu", - "no_oss", # http://b/119349471 - "tf_integration_test", - ], -) - -cuda_py_test( - name = "estimator_training_test", - srcs = ["estimator_training_test.py"], - additional_deps = [ - ":collective_all_reduce_strategy", - "//tensorflow/python/distribute:combinations", - "//tensorflow/python/distribute:strategy_combinations", - ":mirrored_strategy", - "//tensorflow/python/distribute:multi_worker_test_base", - ":parameter_server_strategy", - "//third_party/py/numpy", - "//tensorflow/contrib/optimizer_v2:training", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/distribute:distribute_config", - "//tensorflow/python/distribute:distribute_coordinator", - "//tensorflow/python/distribute:distribute_coordinator_context", - "//tensorflow/python/eager:test", - "//tensorflow/python/estimator:estimator_py", - "//tensorflow/python/feature_column", - "//tensorflow/python:framework_ops", - "//tensorflow/python:platform", - "//tensorflow/python:summary", - ], - shard_count = 48, - tags = [ - "multi_and_single_gpu", - # TODO(b/118768923): Re-enable {a,m,t}san test. - "noasan", - "nomsan", - "notsan", - "no_oss", # http://b/119349471 - ], -) - py_library( name = "monitor", srcs = ["monitor.py"], diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py index d6eff47fdc5..588fa47c6ae 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py @@ -498,8 +498,13 @@ class DistributedCollectiveAllReduceStrategyTest( self.assertEqual('grpc', server_def.protocol) mock_called[0] = True + def mock_configure_collective_ops(*args, **kwargs): + del args, kwargs + with test.mock.patch.object(context.context(), 'enable_collective_ops', - mock_enable_collective_ops): + mock_enable_collective_ops), \ + test.mock.patch.object(context.context(), 'configure_collective_ops', + mock_configure_collective_ops): strategy, _, _ = self._get_test_object( task_type='worker', task_id=1, num_gpus=2, use_core_strategy=True) self.assertTrue(strategy.extended._std_server_started) diff --git a/tensorflow/contrib/distribute/python/estimator_integration_test.py b/tensorflow/contrib/distribute/python/estimator_integration_test.py deleted file mode 100644 index c46616ce60f..00000000000 --- a/tensorflow/contrib/distribute/python/estimator_integration_test.py +++ /dev/null @@ -1,139 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests that show that DistributionStrategy works with canned Estimator.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import shutil -import tempfile -from absl.testing import parameterized -import numpy as np -from tensorflow.contrib.optimizer_v2 import adagrad -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.distribute import combinations -from tensorflow.python.distribute import strategy_combinations -from tensorflow.python.eager import test -from tensorflow.python.estimator import run_config -from tensorflow.python.estimator import training -from tensorflow.python.estimator.canned import dnn_linear_combined -from tensorflow.python.estimator.canned import prediction_keys -from tensorflow.python.estimator.export import export -from tensorflow.python.estimator.inputs import numpy_io -from tensorflow.python.feature_column import feature_column_lib as feature_column -from tensorflow.python.framework import ops -from tensorflow.python.platform import gfile -from tensorflow.python.summary.writer import writer_cache - - -class DNNLinearCombinedClassifierIntegrationTest(test.TestCase, - parameterized.TestCase): - - def setUp(self): - self._model_dir = tempfile.mkdtemp() - - def dataset_input_fn(self, x, y, batch_size, shuffle): - - def input_fn(): - dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) - if shuffle: - dataset = dataset.shuffle(batch_size) - dataset = dataset.repeat(10).batch(batch_size) - return dataset - - return input_fn - - @combinations.generate( - combinations.combine( - mode=['graph'], - distribution=[ - strategy_combinations.one_device_strategy, - strategy_combinations.mirrored_strategy_with_gpu_and_cpu, - strategy_combinations.mirrored_strategy_with_two_gpus, - ], - use_train_and_evaluate=[True, False])) - def test_complete_flow_with_mode(self, distribution, use_train_and_evaluate): - label_dimension = 2 - input_dimension = label_dimension - batch_size = 10 - data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32) - data = data.reshape(batch_size, label_dimension) - train_input_fn = self.dataset_input_fn( - x={'x': data}, - y=data, - batch_size=batch_size // distribution.num_replicas_in_sync, - shuffle=True) - eval_input_fn = self.dataset_input_fn( - x={'x': data}, - y=data, - batch_size=batch_size // distribution.num_replicas_in_sync, - shuffle=False) - predict_input_fn = numpy_io.numpy_input_fn( - x={'x': data}, batch_size=batch_size, shuffle=False) - - linear_feature_columns = [ - feature_column.numeric_column('x', shape=(input_dimension,)) - ] - dnn_feature_columns = [ - feature_column.numeric_column('x', shape=(input_dimension,)) - ] - feature_columns = linear_feature_columns + dnn_feature_columns - estimator = dnn_linear_combined.DNNLinearCombinedRegressor( - linear_feature_columns=linear_feature_columns, - dnn_hidden_units=(2, 2), - dnn_feature_columns=dnn_feature_columns, - label_dimension=label_dimension, - model_dir=self._model_dir, - # TODO(isaprykin): Work around the colocate_with error. - dnn_optimizer=adagrad.AdagradOptimizer(0.001), - linear_optimizer=adagrad.AdagradOptimizer(0.001), - config=run_config.RunConfig( - train_distribute=distribution, eval_distribute=distribution)) - - num_steps = 10 - if use_train_and_evaluate: - scores, _ = training.train_and_evaluate( - estimator, - training.TrainSpec(train_input_fn, max_steps=num_steps), - training.EvalSpec(eval_input_fn)) - else: - estimator.train(train_input_fn, steps=num_steps) - scores = estimator.evaluate(eval_input_fn) - - self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) - self.assertIn('loss', scores) - - predictions = np.array([ - x[prediction_keys.PredictionKeys.PREDICTIONS] - for x in estimator.predict(predict_input_fn) - ]) - self.assertAllEqual((batch_size, label_dimension), predictions.shape) - - feature_spec = feature_column.make_parse_example_spec(feature_columns) - serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( - feature_spec) - export_dir = estimator.export_saved_model(tempfile.mkdtemp(), - serving_input_receiver_fn) - self.assertTrue(gfile.Exists(export_dir)) - - def tearDown(self): - if self._model_dir: - writer_cache.FileWriterCache.clear() - shutil.rmtree(self._model_dir) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/distribute/python/estimator_training_test.py b/tensorflow/contrib/distribute/python/estimator_training_test.py deleted file mode 100644 index 9eebdfd68d8..00000000000 --- a/tensorflow/contrib/distribute/python/estimator_training_test.py +++ /dev/null @@ -1,620 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests that show Distribute Coordinator works with Estimator.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import copy -import glob -import json -import os -import sys -import tempfile -from absl.testing import parameterized -import numpy as np - -from tensorflow.contrib.distribute.python import collective_all_reduce_strategy -from tensorflow.contrib.distribute.python import mirrored_strategy -from tensorflow.contrib.distribute.python import parameter_server_strategy -from tensorflow.contrib.optimizer_v2 import adagrad -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.distribute import combinations -from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib -from tensorflow.python.distribute import distribute_coordinator as dc -from tensorflow.python.distribute import estimator_training as dc_training -from tensorflow.python.distribute import multi_worker_test_base -from tensorflow.python.distribute.distribute_config import DistributeConfig -from tensorflow.python.eager import context -from tensorflow.python.estimator import exporter as exporter_lib -from tensorflow.python.estimator import run_config as run_config_lib -from tensorflow.python.estimator import training as estimator_training -from tensorflow.python.estimator.canned import dnn_linear_combined -from tensorflow.python.estimator.canned import prediction_keys -from tensorflow.python.estimator.export import export as export_lib -from tensorflow.python.feature_column import feature_column_lib as feature_column -from tensorflow.python.platform import gfile -from tensorflow.python.platform import test -from tensorflow.python.summary import summary_iterator -from tensorflow.python.summary.writer import writer_cache -from tensorflow.python.training import session_manager - - -BATCH_SIZE = 10 -LABEL_DIMENSION = 2 -DATA = np.linspace( - 0., 2., BATCH_SIZE * LABEL_DIMENSION, dtype=np.float32).reshape( - BATCH_SIZE, LABEL_DIMENSION) -EVAL_NAME = "foo" -EXPORTER_NAME = "saved_model_exporter" -MAX_STEPS = 10 - -CHIEF = dc._TaskType.CHIEF -EVALUATOR = dc._TaskType.EVALUATOR -WORKER = dc._TaskType.WORKER -PS = dc._TaskType.PS - -original_run_std_server = dc._run_std_server - - -class DistributeCoordinatorIntegrationTest( - multi_worker_test_base.IndependentWorkerTestBase, parameterized.TestCase): - - @classmethod - def setUpClass(cls): - """Create a local cluster with 2 workers.""" - super(DistributeCoordinatorIntegrationTest, cls).setUpClass() - cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( - num_workers=3, num_ps=2, has_eval=True) - - def setUp(self): - self._model_dir = tempfile.mkdtemp() - super(DistributeCoordinatorIntegrationTest, self).setUp() - - def dataset_input_fn(self, x, y, batch_size, shuffle): - - def input_fn(): - dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) - if shuffle: - dataset = dataset.shuffle(batch_size) - dataset = dataset.repeat(100).batch(batch_size) - return dataset - - return input_fn - - def _get_exporter(self, name, fc): - feature_spec = feature_column.make_parse_example_spec(fc) - serving_input_receiver_fn = ( - export_lib.build_parsing_serving_input_receiver_fn(feature_spec)) - return exporter_lib.LatestExporter( - name, serving_input_receiver_fn=serving_input_receiver_fn) - - def _extract_loss_and_global_step(self, event_folder): - """Returns the loss and global step in last event.""" - event_paths = glob.glob(os.path.join(event_folder, "events*")) - self.assertNotEmpty( - event_paths, msg="Event file not found in dir %s" % event_folder) - - loss = None - global_step_count = None - - for e in summary_iterator.summary_iterator(event_paths[-1]): - current_loss = None - for v in e.summary.value: - if v.tag == "loss": - current_loss = v.simple_value - - # If loss is not found, global step is meaningless. - if current_loss is None: - continue - - current_global_step = e.step - if global_step_count is None or current_global_step > global_step_count: - global_step_count = current_global_step - loss = current_loss - - return (loss, global_step_count) - - def _get_estimator(self, - train_distribute, - eval_distribute, - remote_cluster=None): - input_dimension = LABEL_DIMENSION - linear_feature_columns = [ - feature_column.numeric_column("x", shape=(input_dimension,)) - ] - dnn_feature_columns = [ - feature_column.numeric_column("x", shape=(input_dimension,)) - ] - - return dnn_linear_combined.DNNLinearCombinedRegressor( - linear_feature_columns=linear_feature_columns, - dnn_hidden_units=(2, 2), - dnn_feature_columns=dnn_feature_columns, - label_dimension=LABEL_DIMENSION, - model_dir=self._model_dir, - dnn_optimizer=adagrad.AdagradOptimizer(0.001), - linear_optimizer=adagrad.AdagradOptimizer(0.001), - config=run_config_lib.RunConfig( - experimental_distribute=DistributeConfig( - train_distribute=train_distribute, - eval_distribute=eval_distribute, - remote_cluster=remote_cluster))) - - def _complete_flow(self, - train_distribute, - eval_distribute, - remote_cluster=None, - use_train_and_evaluate=True): - estimator = self._get_estimator(train_distribute, eval_distribute, - remote_cluster) - - input_dimension = LABEL_DIMENSION - train_input_fn = self.dataset_input_fn( - x={"x": DATA}, - y=DATA, - batch_size=BATCH_SIZE // train_distribute.num_replicas_in_sync, - shuffle=True) - if eval_distribute: - eval_batch_size = BATCH_SIZE // eval_distribute.num_replicas_in_sync - else: - eval_batch_size = BATCH_SIZE - eval_input_fn = self.dataset_input_fn( - x={"x": DATA}, y=DATA, batch_size=eval_batch_size, shuffle=False) - - linear_feature_columns = [ - feature_column.numeric_column("x", shape=(input_dimension,)) - ] - dnn_feature_columns = [ - feature_column.numeric_column("x", shape=(input_dimension,)) - ] - feature_columns = linear_feature_columns + dnn_feature_columns - - eval_spec = estimator_training.EvalSpec( - name=EVAL_NAME, - input_fn=eval_input_fn, - steps=None, - exporters=self._get_exporter(EXPORTER_NAME, feature_columns), - start_delay_secs=0, - throttle_secs=1) - - if use_train_and_evaluate: - estimator_training.train_and_evaluate( - estimator, - estimator_training.TrainSpec(train_input_fn, max_steps=MAX_STEPS), - eval_spec) - else: - estimator.train(train_input_fn, max_steps=MAX_STEPS) - - latest_ckpt_path = estimator.latest_checkpoint() - metrics = estimator.evaluate(eval_input_fn, - checkpoint_path=latest_ckpt_path, - name=EVAL_NAME) - - # Export the eval result to files. - eval_result = estimator_training._EvalResult( - status=estimator_training._EvalStatus.EVALUATED, - metrics=metrics, - checkpoint_path=latest_ckpt_path) - evaluator = estimator_training._TrainingExecutor._Evaluator(estimator, - eval_spec, - None) - evaluator._export_eval_result(eval_result, True) - - return estimator - - def _inspect_train_and_eval_events(self, estimator): - # Make sure nothing is stuck in limbo. - writer_cache.FileWriterCache.clear() - - # Examine the training events. Use a range to check global step to avoid - # flakyness due to global step race condition. - training_loss, _ = self._extract_loss_and_global_step(self._model_dir) - self.assertIsNotNone(training_loss) - - # Examine the eval events. The global step should be accurate. - eval_dir = os.path.join(self._model_dir, "eval_" + EVAL_NAME) - eval_loss, eval_global_step = self._extract_loss_and_global_step( - event_folder=eval_dir) - self.assertIsNotNone(eval_loss) - self.assertGreaterEqual(eval_global_step, MAX_STEPS) - - # Examine the export folder. - export_dir = os.path.join( - os.path.join(self._model_dir, "export"), EXPORTER_NAME) - self.assertTrue(gfile.Exists(export_dir)) - - # Examine the ckpt for predict. - def predict_input_fn(): - return dataset_ops.Dataset.from_tensor_slices({ - "x": DATA - }).batch(BATCH_SIZE) - - predicted_proba = np.array([ - x[prediction_keys.PredictionKeys.PREDICTIONS] - for x in estimator.predict(predict_input_fn) - ]) - self.assertAllEqual((BATCH_SIZE, LABEL_DIMENSION), predicted_proba.shape) - - def _make_cross_device_ops(self, num_gpus_per_worker): - return cross_device_ops_lib.MultiWorkerAllReduce( - ["/job:worker/task:0", "/job:worker/task:1", "/job:worker/task:2"], - num_gpus_per_worker) - - def _get_strategy_object(self, strategy_cls, eval_strategy=False): - if strategy_cls == mirrored_strategy.CoreMirroredStrategy: - if eval_strategy: - return strategy_cls() - else: - return strategy_cls( - cross_device_ops=self._make_cross_device_ops( - num_gpus_per_worker=context.num_gpus())) - elif (strategy_cls == mirrored_strategy.MirroredStrategy and - not eval_strategy): - return strategy_cls( - num_gpus_per_worker=context.num_gpus(), - cross_device_ops=self._make_cross_device_ops( - num_gpus_per_worker=context.num_gpus())) - else: - return strategy_cls(num_gpus_per_worker=context.num_gpus()) - - @combinations.generate( - combinations.combine( - mode=["graph"], - train_distribute_cls=[ - collective_all_reduce_strategy.CollectiveAllReduceStrategy, - mirrored_strategy.MirroredStrategy, - mirrored_strategy.CoreMirroredStrategy, - parameter_server_strategy.ParameterServerStrategy - ], - eval_distribute_cls=[ - None, - mirrored_strategy.MirroredStrategy, - mirrored_strategy.CoreMirroredStrategy, - parameter_server_strategy.ParameterServerStrategy, - collective_all_reduce_strategy.CollectiveAllReduceStrategy, - ], - required_gpus=[0, 1])) - def test_complete_flow_standalone_client(self, train_distribute_cls, - eval_distribute_cls): - train_distribute = self._get_strategy_object(train_distribute_cls) - - if eval_distribute_cls: - eval_distribute = self._get_strategy_object( - eval_distribute_cls, eval_strategy=True) - else: - eval_distribute = None - - cluster_spec = copy.deepcopy(self._cluster_spec) - if (train_distribute_cls != - parameter_server_strategy.ParameterServerStrategy): - cluster_spec.pop("ps", None) - estimator = self._complete_flow(train_distribute, eval_distribute, - cluster_spec) - self._inspect_train_and_eval_events(estimator) - - @combinations.generate( - combinations.combine( - mode=["graph"], - eval_distribute_class=[ - None, - mirrored_strategy.MirroredStrategy, - mirrored_strategy.CoreMirroredStrategy, - parameter_server_strategy.ParameterServerStrategy, - ], - required_gpus=[0, 1])) - def test_complete_flow_standalone_client_collective_nccl( - self, eval_distribute_class): - train_distribute = ( - collective_all_reduce_strategy.CollectiveAllReduceStrategy( - num_gpus_per_worker=context.num_gpus(), - communication=cross_device_ops_lib.CollectiveCommunication.NCCL)) - - if eval_distribute_class: - eval_distribute = self._get_strategy_object( - eval_distribute_class, eval_strategy=True) - else: - eval_distribute = None - - cluster_spec = copy.deepcopy(self._cluster_spec) - cluster_spec.pop("ps", None) - estimator = self._complete_flow(train_distribute, eval_distribute, - cluster_spec) - self._inspect_train_and_eval_events(estimator) - - @combinations.generate( - combinations.combine( - mode=["graph"], - train_distribute_cls=[ - mirrored_strategy.MirroredStrategy, - mirrored_strategy.CoreMirroredStrategy, - ], - eval_distribute_cls=[ - None, - mirrored_strategy.MirroredStrategy, - mirrored_strategy.CoreMirroredStrategy, - ], - required_gpus=[0, 1])) - def test_estimator_standalone_client(self, train_distribute_cls, - eval_distribute_cls): - train_distribute = self._get_strategy_object(train_distribute_cls) - - if eval_distribute_cls: - eval_distribute = self._get_strategy_object(eval_distribute_cls) - else: - eval_distribute = None - - # We use the whole cluster for evaluation. - cluster = copy.deepcopy(self._cluster_spec) - cluster.pop("evaluator", None) - - estimator = self._complete_flow( - train_distribute, eval_distribute, remote_cluster=cluster, - use_train_and_evaluate=False) - self._inspect_train_and_eval_events(estimator) - - def _mock_run_std_server(self, *args, **kwargs): - ret = original_run_std_server(*args, **kwargs) - # Wait for all std servers to be brought up in order to reduce the chance of - # remote sessions taking local ports that have been assigned to std servers. - self._barrier.wait() - return ret - - def _independent_worker_fn( - self, - train_distribute, - eval_distribute, - ): - with test.mock.patch.object(dc, "_run_std_server", - self._mock_run_std_server): - self._complete_flow(train_distribute, eval_distribute) - - @combinations.generate( - combinations.combine( - mode=["graph"], - train_distribute_cls=[ - collective_all_reduce_strategy.CollectiveAllReduceStrategy, - parameter_server_strategy.ParameterServerStrategy, - ], - eval_distribute_cls=[ - None, - mirrored_strategy.MirroredStrategy, - mirrored_strategy.CoreMirroredStrategy, - parameter_server_strategy.ParameterServerStrategy, - collective_all_reduce_strategy.CollectiveAllReduceStrategy, - ], - required_gpus=[0, 1])) - def test_complete_flow_independent_worker_between_graph( - self, train_distribute_cls, eval_distribute_cls): - if (context.num_gpus() < 2 and eval_distribute_cls == - collective_all_reduce_strategy.CollectiveAllReduceStrategy): - self.skipTest("`CollectiveAllReduceStrategy` needs at least two towers.") - - train_distribute = self._get_strategy_object(train_distribute_cls) - - if eval_distribute_cls: - eval_distribute = self._get_strategy_object( - eval_distribute_cls, eval_strategy=True) - else: - eval_distribute = None - - if (train_distribute_cls == parameter_server_strategy - .ParameterServerStrategy): - cluster_spec = multi_worker_test_base.create_cluster_spec( - num_workers=3, num_ps=2, has_eval=True) - # 3 workers, 2 ps and 1 evaluator. - self._barrier = dc._Barrier(6) - else: - cluster_spec = multi_worker_test_base.create_cluster_spec( - num_workers=3, num_ps=0, has_eval=True) - # 3 workers and 1 evaluator. - self._barrier = dc._Barrier(4) - - threads = self.run_multiple_tasks_in_threads(self._independent_worker_fn, - cluster_spec, train_distribute, - eval_distribute) - threads_to_join = [] - for task_type, ts in threads.items(): - if task_type == PS: - continue - for t in ts: - threads_to_join.append(t) - self.join_independent_workers(threads_to_join) - - estimator = self._get_estimator(train_distribute, eval_distribute) - self._inspect_train_and_eval_events(estimator) - - @combinations.generate( - combinations.combine( - mode=["graph"], - train_distribute_cls=[ - mirrored_strategy.MirroredStrategy, - mirrored_strategy.CoreMirroredStrategy - ], - eval_distribute_cls=[ - None, - mirrored_strategy.MirroredStrategy, - mirrored_strategy.CoreMirroredStrategy - ], - required_gpus=[0, 1])) - def test_complete_flow_independent_worker_in_graph(self, train_distribute_cls, - eval_distribute_cls): - train_distribute = self._get_strategy_object(train_distribute_cls) - - if eval_distribute_cls: - eval_distribute = self._get_strategy_object( - eval_distribute_cls, eval_strategy=True) - else: - eval_distribute = None - - cluster_spec = multi_worker_test_base.create_cluster_spec( - num_workers=3, num_ps=0, has_eval=True) - # 3 workers and 1 evaluator. - self._barrier = dc._Barrier(4) - threads = self.run_multiple_tasks_in_threads(self._independent_worker_fn, - cluster_spec, train_distribute, - eval_distribute) - self.join_independent_workers([threads[WORKER][0], threads[EVALUATOR][0]]) - - estimator = self._get_estimator(train_distribute, eval_distribute) - self._inspect_train_and_eval_events(estimator) - - -TF_CONFIG_WITH_CHIEF = { - "cluster": { - "chief": ["fake_chief"], - }, - "task": { - "type": "chief", - "index": 0 - } -} - -TF_CONFIG_WITH_MASTER = { - "cluster": { - "master": ["fake_master"], - }, - "task": { - "type": "master", - "index": 0 - } -} - -TF_CONFIG_WITHOUT_TASK = {"cluster": {"chief": ["fake_worker"]}} - - -class RunConfigTest(test.TestCase): - - def test_previously_unexpected_cluster_spec(self): - with test.mock.patch.dict( - "os.environ", {"TF_CONFIG": json.dumps(TF_CONFIG_WITHOUT_TASK)}): - run_config_lib.RunConfig( - experimental_distribute=DistributeConfig( - train_distribute=mirrored_strategy.CoreMirroredStrategy( - ["/device:GPU:0", "/device:GPU:1"]))) - - def test_should_run_distribute_coordinator(self): - """Tests that should_run_distribute_coordinator return a correct value.""" - # We don't use distribute coordinator for local training. - self.assertFalse( - dc_training.should_run_distribute_coordinator( - run_config_lib.RunConfig())) - - # When `train_distribute` is not specified, don't use distribute - # coordinator. - with test.mock.patch.dict("os.environ", - {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}): - self.assertFalse( - dc_training.should_run_distribute_coordinator( - run_config_lib.RunConfig())) - - # When `train_distribute` is specified and TF_CONFIG is detected, use - # distribute coordinator. - with test.mock.patch.dict("os.environ", - {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}): - config_with_train_distribute = run_config_lib.RunConfig( - experimental_distribute=DistributeConfig( - train_distribute=mirrored_strategy.CoreMirroredStrategy( - ["/device:GPU:0", "/device:GPU:1"]))) - config_with_eval_distribute = run_config_lib.RunConfig( - experimental_distribute=DistributeConfig( - eval_distribute=mirrored_strategy.CoreMirroredStrategy( - ["/device:GPU:0", "/device:GPU:1"]))) - self.assertTrue( - dc_training.should_run_distribute_coordinator( - config_with_train_distribute)) - self.assertFalse( - dc_training.should_run_distribute_coordinator( - config_with_eval_distribute)) - - # With a master in the cluster, don't run distribute coordinator. - with test.mock.patch.dict("os.environ", - {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_MASTER)}): - config = run_config_lib.RunConfig( - experimental_distribute=DistributeConfig( - train_distribute=mirrored_strategy.CoreMirroredStrategy( - ["/device:GPU:0", "/device:GPU:1"]))) - self.assertFalse(dc_training.should_run_distribute_coordinator(config)) - - def test_init_run_config_duplicate_distribute(self): - with self.assertRaises(ValueError): - run_config_lib.RunConfig( - train_distribute=mirrored_strategy.CoreMirroredStrategy(), - experimental_distribute=DistributeConfig( - train_distribute=mirrored_strategy.CoreMirroredStrategy())) - - with self.assertRaises(ValueError): - run_config_lib.RunConfig( - eval_distribute=mirrored_strategy.CoreMirroredStrategy(), - experimental_distribute=DistributeConfig( - eval_distribute=mirrored_strategy.CoreMirroredStrategy())) - - def test_init_run_config_none_distribute_coordinator_mode(self): - # We don't use distribute coordinator for local training. - config = run_config_lib.RunConfig( - train_distribute=mirrored_strategy.CoreMirroredStrategy()) - dc_training.init_run_config(config, {}) - self.assertIsNone(config._distribute_coordinator_mode) - - # With a master in the cluster, don't run distribute coordinator. - with test.mock.patch.dict("os.environ", - {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_MASTER)}): - config = run_config_lib.RunConfig( - train_distribute=mirrored_strategy.CoreMirroredStrategy()) - self.assertIsNone(config._distribute_coordinator_mode) - - # When `train_distribute` is not specified, don't use distribute - # coordinator. - with test.mock.patch.dict("os.environ", - {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}): - config = run_config_lib.RunConfig() - self.assertFalse(hasattr(config, "_distribute_coordinator_mode")) - - def test_init_run_config_independent_worker(self): - # When `train_distribute` is specified and TF_CONFIG is detected, use - # distribute coordinator with INDEPENDENT_WORKER mode. - with test.mock.patch.dict("os.environ", - {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}): - config = run_config_lib.RunConfig( - train_distribute=mirrored_strategy.CoreMirroredStrategy()) - self.assertEqual(config._distribute_coordinator_mode, - dc.CoordinatorMode.INDEPENDENT_WORKER) - - def test_init_run_config_standalone_client(self): - # When `train_distribute` is specified, TF_CONFIG is detected and - # `experimental.remote_cluster` is set use distribute coordinator with - # STANDALONE_CLIENT mode. - config = run_config_lib.RunConfig( - train_distribute=mirrored_strategy.CoreMirroredStrategy(), - experimental_distribute=DistributeConfig( - remote_cluster={"chief": ["fake_worker"]})) - self.assertEqual(config._distribute_coordinator_mode, - dc.CoordinatorMode.STANDALONE_CLIENT) - - -if __name__ == "__main__": - # Reduce `recovery_wait_secs` from 30 seconds so the test completes quickly. - orig_init = session_manager.SessionManager.__init__ - - def new_init(*args, **kwargs): - kwargs.pop("recovery_wait_secs", None) - kwargs["recovery_wait_secs"] = 0.5 - orig_init(*args, **kwargs) - - session_manager.SessionManager.__init__ = new_init - - with test.mock.patch.object(sys, "exit", os._exit): - test.main() diff --git a/tensorflow/contrib/distribute/python/examples/BUILD b/tensorflow/contrib/distribute/python/examples/BUILD index 75fbc3bf53f..afabba7bfb4 100644 --- a/tensorflow/contrib/distribute/python/examples/BUILD +++ b/tensorflow/contrib/distribute/python/examples/BUILD @@ -4,10 +4,9 @@ package( default_visibility = [ "//tensorflow:internal", ], + licenses = ["notice"], # Apache 2.0 ) -licenses(["notice"]) # Apache 2.0 - exports_files(["LICENSE"]) py_binary( diff --git a/tensorflow/contrib/distribute/python/examples/mnist_eager_multigpu.py b/tensorflow/contrib/distribute/python/examples/mnist_eager_multigpu.py index c045a5586b9..502f94c5728 100644 --- a/tensorflow/contrib/distribute/python/examples/mnist_eager_multigpu.py +++ b/tensorflow/contrib/distribute/python/examples/mnist_eager_multigpu.py @@ -37,8 +37,6 @@ flags.DEFINE_integer("batch_size", 64, flags.DEFINE_integer("num_epochs", 10, "How many epochs to run?") flags.DEFINE_float("learning_rate", 0.01, "Learning Rate") flags.DEFINE_float("momentum", 0.5, "SGD momentum") -flags.DEFINE_boolean("use_function", False, - "Should we wrap the step in a tf.function.") FLAGS = flags.FLAGS NUM_TRAIN_IMAGES = 60000 @@ -70,15 +68,13 @@ def compute_loss(logits, labels): return loss * (1. / FLAGS.batch_size) -def mnist_datasets(): +def mnist_datasets(strategy): (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() # Numpy defaults to dtype=float64; TF defaults to float32. Stick with float32. x_train, x_test = x_train / np.float32(255), x_test / np.float32(255) y_train, y_test = y_train.astype(np.int64), y_test.astype(np.int64) - # TODO(priyag): `strategy.make_numpy_iterator` can be used directly instead of - # converting to datasets. - train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) - test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)) + train_dataset = strategy.experimental_make_numpy_dataset((x_train, y_train)) + test_dataset = strategy.experimental_make_numpy_dataset((x_test, y_test)) return train_dataset, test_dataset @@ -97,7 +93,7 @@ def main(unused_argv): strategy = tf.distribute.MirroredStrategy(devices) with strategy.scope(): - train_ds, test_ds = mnist_datasets() + train_ds, test_ds = mnist_datasets(strategy) train_ds = train_ds.shuffle(NUM_TRAIN_IMAGES).batch(FLAGS.batch_size) test_ds = test_ds.batch(FLAGS.batch_size) @@ -110,55 +106,47 @@ def main(unused_argv): test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( "test_accuracy", dtype=tf.float32) - def train_step(inputs): - images, labels = inputs - with tf.GradientTape() as tape: - logits = model(images, training=True) + @tf.function + def train_epoch(train_dist_dataset): + """Training Step.""" + def step_fn(images, labels): + with tf.GradientTape() as tape: + logits = model(images, training=True) + loss = compute_loss(logits, labels) + grads = tape.gradient(loss, model.variables) + optimizer.apply_gradients(zip(grads, model.variables)) + training_loss.update_state(loss) + training_accuracy.update_state(labels, logits) + + for images, labels in train_dist_dataset: + strategy.experimental_run_v2(step_fn, args=(images, labels)) + + @tf.function + def test_epoch(test_dist_dataset): + """Testing Step.""" + def step_fn(images, labels): + logits = model(images, training=False) loss = compute_loss(logits, labels) - grads = tape.gradient(loss, model.variables) - optimizer.apply_gradients(zip(grads, model.variables)) - training_loss.update_state(loss) - training_accuracy.update_state(labels, logits) + test_loss.update_state(loss) + test_accuracy.update_state(labels, logits) - def test_step(inputs): - images, labels = inputs - logits = model(images, training=False) - loss = compute_loss(logits, labels) - test_loss.update_state(loss) - test_accuracy.update_state(labels, logits) + for images, labels in test_dist_dataset: + strategy.experimental_run_v2(step_fn, args=(images, labels)) - train_iterator = strategy.make_dataset_iterator(train_ds) - test_iterator = strategy.make_dataset_iterator(test_ds) - - for epoch in range(0, FLAGS.num_epochs): - # TODO(b/123315763): Create the tf.function outside this loop once we are - # able to initialize iterator in eager mode. - dist_train = lambda it: strategy.experimental_run(train_step, it) - dist_test = lambda it: strategy.experimental_run(test_step, it) - if FLAGS.use_function: - dist_train = tf.function(dist_train) - dist_test = tf.function(dist_test) + train_dist_dataset = strategy.experimental_distribute_dataset(train_ds) + test_dist_dataset = strategy.experimental_distribute_dataset(test_ds) + for epoch in range(FLAGS.num_epochs): # Train print("Starting epoch {}".format(epoch)) - train_iterator.initialize() - while True: - try: - dist_train(train_iterator) - except tf.errors.OutOfRangeError: - break + train_epoch(train_dist_dataset) print("Training loss: {:0.4f}, accuracy: {:0.2f}%".format( training_loss.result(), training_accuracy.result() * 100)) training_loss.reset_states() training_accuracy.reset_states() # Test - test_iterator.initialize() - while True: - try: - dist_test(test_iterator) - except tf.errors.OutOfRangeError: - break + test_epoch(test_dist_dataset) print("Test loss: {:0.4f}, accuracy: {:0.2f}%".format( test_loss.result(), test_accuracy.result() * 100)) test_loss.reset_states() diff --git a/tensorflow/contrib/distribute/python/optimizer_v2_test.py b/tensorflow/contrib/distribute/python/optimizer_v2_test.py index df5e5595ccb..bbae1174e49 100644 --- a/tensorflow/contrib/distribute/python/optimizer_v2_test.py +++ b/tensorflow/contrib/distribute/python/optimizer_v2_test.py @@ -72,8 +72,9 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase): def testTrainNetwork(self, distribution, optimizer_fn, use_callable_loss=True): with distribution.scope(): + optimizer = optimizer_fn() model_fn, dataset_fn, layer = minimize_loss_example( - optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss) + optimizer, use_bias=True, use_callable_loss=use_callable_loss) iterator = distribution.make_input_fn_iterator(lambda _: dataset_fn()) def run_step(): diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py index a5fead9596d..90f174d0d47 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py @@ -46,7 +46,7 @@ class ParameterServerStrategy(distribute_lib.StrategyV1): becomes local training where variables are assigned to local CPU or the only GPU. When each worker has more than one GPU, operations will be replicated on these GPUs. In both cases, operations are replicated but variables are not and - these workers share a common view for which paramater server a variable is + these workers share a common view for which parameter server a variable is assigned to. This class assumes between-graph replication will be used and works on a graph diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index e4b7b81d083..5f6bca0cbe1 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -2,12 +2,13 @@ # Contains ops for statistical distributions (with pdf, cdf, sample, etc...). # APIs here are meant to evolve over time. -package(default_visibility = [ - "//learning/brain/contrib/bayesflow:__subpackages__", - "//tensorflow:__subpackages__", -]) - -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = [ + "//learning/brain/contrib/bayesflow:__subpackages__", + "//tensorflow:__subpackages__", + ], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) @@ -94,7 +95,7 @@ cuda_py_test( "//third_party/py/numpy", "@six_archive//:six", "//tensorflow/contrib/learn", - "//tensorflow/contrib/learn:head_test", + "//tensorflow/contrib/learn:head_test_lib", "//tensorflow/python/ops/distributions", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_ops", diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index a500f9fd34c..342bd5ae6d7 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -1,6 +1,7 @@ -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//tensorflow:internal"]) +package( + default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], # Apache 2.0 +) load("//tensorflow:tensorflow.bzl", "py_test") load("//tensorflow:tensorflow.bzl", "cuda_py_test") diff --git a/tensorflow/contrib/eager/python/examples/BUILD b/tensorflow/contrib/eager/python/examples/BUILD index fd5a44a7975..14e97a5138f 100644 --- a/tensorflow/contrib/eager/python/examples/BUILD +++ b/tensorflow/contrib/eager/python/examples/BUILD @@ -1,8 +1,9 @@ # TensorFlow code for training gradient boosted trees. -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//tensorflow:internal"]) +package( + default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], # Apache 2.0 +) py_library( name = "examples_pip", diff --git a/tensorflow/contrib/eager/python/examples/densenet/BUILD b/tensorflow/contrib/eager/python/examples/densenet/BUILD index a001d426fe2..2b85833c151 100644 --- a/tensorflow/contrib/eager/python/examples/densenet/BUILD +++ b/tensorflow/contrib/eager/python/examples/densenet/BUILD @@ -1,6 +1,7 @@ -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//tensorflow:internal"]) +package( + default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], # Apache 2.0 +) load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "py_binary") diff --git a/tensorflow/contrib/eager/python/examples/gan/BUILD b/tensorflow/contrib/eager/python/examples/gan/BUILD index be561a1da66..aaf736c0ded 100644 --- a/tensorflow/contrib/eager/python/examples/gan/BUILD +++ b/tensorflow/contrib/eager/python/examples/gan/BUILD @@ -1,6 +1,7 @@ -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//tensorflow:internal"]) +package( + default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], # Apache 2.0 +) load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "py_binary") diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/BUILD b/tensorflow/contrib/eager/python/examples/l2hmc/BUILD index 35d50990421..99edc8223d0 100644 --- a/tensorflow/contrib/eager/python/examples/l2hmc/BUILD +++ b/tensorflow/contrib/eager/python/examples/l2hmc/BUILD @@ -1,6 +1,7 @@ -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//tensorflow:internal"]) +package( + default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], # Apache 2.0 +) load("//tensorflow:tensorflow.bzl", "cuda_py_test") diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/BUILD b/tensorflow/contrib/eager/python/examples/linear_regression/BUILD index 8536fdbf705..4a11f95902c 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/BUILD +++ b/tensorflow/contrib/eager/python/examples/linear_regression/BUILD @@ -1,6 +1,7 @@ -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//tensorflow:internal"]) +package( + default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], # Apache 2.0 +) load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "py_binary") diff --git a/tensorflow/contrib/eager/python/examples/resnet50/BUILD b/tensorflow/contrib/eager/python/examples/resnet50/BUILD index a80f3d210a4..f397925e9e1 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/BUILD +++ b/tensorflow/contrib/eager/python/examples/resnet50/BUILD @@ -1,6 +1,7 @@ -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//tensorflow:internal"]) +package( + default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], # Apache 2.0 +) load("//tensorflow:tensorflow.bzl", "cuda_py_test") diff --git a/tensorflow/contrib/eager/python/examples/revnet/BUILD b/tensorflow/contrib/eager/python/examples/revnet/BUILD index a48d08b8a3a..63c0edde775 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/BUILD +++ b/tensorflow/contrib/eager/python/examples/revnet/BUILD @@ -1,6 +1,7 @@ -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//tensorflow:internal"]) +package( + default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], # Apache 2.0 +) load("//tensorflow:tensorflow.bzl", "cuda_py_test") diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD b/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD index aca0b2f05f6..d54ae37192c 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD @@ -1,6 +1,7 @@ -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//tensorflow:internal"]) +package( + default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], # Apache 2.0 +) load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "py_binary") diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD b/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD index ef683ce232b..3232644d4ff 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD @@ -1,6 +1,7 @@ -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//tensorflow:internal"]) +package( + default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], # Apache 2.0 +) load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "py_binary") diff --git a/tensorflow/contrib/eager/python/examples/spinn/BUILD b/tensorflow/contrib/eager/python/examples/spinn/BUILD index 72f1829ffc4..3b676564e4d 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/BUILD +++ b/tensorflow/contrib/eager/python/examples/spinn/BUILD @@ -1,6 +1,7 @@ -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//tensorflow:internal"]) +package( + default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], # Apache 2.0 +) load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "py_test") diff --git a/tensorflow/contrib/eager/python/remote_test.py b/tensorflow/contrib/eager/python/remote_test.py index fb8ae11d6f6..fc78e46a5b1 100644 --- a/tensorflow/contrib/eager/python/remote_test.py +++ b/tensorflow/contrib/eager/python/remote_test.py @@ -23,6 +23,7 @@ import os import numpy as np +from tensorflow.python import pywrap_tensorflow from tensorflow.contrib.eager.python import parameter_server from tensorflow.core.protobuf import cluster_pb2 from tensorflow.core.protobuf import tensorflow_server_pb2 @@ -92,10 +93,11 @@ class RemoteExecutionTest(test.TestCase): def setUp(self): # Start the local server. + local_port = pywrap_tensorflow.TF_PickUnusedPortOrDie() context.set_server_def( server_def=get_server_def( JOB_NAME, - local_server_port=0, + local_server_port=local_port, remote_server_addresses=[ self._cached_server1_target, self._cached_server2_target ], diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index a888379f13e..f1cb596bce0 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -2,10 +2,9 @@ package( default_visibility = [ "//tensorflow:internal", ], + licenses = ["notice"], # Apache 2.0 ) -licenses(["notice"]) # Apache 2.0 - load("//tensorflow:tensorflow.bzl", "py_test") load("//tensorflow:tensorflow.bzl", "cuda_py_test") # PLACEHOLDER PIP REQUIREMENTS diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD index ab510b86d15..f82b9e8dedd 100644 --- a/tensorflow/contrib/factorization/BUILD +++ b/tensorflow/contrib/factorization/BUILD @@ -2,12 +2,13 @@ # Contains ops for factorization of data, including matrix factorization and # clustering. -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) -package(default_visibility = ["//tensorflow:__subpackages__"]) - load("//tensorflow:tensorflow.bzl", "py_test") load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") diff --git a/tensorflow/contrib/factorization/examples/BUILD b/tensorflow/contrib/factorization/examples/BUILD index 363baa121ab..92bcaf870ba 100644 --- a/tensorflow/contrib/factorization/examples/BUILD +++ b/tensorflow/contrib/factorization/examples/BUILD @@ -1,11 +1,12 @@ # Example TensorFlow models using factorization ops. -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) -package(default_visibility = ["//tensorflow:__subpackages__"]) - load("//tensorflow:tensorflow.bzl", "tf_py_test") tf_py_test( diff --git a/tensorflow/contrib/factorization/kernels/BUILD b/tensorflow/contrib/factorization/kernels/BUILD index 23d7e088d06..7b9bef2c989 100644 --- a/tensorflow/contrib/factorization/kernels/BUILD +++ b/tensorflow/contrib/factorization/kernels/BUILD @@ -1,11 +1,12 @@ # OpKernels for data factorization and clustering. -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) -package(default_visibility = ["//tensorflow:__subpackages__"]) - load("//tensorflow:tensorflow.bzl", "tf_cc_test") cc_library( diff --git a/tensorflow/contrib/feature_column/BUILD b/tensorflow/contrib/feature_column/BUILD index edd6f36e07c..9092f19c86e 100644 --- a/tensorflow/contrib/feature_column/BUILD +++ b/tensorflow/contrib/feature_column/BUILD @@ -2,10 +2,9 @@ package( default_visibility = [ "//tensorflow:internal", ], + licenses = ["notice"], # Apache 2.0 ) -licenses(["notice"]) # Apache 2.0 - load("//tensorflow:tensorflow.bzl", "tf_py_test") py_library( diff --git a/tensorflow/contrib/ffmpeg/BUILD b/tensorflow/contrib/ffmpeg/BUILD index f7b3273a4d3..9b47ec8d39a 100644 --- a/tensorflow/contrib/ffmpeg/BUILD +++ b/tensorflow/contrib/ffmpeg/BUILD @@ -1,12 +1,13 @@ # Ops that process audio and/or video files using FFmpeg. # (https://www.ffmpeg.org/) -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) -package(default_visibility = ["//tensorflow:__subpackages__"]) - load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") load("//tensorflow:tensorflow.bzl", "tf_copts") load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") diff --git a/tensorflow/contrib/ffmpeg/decode_audio_op.cc b/tensorflow/contrib/ffmpeg/decode_audio_op.cc index 5ab57ca4cd4..ca65ad45326 100644 --- a/tensorflow/contrib/ffmpeg/decode_audio_op.cc +++ b/tensorflow/contrib/ffmpeg/decode_audio_op.cc @@ -137,18 +137,17 @@ class DecodeAudioOpV2 : public OpKernel { const tensorflow::StringPiece contents = contents_tensor.scalar()(); const string file_format = - str_util::Lowercase(file_format_tensor.scalar()()); + absl::AsciiStrToLower(file_format_tensor.scalar()()); const int32 samples_per_second = samples_per_second_tensor.scalar()(); const int32 channel_count = channel_count_tensor.scalar()(); const std::set valid_file_formats( kValidFileFormats, kValidFileFormats + TF_ARRAYSIZE(kValidFileFormats)); - OP_REQUIRES( - context, valid_file_formats.count(file_format) == 1, - errors::InvalidArgument("file_format must be one of {", - str_util::Join(valid_file_formats, ", "), - "}, but was: \"", file_format, "\"")); + OP_REQUIRES(context, valid_file_formats.count(file_format) == 1, + errors::InvalidArgument("file_format must be one of {", + absl::StrJoin(valid_file_formats, ", "), + "}, but was: \"", file_format, "\"")); OP_REQUIRES(context, samples_per_second > 0, errors::InvalidArgument( "samples_per_second must be positive, but got: ", @@ -220,14 +219,13 @@ class DecodeAudioOp : public OpKernel { public: explicit DecodeAudioOp(OpKernelConstruction* context) : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("file_format", &file_format_)); - file_format_ = str_util::Lowercase(file_format_); + file_format_ = absl::AsciiStrToLower(file_format_); const std::set valid_file_formats( kValidFileFormats, kValidFileFormats + TF_ARRAYSIZE(kValidFileFormats)); - OP_REQUIRES( - context, valid_file_formats.count(file_format_) == 1, - errors::InvalidArgument("file_format must be one of {", - str_util::Join(valid_file_formats, ", "), - "}, but was: \"", file_format_, "\"")); + OP_REQUIRES(context, valid_file_formats.count(file_format_) == 1, + errors::InvalidArgument("file_format must be one of {", + absl::StrJoin(valid_file_formats, ", "), + "}, but was: \"", file_format_, "\"")); OP_REQUIRES_OK(context, context->GetAttr("channel_count", &channel_count_)); OP_REQUIRES(context, channel_count_ > 0, diff --git a/tensorflow/contrib/ffmpeg/default/BUILD b/tensorflow/contrib/ffmpeg/default/BUILD index 59bad8982dd..ec034946c5e 100644 --- a/tensorflow/contrib/ffmpeg/default/BUILD +++ b/tensorflow/contrib/ffmpeg/default/BUILD @@ -2,12 +2,13 @@ # Libraries and kernels for manipulating audio and video using FFmpeg. # (https://www.ffmpeg.org) -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) -package(default_visibility = ["//tensorflow:__subpackages__"]) - load("//tensorflow:tensorflow.bzl", "tf_cc_test") cc_library( diff --git a/tensorflow/contrib/ffmpeg/encode_audio_op.cc b/tensorflow/contrib/ffmpeg/encode_audio_op.cc index c00cccd8461..7de09e062ec 100644 --- a/tensorflow/contrib/ffmpeg/encode_audio_op.cc +++ b/tensorflow/contrib/ffmpeg/encode_audio_op.cc @@ -95,7 +95,7 @@ class EncodeAudioOpV2 : public OpKernel { bits_per_second_tensor.shape().DebugString())); const string file_format = - str_util::Lowercase(file_format_tensor.scalar()()); + absl::AsciiStrToLower(file_format_tensor.scalar()()); const int32 samples_per_second = samples_per_second_tensor.scalar()(); const int32 bits_per_second = bits_per_second_tensor.scalar()(); @@ -157,7 +157,7 @@ class EncodeAudioOp : public OpKernel { public: explicit EncodeAudioOp(OpKernelConstruction* context) : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("file_format", &file_format_)); - file_format_ = str_util::Lowercase(file_format_); + file_format_ = absl::AsciiStrToLower(file_format_); OP_REQUIRES(context, file_format_ == "wav", errors::InvalidArgument("file_format arg must be \"wav\".")); diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD index 91e2954079e..f3385c07745 100644 --- a/tensorflow/contrib/framework/BUILD +++ b/tensorflow/contrib/framework/BUILD @@ -1,13 +1,14 @@ # Description: # contains parts of TensorFlow that are experimental or unstable and which are not supported. -package(default_visibility = [ - "//learning/brain:__subpackages__", - "//tensorflow:__subpackages__", - "//tensorflow_model_optimization:__subpackages__", -]) - -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = [ + "//learning/brain:__subpackages__", + "//tensorflow:__subpackages__", + "//tensorflow_model_optimization:__subpackages__", + ], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) diff --git a/tensorflow/contrib/fused_conv/BUILD b/tensorflow/contrib/fused_conv/BUILD index 2dfbd646a65..4c8c5d90b47 100644 --- a/tensorflow/contrib/fused_conv/BUILD +++ b/tensorflow/contrib/fused_conv/BUILD @@ -4,6 +4,7 @@ package( default_visibility = ["//visibility:private"], + licenses = ["notice"], # Apache 2.0 ) package_group( @@ -13,8 +14,6 @@ package_group( ], ) -licenses(["notice"]) # Apache 2.0 - exports_files(["LICENSE"]) load( @@ -73,14 +72,17 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:logger", "//tensorflow/core:stream_executor", "//tensorflow/core/kernels:bounds_check", "//tensorflow/core/kernels:conv_2d_hdrs", "//tensorflow/core/kernels:conv_ops_gpu_hdrs", + "//tensorflow/core/kernels:cwise_lib_hdrs", "//tensorflow/core/kernels:eigen_contraction_kernel", "//tensorflow/core/kernels:gpu_util_hdrs", "//tensorflow/core/kernels:ops_util_hdrs", "//third_party/eigen3", + "@com_google_absl//absl/time", "@local_config_cuda//cuda:cudnn_header", ], alwayslink = 1, @@ -101,6 +103,7 @@ tf_custom_op_library( "//tensorflow/core/kernels:bounds_check_lib", "//tensorflow/core/kernels:conv_2d_hdrs", "//tensorflow/core/kernels:conv_ops_gpu_hdrs", + "//tensorflow/core/kernels:cwise_lib_hdrs", "//tensorflow/core/kernels:eigen_contraction_kernel", "//tensorflow/core/kernels:gpu_util_hdrs", "//tensorflow/core/kernels:ops_util_hdrs", diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc index 9dda04f3929..c097a2e103c 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_slice.h" #include "tensorflow/core/kernels/conv_2d.h" +#include "tensorflow/core/kernels/cwise_ops.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -90,8 +91,9 @@ struct Int8x4ToInt32 { template class LaunchFusedConv2DBiasActivationOp { - using T = qint8; // conv_input and filter type - using TempT = qint32; // temporary accumulator type for tensor contraction + using T = qint8; // conv_input and filter type + using ComputeT = float; // convert inputs to fp32 for tensor contraction + using TempT = float; // temporary accumulator type for tensor contraction public: void launch(OpKernelContext* ctx, bool cudnn_use_autotune, @@ -106,7 +108,7 @@ class LaunchFusedConv2DBiasActivationOp { // Output tensor has type T (QInt8), but we can only evaluate Int8 Tensor // contraction using 32-bit accumulation (QInt32). - Tensor temp_output(DT_QINT32, output->shape()); + Tensor temp_output(DataTypeToEnum::value, output->shape()); constexpr int32 row_dilation = 1; constexpr int32 col_dilation = 1; @@ -132,7 +134,8 @@ class LaunchFusedConv2DBiasActivationOp { auto in0 = conv_input.shaped({conv_width, filter.dim_size(2)}); auto in1 = filter.shaped({filter.dim_size(2), filter.dim_size(3)}); - out.device(device) = in0.contract(in1, dim_pair, output_kernel); + out.device(device) = in0.cast().contract( + in1.cast(), dim_pair, output_kernel); } else if (filter.dim_size(0) == conv_input.dim_size(1) && filter.dim_size(1) == conv_input.dim_size(2) && @@ -151,7 +154,8 @@ class LaunchFusedConv2DBiasActivationOp { auto in0 = conv_input.shaped({conv_input.dim_size(0), k}); auto in1 = filter.shaped({k, filter.dim_size(3)}); - out.device(device) = in0.contract(in1, dim_pair, output_kernel); + out.device(device) = in0.cast().contract( + in1.cast(), dim_pair, output_kernel); } else { auto out = temp_output.tensor(); @@ -159,9 +163,9 @@ class LaunchFusedConv2DBiasActivationOp { auto in1 = filter.tensor(); // Need to swap row/col when calling Eigen. - out.device(device) = - Eigen::SpatialConvolution(in0, in1, col_stride, row_stride, padding, - col_dilation, row_dilation, output_kernel); + out.device(device) = Eigen::SpatialConvolution( + in0.cast(), in1.cast(), col_stride, row_stride, + padding, col_dilation, row_dilation, output_kernel); } } @@ -219,23 +223,31 @@ class LaunchFusedConv2DBiasActivationOp { typename TTypes::UnalignedTensor output(output_base + col * stride, num_rows); - auto conv_output_scaled = - conv_output.cast() * conv_input_scale; + // TODO(ezhulenev): No-op cast optimization in Eigen cause dangling + // references and segfaults. + static_assert(std::is_same::value, + "Must use 'conv_output.cast()'"); + auto conv_output_scaled = conv_output * conv_input_scale; + ScaleType lower_bound = (activation_mode == ActivationMode::NONE ? static_cast(kMinRange) : 0); if (side_input_scale == 0.0f) { - output = (conv_output_scaled + bias) - .round() - .clip(lower_bound, static_cast(kMaxRange)) - .template cast(); + output = + (conv_output_scaled + bias) + // scalar_round_op_google uses HALF_TO_EVEN. + .unaryExpr(Eigen::internal::scalar_round_op_google()) + .clip(lower_bound, static_cast(kMaxRange)) + .template cast(); } else { auto side_input_scaled = side_input.cast() * side_input_scale; - output = (conv_output_scaled + bias + side_input_scaled) - .round() - .clip(lower_bound, static_cast(kMaxRange)) - .template cast(); + output = + (conv_output_scaled + bias + side_input_scaled) + // scalar_round_op_google uses HALF_TO_EVEN. + .unaryExpr(Eigen::internal::scalar_round_op_google()) + .clip(lower_bound, static_cast(kMaxRange)) + .template cast(); } } } diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test_base.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test_base.py index 04edc7593a2..640a6b00965 100644 --- a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test_base.py +++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test_base.py @@ -966,6 +966,37 @@ class FusedConvInt8CPUTests(object): for test_param in self._test_params: self.runTest(test_param, apply_relu) + def testRoundingMode(self): + """Verify the fused convolution op uses half-to-even rounding mode.""" + batches = 1 + input_size = 2 + input_channels = 1 + output_channels = 1 + conv_input = np.array([1, 2, 3, 4]).reshape( + (batches, input_size, input_size, input_channels)).astype(np.int8) + kernel = np.array([1]).reshape( + (1, 1, input_channels, output_channels)).astype(np.int8) + biases = np.zeros((output_channels)).astype(np.float32) + + with self.session() as sess, self.test_scope(): + actual = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( + math_ops.cast(conv_input, dtypes.qint8), + math_ops.cast(kernel, dtypes.qint8), + biases, + strides=[1, 1, 1, 1], + padding="SAME", + conv_input_scale=0.5, + side_input_scale=0.0, + activation_mode="None", + data_format="NHWC", + filter_format="HWIO") + actual_value = sess.run(actual) + # The convolution output scaled is [0.5, 1.0, 1.5, 2.0]. After rounding + # half to even, the final output is [0, 1, 2, 2]. + self.assertTrue( + np.array_equal(actual_value.flatten(), + np.array([0, 1, 2, 2]).astype(np.int8))) + # Test that GPU and CPU kernels produce identical results for QInt8 data type. class FusedConvInt8CorrespondenceTests(object): diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD index 3165e007996..ddd04947e9b 100644 --- a/tensorflow/contrib/gan/BUILD +++ b/tensorflow/contrib/gan/BUILD @@ -2,11 +2,12 @@ load("//tensorflow:tensorflow.bzl", "py_test") -package(default_visibility = [ - "//tensorflow:__subpackages__", -]) - -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = [ + "//tensorflow:__subpackages__", + ], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) diff --git a/tensorflow/contrib/gdr/BUILD b/tensorflow/contrib/gdr/BUILD index bf8b66dcfa5..797d0cdad3e 100644 --- a/tensorflow/contrib/gdr/BUILD +++ b/tensorflow/contrib/gdr/BUILD @@ -1,11 +1,12 @@ # Description: # GPU Direct RDMA Out-of-Band Tensor transport for TensorFlow. -package(default_visibility = [ - "//tensorflow:__subpackages__", -]) - -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = [ + "//tensorflow:__subpackages__", + ], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) diff --git a/tensorflow/contrib/graph_editor/BUILD b/tensorflow/contrib/graph_editor/BUILD index 35b6e638763..180cf69b07f 100644 --- a/tensorflow/contrib/graph_editor/BUILD +++ b/tensorflow/contrib/graph_editor/BUILD @@ -1,12 +1,13 @@ # Description: # contains parts of TensorFlow that are experimental or unstable and which are not supported. -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) -package(default_visibility = ["//tensorflow:__subpackages__"]) - load("//tensorflow:tensorflow.bzl", "py_test") py_library( diff --git a/tensorflow/contrib/grid_rnn/BUILD b/tensorflow/contrib/grid_rnn/BUILD index d0b44640667..126078ae791 100644 --- a/tensorflow/contrib/grid_rnn/BUILD +++ b/tensorflow/contrib/grid_rnn/BUILD @@ -2,12 +2,13 @@ # Contains classes to construct GridRNN cells # APIs here are meant to evolve over time. -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) -package(default_visibility = ["//tensorflow:__subpackages__"]) - load("//tensorflow:tensorflow.bzl", "cuda_py_tests") py_library( diff --git a/tensorflow/contrib/hadoop/BUILD b/tensorflow/contrib/hadoop/BUILD index 178a8a6f084..87db7ea3b71 100644 --- a/tensorflow/contrib/hadoop/BUILD +++ b/tensorflow/contrib/hadoop/BUILD @@ -1,6 +1,7 @@ -package(default_visibility = ["//tensorflow:internal"]) - -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) diff --git a/tensorflow/contrib/hooks/BUILD b/tensorflow/contrib/hooks/BUILD index d65b2d6026d..78fd9aaab82 100644 --- a/tensorflow/contrib/hooks/BUILD +++ b/tensorflow/contrib/hooks/BUILD @@ -2,12 +2,13 @@ # Contains `SessionRunHook`s for use with `MonitoredSession` and the # wrappers around it. -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) -package(default_visibility = ["//tensorflow:__subpackages__"]) - load("//tensorflow:tensorflow.bzl", "py_test") py_library( diff --git a/tensorflow/contrib/hvx/clock_cycle_profiling/BUILD b/tensorflow/contrib/hvx/clock_cycle_profiling/BUILD index e39c60b252a..dc0eea3f2e5 100644 --- a/tensorflow/contrib/hvx/clock_cycle_profiling/BUILD +++ b/tensorflow/contrib/hvx/clock_cycle_profiling/BUILD @@ -1,18 +1,19 @@ # Description: # contains parts of TensorFlow that are experimental or unstable and which are not supported. -licenses(["notice"]) # Apache 2.0 - load( "//tensorflow:tensorflow.bzl", - "tf_copts", "tf_cc_binary", + "tf_copts", +) + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 ) exports_files(["LICENSE"]) -package(default_visibility = ["//visibility:public"]) - tf_cc_binary( name = "clock_cycle_profiling", testonly = 1, diff --git a/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD b/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD index 92016e6a839..76ee1cc3b39 100644 --- a/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD +++ b/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD @@ -2,9 +2,10 @@ # Contains a tool to dump TensorFlow ops which are not supported # in TensorFlow HVX runtime. -package(default_visibility = ["//visibility:public"]) - -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) diff --git a/tensorflow/contrib/image/BUILD b/tensorflow/contrib/image/BUILD index c9d917fe20d..dfc1746f533 100755 --- a/tensorflow/contrib/image/BUILD +++ b/tensorflow/contrib/image/BUILD @@ -1,12 +1,13 @@ # Description: # Contains ops for image manipulation. -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) -package(default_visibility = ["//visibility:public"]) - load( "//tensorflow:tensorflow.bzl", "tf_cc_test", diff --git a/tensorflow/contrib/input_pipeline/BUILD b/tensorflow/contrib/input_pipeline/BUILD index cf786c062ea..777399184e8 100644 --- a/tensorflow/contrib/input_pipeline/BUILD +++ b/tensorflow/contrib/input_pipeline/BUILD @@ -17,9 +17,10 @@ load( ) load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") -package(default_visibility = ["//visibility:public"]) - -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) diff --git a/tensorflow/contrib/input_pipeline/kernels/BUILD b/tensorflow/contrib/input_pipeline/kernels/BUILD index 797605b8fe6..64b614651a1 100644 --- a/tensorflow/contrib/input_pipeline/kernels/BUILD +++ b/tensorflow/contrib/input_pipeline/kernels/BUILD @@ -1,12 +1,13 @@ # Description: # Contains kernels for the input pipeline. -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) -package(default_visibility = ["//tensorflow:__subpackages__"]) - cc_library( name = "input_pipeline_kernels", srcs = ["input_pipeline_kernels.cc"], diff --git a/tensorflow/contrib/integrate/BUILD b/tensorflow/contrib/integrate/BUILD index 9a2c94446fd..3cb268affac 100644 --- a/tensorflow/contrib/integrate/BUILD +++ b/tensorflow/contrib/integrate/BUILD @@ -1,12 +1,13 @@ # Description: # Integration and ODE solvers for TensorFlow. -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) -package(default_visibility = ["//tensorflow:__subpackages__"]) - load("//tensorflow:tensorflow.bzl", "py_test") py_library( diff --git a/tensorflow/contrib/keras/BUILD b/tensorflow/contrib/keras/BUILD index 7a4cab20d1a..a839693340e 100644 --- a/tensorflow/contrib/keras/BUILD +++ b/tensorflow/contrib/keras/BUILD @@ -2,12 +2,13 @@ # Contains the Keras API (internal TensorFlow version). # Note that tf.contrib.keras has been deprecated in favor of tf.keras. -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) -package(default_visibility = ["//tensorflow:__subpackages__"]) - py_library( name = "keras", srcs = [ diff --git a/tensorflow/contrib/kernel_methods/BUILD b/tensorflow/contrib/kernel_methods/BUILD index 833771eda0f..71c7bf99804 100644 --- a/tensorflow/contrib/kernel_methods/BUILD +++ b/tensorflow/contrib/kernel_methods/BUILD @@ -1,12 +1,13 @@ # Description: # Contains kernel methods for TensorFlow. -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) -package(default_visibility = ["//tensorflow:__subpackages__"]) - load("//tensorflow:tensorflow.bzl", "py_test") py_library( diff --git a/tensorflow/contrib/kinesis/kernels/kinesis_dataset_ops.cc b/tensorflow/contrib/kinesis/kernels/kinesis_dataset_ops.cc index 95c7001371a..f24d091c3f8 100644 --- a/tensorflow/contrib/kinesis/kernels/kinesis_dataset_ops.cc +++ b/tensorflow/contrib/kinesis/kernels/kinesis_dataset_ops.cc @@ -42,7 +42,7 @@ Aws::Client::ClientConfiguration* InitializeDefaultClientConfig() { // is set with a truthy value. const char* load_config_env = getenv("AWS_SDK_LOAD_CONFIG"); string load_config = - load_config_env ? str_util::Lowercase(load_config_env) : ""; + load_config_env ? absl::AsciiStrToLower(load_config_env) : ""; if (load_config == "true" || load_config == "1") { Aws::String config_file; // If AWS_CONFIG_FILE is set then use it, otherwise use ~/.aws/config. diff --git a/tensorflow/contrib/labeled_tensor/BUILD b/tensorflow/contrib/labeled_tensor/BUILD index fb28d6689a6..da5d8f6b4e2 100644 --- a/tensorflow/contrib/labeled_tensor/BUILD +++ b/tensorflow/contrib/labeled_tensor/BUILD @@ -1,12 +1,13 @@ # Description: # Labels for TensorFlow. -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) -package(default_visibility = ["//tensorflow:__subpackages__"]) - load("//tensorflow:tensorflow.bzl", "py_test") py_library( diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD index c6f6e722a4f..46040c64d43 100644 --- a/tensorflow/contrib/layers/BUILD +++ b/tensorflow/contrib/layers/BUILD @@ -1,13 +1,14 @@ # Description: # contains parts of TensorFlow that are experimental or unstable and which are not supported. -package(default_visibility = [ - "//learning/brain:__subpackages__", - "//tensorflow:__subpackages__", - "//tensorflow_model_optimization:__subpackages__", -]) - -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = [ + "//learning/brain:__subpackages__", + "//tensorflow:__subpackages__", + "//tensorflow_model_optimization:__subpackages__", + ], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) diff --git a/tensorflow/contrib/layers/kernels/BUILD b/tensorflow/contrib/layers/kernels/BUILD index 7aae09ff3e9..187a3a92d73 100644 --- a/tensorflow/contrib/layers/kernels/BUILD +++ b/tensorflow/contrib/layers/kernels/BUILD @@ -1,12 +1,13 @@ # Description: # Contains kernels for layers. -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) -package(default_visibility = ["//tensorflow:__subpackages__"]) - cc_library( name = "sparse_feature_cross_kernel", srcs = ["sparse_feature_cross_kernel.cc"], diff --git a/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc b/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc index 01893d60615..ee4b0373ef7 100644 --- a/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc +++ b/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc @@ -182,7 +182,7 @@ class StringCrosser { } // TODO(zakaria): this will copy the string twice, might effect // performance. - return str_util::Join(cross_vec, k_feature_separator); + return absl::StrJoin(cross_vec, k_feature_separator); } private: diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 7507e1fffa6..bb3d73f7a17 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -2922,7 +2922,7 @@ def spatial_softmax(features, First computes the softmax over the spatial extent of each channel of a convolutional feature map. Then computes the expected 2D position of the points of maximal activation for each channel, resulting in a set of - feature keypoints [x1, y1, ... xN, yN] for all N channels. + feature keypoints [i1, j1, ... iN, jN] for all N channels. Read more here: "Learning visual feature spaces for robotic manipulation with @@ -2943,7 +2943,7 @@ def spatial_softmax(features, feature_keypoints: A `Tensor` with size [batch_size, num_channels * 2]; the expected 2D locations of each channel's feature keypoint (normalized to the range (-1,1)). The inner dimension is arranged as - [x1, y1, ... xN, yN]. + [i1, j1, ... iN, jN]. Raises: ValueError: If unexpected data_format specified. ValueError: If num_channels dimension is unspecified. diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index 1d0cac308f3..0a34e91e33f 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -1,17 +1,17 @@ # Description: # Contains TF Learn (aka Scikit Flow) sub-project with high level tensorflow API. -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = [ + "//engedu/ml/tf_from_scratch:__pkg__", + "//tensorflow:internal", + ], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) load("//tensorflow:tensorflow.bzl", "py_test") - -package(default_visibility = [ - "//engedu/ml/tf_from_scratch:__pkg__", - "//tensorflow:internal", -]) - load("//tensorflow:tensorflow.bzl", "py_test", "tf_py_test") py_library( @@ -112,6 +112,7 @@ py_test( name = "data_feeder_test", size = "small", srcs = ["python/learn/learn_io/data_feeder_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":learn", @@ -127,6 +128,7 @@ py_test( name = "estimators_test", size = "small", srcs = ["python/learn/estimators/estimators_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":learn", @@ -145,6 +147,7 @@ py_test( name = "metric_spec_test", size = "small", srcs = ["python/learn/metric_spec_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":learn", @@ -174,6 +177,7 @@ py_test( name = "export_strategy_test", size = "small", srcs = ["python/learn/export_strategy_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":learn", @@ -185,6 +189,7 @@ py_test( name = "graph_actions_test", size = "small", srcs = ["python/learn/graph_actions_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", tags = ["no_windows"], # TODO: needs investigation on Windows deps = [ @@ -208,6 +213,7 @@ py_test( name = "learn_runner_test", size = "small", srcs = ["python/learn/learn_runner_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":learn", @@ -222,6 +228,7 @@ py_test( name = "monitors_test", size = "small", srcs = ["python/learn/monitors_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", tags = ["no_pip_gpu"], # b/74437598 deps = [ @@ -247,6 +254,7 @@ py_test( name = "run_config_test", size = "small", srcs = ["python/learn/estimators/run_config_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":learn", @@ -260,6 +268,7 @@ py_test( py_test( name = "tensor_signature_test", srcs = ["python/learn/estimators/tensor_signature_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", tags = [ "manual", # b/130760310 @@ -277,6 +286,7 @@ py_test( name = "estimator_test", size = "medium", srcs = ["python/learn/estimators/estimator_test.py"], + python_version = "PY2", shard_count = 2, srcs_version = "PY2AND3", tags = [ @@ -321,6 +331,7 @@ py_test( name = "estimator_input_test", size = "medium", srcs = ["python/learn/estimators/estimator_input_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":learn", @@ -342,6 +353,7 @@ py_test( name = "logistic_regressor_test", size = "small", srcs = ["python/learn/estimators/logistic_regressor_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":learn", @@ -362,6 +374,7 @@ py_test( name = "dnn_linear_combined_test", size = "medium", srcs = ["python/learn/estimators/dnn_linear_combined_test.py"], + python_version = "PY2", shard_count = 8, srcs_version = "PY2AND3", tags = ["no_oss"], # flaky b/70524820 @@ -387,6 +400,7 @@ py_test( name = "head_test", size = "medium", srcs = ["python/learn/estimators/head_test.py"], + python_version = "PY2", shard_count = 4, srcs_version = "PY2AND3", tags = ["noasan"], # times out b/63678675 @@ -417,6 +431,7 @@ py_test( name = "dnn_test", size = "medium", srcs = ["python/learn/estimators/dnn_test.py"], + python_version = "PY2", shard_count = 4, srcs_version = "PY2AND3", tags = ["notap"], @@ -441,6 +456,7 @@ py_test( name = "kmeans_test", size = "medium", srcs = ["python/learn/estimators/kmeans_test.py"], + python_version = "PY2", shard_count = 4, srcs_version = "PY2AND3", tags = [ @@ -467,6 +483,7 @@ py_test( name = "dynamic_rnn_estimator_test", size = "medium", srcs = ["python/learn/estimators/dynamic_rnn_estimator_test.py"], + python_version = "PY2", shard_count = 4, srcs_version = "PY2AND3", deps = [ @@ -493,6 +510,7 @@ py_test( name = "state_saving_rnn_estimator_test", size = "medium", srcs = ["python/learn/estimators/state_saving_rnn_estimator_test.py"], + python_version = "PY2", shard_count = 4, srcs_version = "PY2AND3", tags = ["noasan"], @@ -517,6 +535,7 @@ py_test( name = "linear_test", size = "medium", srcs = ["python/learn/estimators/linear_test.py"], + python_version = "PY2", shard_count = 20, srcs_version = "PY2AND3", tags = ["no_pip"], @@ -541,6 +560,7 @@ py_test( name = "debug_test", size = "medium", srcs = ["python/learn/estimators/debug_test.py"], + python_version = "PY2", shard_count = 4, srcs_version = "PY2AND3", deps = [ @@ -563,6 +583,7 @@ py_test( name = "composable_model_test", size = "medium", srcs = ["python/learn/estimators/composable_model_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":learn", @@ -580,6 +601,7 @@ py_test( name = "svm_test", size = "medium", srcs = ["python/learn/estimators/svm_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":learn", @@ -594,6 +616,7 @@ py_test( name = "grid_search_test", size = "small", srcs = ["python/learn/grid_search_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":learn", @@ -605,6 +628,7 @@ py_test( name = "io_test", size = "small", srcs = ["python/learn/learn_io/io_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":learn", @@ -618,6 +642,7 @@ py_test( name = "model_fn_test", size = "small", srcs = ["python/learn/estimators/model_fn_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":learn", @@ -636,7 +661,11 @@ py_test( name = "multioutput_test", size = "small", srcs = ["python/learn/estimators/multioutput_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", + tags = [ + "no_oss", + ], deps = [ ":learn", "//tensorflow/python:client_testlib", @@ -648,6 +677,7 @@ py_test( name = "nonlinear_test", size = "medium", srcs = ["python/learn/estimators/nonlinear_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":learn", @@ -662,6 +692,7 @@ py_test( name = "regression_test", size = "small", srcs = ["python/learn/estimators/regression_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":learn", @@ -674,6 +705,7 @@ py_test( name = "rnn_common_test", size = "medium", srcs = ["python/learn/estimators/rnn_common_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":learn", @@ -688,6 +720,7 @@ py_test( name = "ops_test", size = "small", srcs = ["python/learn/ops/ops_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":learn", @@ -705,6 +738,7 @@ py_test( name = "seq2seq_ops_test", size = "small", srcs = ["python/learn/ops/seq2seq_ops_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":learn", @@ -720,6 +754,7 @@ py_test( name = "categorical_test", size = "small", srcs = ["python/learn/preprocessing/tests/categorical_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":learn", @@ -732,6 +767,7 @@ py_test( name = "categorical_vocabulary_test", size = "small", srcs = ["python/learn/preprocessing/tests/categorical_vocabulary_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":learn", @@ -743,6 +779,7 @@ py_test( name = "text_test", size = "small", srcs = ["python/learn/preprocessing/tests/text_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":learn", @@ -777,6 +814,7 @@ py_test( name = "pandas_io_test", size = "small", srcs = ["python/learn/learn_io/pandas_io_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":learn", @@ -792,6 +830,7 @@ py_test( size = "small", timeout = "moderate", srcs = ["python/learn/utils/export_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", tags = [ "manual", # http://b/31032996 @@ -819,6 +858,7 @@ py_test( name = "gc_test", size = "small", srcs = ["python/learn/utils/gc_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":learn", @@ -834,6 +874,7 @@ py_test( name = "saved_model_export_utils_test", size = "small", srcs = ["python/learn/utils/saved_model_export_utils_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", tags = ["no_windows"], # TODO: needs investigation on Windows deps = [ @@ -854,6 +895,7 @@ py_test( name = "input_fn_utils_test", size = "small", srcs = ["python/learn/utils/input_fn_utils_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":learn", @@ -867,6 +909,7 @@ py_test( name = "stability_test", size = "small", srcs = ["python/learn/estimators/stability_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":learn", @@ -884,6 +927,7 @@ py_test( py_binary( name = "inspect_checkpoint", srcs = ["python/learn/utils/inspect_checkpoint.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ "//tensorflow/contrib/framework:framework_py", diff --git a/tensorflow/contrib/learn/python/learn/datasets/BUILD b/tensorflow/contrib/learn/python/learn/datasets/BUILD index d6a43ee3a69..c872c55c6b8 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/BUILD +++ b/tensorflow/contrib/learn/python/learn/datasets/BUILD @@ -1,8 +1,9 @@ # Prepare training and testing data. -package(default_visibility = ["//tensorflow:internal"]) - -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) diff --git a/tensorflow/contrib/legacy_seq2seq/BUILD b/tensorflow/contrib/legacy_seq2seq/BUILD index 4ce91a140f8..8974f85a209 100644 --- a/tensorflow/contrib/legacy_seq2seq/BUILD +++ b/tensorflow/contrib/legacy_seq2seq/BUILD @@ -2,12 +2,13 @@ # Contains library to create sequence-to-sequence models on top of TensorFlow. # APIs here are meant to evolve over time. -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) -package(default_visibility = ["//visibility:public"]) - load("//tensorflow:tensorflow.bzl", "cuda_py_tests") py_library( diff --git a/tensorflow/contrib/libsvm/BUILD b/tensorflow/contrib/libsvm/BUILD index 4dccb9be7cd..7d83dc5818a 100644 --- a/tensorflow/contrib/libsvm/BUILD +++ b/tensorflow/contrib/libsvm/BUILD @@ -1,9 +1,8 @@ package( default_visibility = ["//visibility:private"], + licenses = ["notice"], # Apache 2.0 ) -licenses(["notice"]) # Apache 2.0 - exports_files(["LICENSE"]) load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") diff --git a/tensorflow/contrib/linear_optimizer/BUILD b/tensorflow/contrib/linear_optimizer/BUILD index ec0cbf92dd2..db81ed7057d 100644 --- a/tensorflow/contrib/linear_optimizer/BUILD +++ b/tensorflow/contrib/linear_optimizer/BUILD @@ -2,12 +2,13 @@ # Contains ops to train linear models on top of TensorFlow. # APIs here are meant to evolve over time. -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) -package(default_visibility = ["//tensorflow:__subpackages__"]) - load("//tensorflow:tensorflow.bzl", "py_test") py_library( diff --git a/tensorflow/contrib/lookup/BUILD b/tensorflow/contrib/lookup/BUILD index 83e80f25bcf..c4053ba9679 100644 --- a/tensorflow/contrib/lookup/BUILD +++ b/tensorflow/contrib/lookup/BUILD @@ -1,12 +1,13 @@ # Description: # contains parts of TensorFlow that are experimental or unstable and which are not supported. -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) -package(default_visibility = ["//tensorflow:internal"]) - load("//tensorflow:tensorflow.bzl", "tf_py_test") # TODO(yleon): Refactor after one we switching to the V2 kernels. diff --git a/tensorflow/contrib/losses/BUILD b/tensorflow/contrib/losses/BUILD index c51b651d1a4..4861bdab15b 100644 --- a/tensorflow/contrib/losses/BUILD +++ b/tensorflow/contrib/losses/BUILD @@ -1,12 +1,13 @@ # Description: # contains parts of TensorFlow that are experimental or unstable and which are not supported. -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) -package(default_visibility = ["//tensorflow:__subpackages__"]) - load("//tensorflow:tensorflow.bzl", "py_test") py_library( diff --git a/tensorflow/contrib/makefile/BUILD b/tensorflow/contrib/makefile/BUILD index 1abb46f4d41..afd6a785705 100644 --- a/tensorflow/contrib/makefile/BUILD +++ b/tensorflow/contrib/makefile/BUILD @@ -1,5 +1,6 @@ # Necessary build rules for makefile build in our CI. -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//visibility:private"]) +package( + default_visibility = ["//visibility:private"], + licenses = ["notice"], # Apache 2.0 +) diff --git a/tensorflow/contrib/makefile/Makefile b/tensorflow/contrib/makefile/Makefile index 13f84313314..ba0ea348ef8 100644 --- a/tensorflow/contrib/makefile/Makefile +++ b/tensorflow/contrib/makefile/Makefile @@ -109,7 +109,7 @@ $(HOST_NSYNC_LIB) \ # If we're on Linux, also link in the dl library. ifeq ($(HOST_OS),LINUX) - HOST_LIBS += -ldl -lpthread + HOST_LIBS += -ldl -lpthread -lrt endif # If we're on a Pi, link in pthreads and dl @@ -259,7 +259,7 @@ endif endif # If we're on Linux, also link in the dl library. ifeq ($(TARGET),LINUX) - LIBS += -ldl -lpthread + LIBS += -ldl -lpthread -lrt endif # If we're cross-compiling for the Raspberry Pi, use the right gcc. ifeq ($(TARGET),PI) @@ -636,6 +636,8 @@ CORE_CC_ALL_SRCS := \ $(ABSL_CC_SRCS) \ tensorflow/c/c_api.cc \ tensorflow/c/kernels.cc \ +tensorflow/c/tf_datatype.cc \ +tensorflow/c/tf_status.cc \ tensorflow/c/tf_status_helper.cc \ $(wildcard tensorflow/core/*.cc) \ $(wildcard tensorflow/core/common_runtime/*.cc) \ diff --git a/tensorflow/contrib/makefile/download_dependencies.sh b/tensorflow/contrib/makefile/download_dependencies.sh index 7566733680c..c41513a9096 100755 --- a/tensorflow/contrib/makefile/download_dependencies.sh +++ b/tensorflow/contrib/makefile/download_dependencies.sh @@ -26,7 +26,7 @@ if [ ! -f $BZL_FILE_PATH ]; then exit 1; fi -EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v mirror.bazel | head -n1)" +EIGEN_URL="$(grep -o 'https://bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v mirror.bazel | head -n1)" GEMMLOWP_URL="$(grep -o 'http://mirror.tensorflow.org/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz" NSYNC_URL="$(grep -o 'http://mirror.tensorflow.org/github.com/google/nsync/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" diff --git a/tensorflow/contrib/memory_stats/BUILD b/tensorflow/contrib/memory_stats/BUILD index 93701249cc8..0a35cb78704 100644 --- a/tensorflow/contrib/memory_stats/BUILD +++ b/tensorflow/contrib/memory_stats/BUILD @@ -1,12 +1,13 @@ # Description: # Ops that get statistics on memory allocators. -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) -package(default_visibility = ["//tensorflow:__subpackages__"]) - load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") diff --git a/tensorflow/contrib/meta_graph_transform/BUILD b/tensorflow/contrib/meta_graph_transform/BUILD index d667b8e1449..a4228beb6e6 100644 --- a/tensorflow/contrib/meta_graph_transform/BUILD +++ b/tensorflow/contrib/meta_graph_transform/BUILD @@ -1,9 +1,10 @@ # Description: # Utility for applying the Graph Transform tool to a MetaGraphDef. -package(default_visibility = ["//visibility:public"]) - -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) diff --git a/tensorflow/contrib/metrics/BUILD b/tensorflow/contrib/metrics/BUILD index 858fd1ede45..9615f65ab1d 100644 --- a/tensorflow/contrib/metrics/BUILD +++ b/tensorflow/contrib/metrics/BUILD @@ -2,15 +2,16 @@ # Contains ops for evaluation metrics and summary statistics. # APIs here are meant to evolve over time. -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = [ + "//engedu/ml/tf_from_scratch:__pkg__", + "//tensorflow:internal", + ], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) -package(default_visibility = [ - "//engedu/ml/tf_from_scratch:__pkg__", - "//tensorflow:internal", -]) - load("//tensorflow:tensorflow.bzl", "py_test") py_library( diff --git a/tensorflow/contrib/mixed_precision/BUILD b/tensorflow/contrib/mixed_precision/BUILD index 3dfb95e0a00..5b41eed73f3 100644 --- a/tensorflow/contrib/mixed_precision/BUILD +++ b/tensorflow/contrib/mixed_precision/BUILD @@ -2,10 +2,9 @@ package( default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], # Apache 2.0 ) -licenses(["notice"]) # Apache 2.0 - exports_files(["LICENSE"]) filegroup( diff --git a/tensorflow/contrib/mixed_precision/python/BUILD b/tensorflow/contrib/mixed_precision/python/BUILD index 39821399fc9..de1ac08bfe8 100644 --- a/tensorflow/contrib/mixed_precision/python/BUILD +++ b/tensorflow/contrib/mixed_precision/python/BUILD @@ -2,10 +2,9 @@ package( default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], # Apache 2.0 ) -licenses(["notice"]) # Apache 2.0 - exports_files(["LICENSE"]) load("//tensorflow:tensorflow.bzl", "py_test") diff --git a/tensorflow/contrib/model_pruning/BUILD b/tensorflow/contrib/model_pruning/BUILD index ce77143e0c3..00a625ff2b8 100644 --- a/tensorflow/contrib/model_pruning/BUILD +++ b/tensorflow/contrib/model_pruning/BUILD @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -package(default_visibility = ["//tensorflow:__subpackages__"]) - -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) load("//tensorflow:tensorflow.bzl", "py_test") diff --git a/tensorflow/contrib/model_pruning/README.md b/tensorflow/contrib/model_pruning/README.md index 98760ea7050..01a58fdcdea 100644 --- a/tensorflow/contrib/model_pruning/README.md +++ b/tensorflow/contrib/model_pruning/README.md @@ -50,7 +50,7 @@ The pruning library allows for specification of the following hyper parameters: | name | string | model_pruning | Name of the pruning specification. Used for adding summaries and ops under a common tensorflow name_scope | | begin_pruning_step | integer | 0 | The global step at which to begin pruning | | end_pruning_step | integer | -1 | The global step at which to terminate pruning. Defaults to -1 implying that pruning continues till the training stops | -| weight_sparsity_map | list of strings | [""] | list of weight variable name (or layer name):target sparsity pairs. Eg. [conv1:0.9,conv2/kernel:0.8]. For layers/weights not in this list, sparsity as specified by the target_sparsity hyperparameter is used. | +| weight_sparsity_map | list of strings | [""] | list of weight variable name regex (or layer name regex):target sparsity pairs. Eg. [conv1:0.9,conv.*/kernel:0.8]. For layers/weights not in this list, sparsity as specified by the target_sparsity hyperparameter is used. | | threshold_decay | float | 0.0 | The decay factor to use for exponential decay of the thresholds | | pruning_frequency | integer | 10 | How often should the masks be updated? (in # of global_steps) | | block_height|integer | 1 | Number of rows in a block for block sparse matrices| diff --git a/tensorflow/contrib/model_pruning/examples/cifar10/BUILD b/tensorflow/contrib/model_pruning/examples/cifar10/BUILD index 805a6eab236..d75211086e6 100644 --- a/tensorflow/contrib/model_pruning/examples/cifar10/BUILD +++ b/tensorflow/contrib/model_pruning/examples/cifar10/BUILD @@ -19,10 +19,9 @@ package( default_visibility = [ "//tensorflow:internal", ], + licenses = ["notice"], # Apache 2.0 ) -licenses(["notice"]) # Apache 2.0 - py_library( name = "cifar10_input", srcs = ["cifar10_input.py"], diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py index 9966f7cf798..85fcbad26c7 100644 --- a/tensorflow/contrib/model_pruning/python/pruning.py +++ b/tensorflow/contrib/model_pruning/python/pruning.py @@ -60,6 +60,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import re + from tensorflow.contrib.model_pruning.python import pruning_utils from tensorflow.contrib.model_pruning.python.layers import core_layers as core from tensorflow.contrib.training.python.training import hparam @@ -153,7 +155,7 @@ def get_pruning_hparams(): the global step at which to terminate pruning. Defaults to -1 implying that pruning continues till the training stops weight_sparsity_map: list of strings - comma separed list of weight variable name:target sparsity pairs. + comma separed list of weight variable name regex:target sparsity pairs. For layers/weights not in this list, sparsity as specified by the target_sparsity hyperparameter is used. Eg. [conv1:0.9,conv2/kernel:0.8] @@ -355,8 +357,8 @@ class Pruning(object): def _get_sparsity(self, weight_name): """Return target sparsity for the given layer/weight name.""" target_sparsity = [ - sparsity for name, sparsity in self._weight_sparsity_map.items() - if weight_name.find(name) != -1 + sparsity for regexp, sparsity in self._weight_sparsity_map.items() + if re.match(regexp, weight_name) ] if not target_sparsity: return self._sparsity diff --git a/tensorflow/contrib/nearest_neighbor/BUILD b/tensorflow/contrib/nearest_neighbor/BUILD index 6fa76244670..4d74da7962d 100644 --- a/tensorflow/contrib/nearest_neighbor/BUILD +++ b/tensorflow/contrib/nearest_neighbor/BUILD @@ -1,9 +1,10 @@ # Description: # Tensorflow ops for nearest neighbor queries etc. -package(default_visibility = ["//tensorflow:__subpackages__"]) - -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) diff --git a/tensorflow/contrib/nn/BUILD b/tensorflow/contrib/nn/BUILD index e3e36c4fdf5..5fe1396bced 100644 --- a/tensorflow/contrib/nn/BUILD +++ b/tensorflow/contrib/nn/BUILD @@ -1,12 +1,13 @@ # Description: # Contains deprecated ops to calculate cross entropy. -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) -package(default_visibility = ["//visibility:public"]) - load("//tensorflow:tensorflow.bzl", "py_test") py_library( diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD index 6c85533d774..63eb73940c4 100644 --- a/tensorflow/contrib/opt/BUILD +++ b/tensorflow/contrib/opt/BUILD @@ -1,12 +1,13 @@ # Description: # Optimization routines. -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) -package(default_visibility = ["//tensorflow:__subpackages__"]) - load("//tensorflow:tensorflow.bzl", "py_test") load("//tensorflow:tensorflow.bzl", "tf_py_test") diff --git a/tensorflow/contrib/optimizer_v2/BUILD b/tensorflow/contrib/optimizer_v2/BUILD index 6e401406308..8ecc0b09f4d 100644 --- a/tensorflow/contrib/optimizer_v2/BUILD +++ b/tensorflow/contrib/optimizer_v2/BUILD @@ -2,10 +2,9 @@ package( default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], # Apache 2.0 ) -licenses(["notice"]) # Apache 2.0 - exports_files(["LICENSE"]) load("//tensorflow:tensorflow.bzl", "py_test") diff --git a/tensorflow/contrib/periodic_resample/BUILD b/tensorflow/contrib/periodic_resample/BUILD index 37674071e41..9d97b85d851 100644 --- a/tensorflow/contrib/periodic_resample/BUILD +++ b/tensorflow/contrib/periodic_resample/BUILD @@ -8,9 +8,10 @@ load( ) load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") -package(default_visibility = ["//visibility:public"]) - -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) diff --git a/tensorflow/contrib/pi_examples/label_image/Makefile b/tensorflow/contrib/pi_examples/label_image/Makefile index 9d054a3133a..58fbd18dc3a 100644 --- a/tensorflow/contrib/pi_examples/label_image/Makefile +++ b/tensorflow/contrib/pi_examples/label_image/Makefile @@ -34,12 +34,14 @@ CXXFLAGS := --std=c++11 $(OPTFLAGS) LDFLAGS := \ -L/usr/local/lib \ -L$(TFLIBDIR) \ +-L$(DOWNLOADSDIR)/nsync/builds/default.linux.c++11/ \ -Wl,--no-whole-archive INCLUDES := \ -I/usr/local/include \ -I. \ -I$(DOWNLOADSDIR) \ -I$(DOWNLOADSDIR)/eigen/ \ +-I$(DOWNLOADSDIR)/absl/ \ -I$(PROTOGENDIR) \ -I$(PBTGENDIR) LIBS := \ @@ -49,6 +51,7 @@ LIBS := \ -Wl,--no-whole-archive \ -lstdc++ \ -lprotobuf \ +-lnsync \ -ldl \ -lpthread \ -lm \ diff --git a/tensorflow/contrib/pi_examples/label_image/label_image.cc b/tensorflow/contrib/pi_examples/label_image/label_image.cc index c6935a093f7..97a6e69ac03 100644 --- a/tensorflow/contrib/pi_examples/label_image/label_image.cc +++ b/tensorflow/contrib/pi_examples/label_image/label_image.cc @@ -26,6 +26,7 @@ limitations under the License. #include #include #include + #include #include diff --git a/tensorflow/contrib/predictor/BUILD b/tensorflow/contrib/predictor/BUILD index 3189bb97ca3..279006843ef 100644 --- a/tensorflow/contrib/predictor/BUILD +++ b/tensorflow/contrib/predictor/BUILD @@ -1,8 +1,9 @@ # `Predictor` classes provide an interface for efficient, repeated inference. -package(default_visibility = ["//tensorflow/contrib/predictor:__subpackages__"]) - -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow/contrib/predictor:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) diff --git a/tensorflow/contrib/proto/BUILD b/tensorflow/contrib/proto/BUILD index c167fd70189..403e9bc67c2 100644 --- a/tensorflow/contrib/proto/BUILD +++ b/tensorflow/contrib/proto/BUILD @@ -1,6 +1,7 @@ -package(default_visibility = ["//visibility:public"]) - -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) diff --git a/tensorflow/contrib/proto/python/ops/BUILD b/tensorflow/contrib/proto/python/ops/BUILD index ac09934b77d..cc5d319be27 100644 --- a/tensorflow/contrib/proto/python/ops/BUILD +++ b/tensorflow/contrib/proto/python/ops/BUILD @@ -1,6 +1,7 @@ -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//tensorflow:__subpackages__"]) +package( + default_visibility = ["//tensorflow:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) # Placeholders for folks with old dependencies. py_library( diff --git a/tensorflow/contrib/quantization/BUILD b/tensorflow/contrib/quantization/BUILD index 2de10e8faef..80f0a10ec75 100644 --- a/tensorflow/contrib/quantization/BUILD +++ b/tensorflow/contrib/quantization/BUILD @@ -1,16 +1,17 @@ # Description: # contains parts of TensorFlow that are experimental or unstable and which are not supported. -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) -package(default_visibility = ["//visibility:public"]) - load( "//tensorflow:tensorflow.bzl", - "tf_gen_op_wrapper_py", "tf_custom_op_library", + "tf_gen_op_wrapper_py", # @unused ) py_library( diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD index 598f6d15676..8183fab5f32 100644 --- a/tensorflow/contrib/quantize/BUILD +++ b/tensorflow/contrib/quantize/BUILD @@ -1,6 +1,7 @@ -package(default_visibility = ["//visibility:public"]) - -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py index a70f748fad6..6b94de61a60 100644 --- a/tensorflow/contrib/quantize/python/fold_batch_norms.py +++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py @@ -185,7 +185,7 @@ def _FindFusedBatchNorms(graph): graph_matcher.OpTypePattern('*')]) batch_norm_pattern = graph_matcher.OpTypePattern( - 'FusedBatchNorm', + 'FusedBatchNorm|FusedBatchNormV3', inputs=[ graph_matcher.OneofPattern( [matmul_reshape_pattern, layer_output_pattern]), gamma_pattern, @@ -489,8 +489,14 @@ def _CloneWithNewOperands(layer_op, input_tensor, weight_tensor, @ops.RegisterGradient('FoldFusedBatchNormGrad') -def _FoldFusedBatchNormGrad(op, unused_grad_y, grad_mean, grad_var, unused_1, - unused_2): +def _FoldFusedBatchNormGrad(op, + unused_grad_y, + grad_mean, + grad_var, + unused_1, + unused_2, + unused_3=None): + """Gradient function for the FusedBatchNorm ops matched by _GetLayerMatch.""" x = op.inputs[0] n = math_ops.cast( array_ops.size(x) / array_ops.size(grad_mean), dtypes.float32) diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py index 77b3f62e9d6..8616548bace 100644 --- a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py +++ b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.contrib.layers.python.layers import layers from tensorflow.contrib.quantize.python import fold_batch_norms from tensorflow.python.client import session +from tensorflow.python.compat import compat from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed @@ -167,7 +168,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name) def testFoldConv2d(self): - self._RunTestOverParameters(self._TestFoldConv2d) + with compat.forward_compatibility_horizon(2019, 6, 7): + self._RunTestOverParameters(self._TestFoldConv2d) def testMultipleLayerConv2d(self, relu=nn_ops.relu, @@ -337,7 +339,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name) def testFoldConv2dUnknownShape(self): - self._RunTestOverParameters(self._TestFoldConv2dUnknownShape) + with compat.forward_compatibility_horizon(2019, 6, 7): + self._RunTestOverParameters(self._TestFoldConv2dUnknownShape) def _TestFoldFullyConnectedLayer( self, relu, relu_op_name, with_bypass, has_scaling, fused_batch_norm, @@ -432,7 +435,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name) def testFoldFullyConnectedLayer(self): - self._RunTestOverParameters(self._TestFoldFullyConnectedLayer) + with compat.forward_compatibility_horizon(2019, 6, 7): + self._RunTestOverParameters(self._TestFoldFullyConnectedLayer) def _TestFoldDepthwiseConv2d(self, relu, relu_op_name, with_bypass, has_scaling, fused_batch_norm, @@ -543,7 +547,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name) def testFoldDepthwiseConv2d(self): - self._RunTestOverParameters(self._TestFoldDepthwiseConv2d) + with compat.forward_compatibility_horizon(2019, 6, 7): + self._RunTestOverParameters(self._TestFoldDepthwiseConv2d) def _TestFoldAtrousConv2d(self, relu, relu_op_name, with_bypass, has_scaling, fused_batch_norm, freeze_batch_norm_delay, @@ -660,7 +665,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name) def testFoldAtrousConv2d(self): - self._RunTestOverParameters(self._TestFoldAtrousConv2d) + with compat.forward_compatibility_horizon(2019, 6, 7): + self._RunTestOverParameters(self._TestFoldAtrousConv2d) def _TestCompareFoldAndUnfolded(self, relu, @@ -733,7 +739,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): self.assertAllClose(unfolded_backward, folded_backward, atol=1e-3) def testCompareFoldAndUnfolded(self): - self._RunTestOverParameters(self._TestCompareFoldAndUnfolded) + with compat.forward_compatibility_horizon(2019, 6, 7): + self._RunTestOverParameters(self._TestCompareFoldAndUnfolded) def _BatchNormParams(self, scale=True, fused=False): return { diff --git a/tensorflow/contrib/quantize/python/graph_matcher_test.py b/tensorflow/contrib/quantize/python/graph_matcher_test.py index be741644b61..95849d75b61 100644 --- a/tensorflow/contrib/quantize/python/graph_matcher_test.py +++ b/tensorflow/contrib/quantize/python/graph_matcher_test.py @@ -22,6 +22,7 @@ from tensorflow.contrib.framework.python import ops as contrib_ops from tensorflow.contrib.layers.python.layers import initializers from tensorflow.contrib.layers.python.layers import layers from tensorflow.contrib.quantize.python import graph_matcher +from tensorflow.python.compat import compat from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -36,48 +37,51 @@ from tensorflow.python.platform import googletest class GraphMatcherTest(test_util.TensorFlowTestCase): def test_conv_layer(self): - g = ops.Graph() - with g.as_default(): - inputs = array_ops.placeholder(dtypes.float32, shape=[8, 5, 5, 3]) + with compat.forward_compatibility_horizon(2019, 6, 7): + g = ops.Graph() + with g.as_default(): + inputs = array_ops.placeholder(dtypes.float32, shape=[8, 5, 5, 3]) - with contrib_ops.arg_scope( - [layers.batch_norm], fused=True, is_training=True, trainable=True): - return layers.convolution( - inputs, - num_outputs=16, - kernel_size=3, - stride=1, - padding='VALID', - activation_fn=nn_ops.relu, - normalizer_fn=layers.batch_norm, - normalizer_params={}, - weights_initializer=initializers.xavier_initializer(), - weights_regularizer=None, - biases_initializer=init_ops.zeros_initializer(), - biases_regularizer=None, - reuse=None, - trainable=True, - scope=None) + with contrib_ops.arg_scope([layers.batch_norm], + fused=True, + is_training=True, + trainable=True): + return layers.convolution( + inputs, + num_outputs=16, + kernel_size=3, + stride=1, + padding='VALID', + activation_fn=nn_ops.relu, + normalizer_fn=layers.batch_norm, + normalizer_params={}, + weights_initializer=initializers.xavier_initializer(), + weights_regularizer=None, + biases_initializer=init_ops.zeros_initializer(), + biases_regularizer=None, + reuse=None, + trainable=True, + scope=None) - inputs_pattern = graph_matcher.OpTypePattern('*', name='inputs') - relu_pattern = graph_matcher.OpTypePattern( - 'Relu', - name='relu', - inputs=[ - graph_matcher.OpTypePattern( - 'FusedBatchNorm', - inputs=[ - graph_matcher.OpTypePattern( - 'Conv2D', inputs=[inputs_pattern, '*']), '*', '*', '*', - '*' - ]) - ]) - matcher = graph_matcher.GraphMatcher(relu_pattern) - match_results = list(matcher.match_graph(g)) - self.assertEqual(1, len(match_results)) - match_result = match_results[0] - self.assertEqual(match_result.get_tensor(inputs_pattern), inputs) - self.assertEqual(match_result.get_tensor('inputs'), inputs) + inputs_pattern = graph_matcher.OpTypePattern('*', name='inputs') + relu_pattern = graph_matcher.OpTypePattern( + 'Relu', + name='relu', + inputs=[ + graph_matcher.OpTypePattern( + 'FusedBatchNormV3', + inputs=[ + graph_matcher.OpTypePattern( + 'Conv2D', inputs=[inputs_pattern, '*']), '*', '*', + '*', '*' + ]) + ]) + matcher = graph_matcher.GraphMatcher(relu_pattern) + match_results = list(matcher.match_graph(g)) + self.assertEqual(1, len(match_results)) + match_result = match_results[0] + self.assertEqual(match_result.get_tensor(inputs_pattern), inputs) + self.assertEqual(match_result.get_tensor('inputs'), inputs) def test_multiple_outputs(self): # - + diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py index 7c973fe5971..c2053beae33 100644 --- a/tensorflow/contrib/quantize/python/quantize.py +++ b/tensorflow/contrib/quantize/python/quantize.py @@ -577,7 +577,7 @@ def _IsSkipLayer(activation_op): if activation_op.type == 'Identity' and len(activation_op.outputs) == 1: if len(activation_op.outputs[0].consumers()) == 1: consumer = activation_op.outputs[0].consumers()[0] - if consumer.type == 'FusedBatchNorm': + if consumer.type in ['FusedBatchNorm', 'FusedBatchNormV3']: skip_layer = True logging.info( 'Skipping quantizing %s, because it is the output of a conv/fc ' diff --git a/tensorflow/contrib/quantize/python/quantize_graph_test.py b/tensorflow/contrib/quantize/python/quantize_graph_test.py index 9aa6e2c24d4..054c66be9cd 100644 --- a/tensorflow/contrib/quantize/python/quantize_graph_test.py +++ b/tensorflow/contrib/quantize/python/quantize_graph_test.py @@ -23,6 +23,7 @@ import functools from tensorflow.contrib.layers.python.layers import layers from tensorflow.contrib.quantize.python import quantize_graph from tensorflow.python import training +from tensorflow.python.compat import compat from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops @@ -275,7 +276,8 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase): self.assertEqual(graph_def_before, graph_def_after) def testIdentityNode(self): - self._RunTestOverAllRewrites(self._TestIdentityNode) + with compat.forward_compatibility_horizon(2019, 6, 7): + self._RunTestOverAllRewrites(self._TestIdentityNode) def _TestIdentityNode(self, rewrite_fn): graph = ops.Graph() @@ -293,10 +295,11 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase): conv_out_identity = graph.get_operation_by_name('test/conv_out') self._AssertOutputGoesToOps(conv_out_identity, graph, - ['test/BatchNorm/FusedBatchNorm']) + ['test/BatchNorm/FusedBatchNormV3']) def testActivationQuantization(self): - self._RunTestOverAllRewrites(self._TestActivationQuantization) + with compat.forward_compatibility_horizon(2019, 6, 7): + self._RunTestOverAllRewrites(self._TestActivationQuantization) def _TestActivationQuantization(self, rewrite_fn): graph = ops.Graph() diff --git a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py index f6bf57a789c..26a6e35c3c6 100644 --- a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py +++ b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.contrib.layers.python.layers import layers from tensorflow.contrib.quantize.python import fold_batch_norms from tensorflow.contrib.quantize.python import quantize +from tensorflow.python.compat import compat from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops @@ -484,7 +485,9 @@ class QuantizeTest(test_util.TensorFlowTestCase): self._AssertIdempotent(graph) def testQuantize_Conv2dWithBatchNorm(self): - self._RunBatchNormTestOverParameters(self._TestQuantize_Conv2dWithBatchNorm) + with compat.forward_compatibility_horizon(2019, 6, 7): + self._RunBatchNormTestOverParameters( + self._TestQuantize_Conv2dWithBatchNorm) def _TestQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name, with_bypass, delay, fused_batch_norm, @@ -541,7 +544,8 @@ class QuantizeTest(test_util.TensorFlowTestCase): use_resource) def testQuantize_FCWithBatchNorm(self): - self._RunBatchNormTestOverParameters(self._TestQuantize_FCWithBatchNorm) + with compat.forward_compatibility_horizon(2019, 6, 7): + self._RunBatchNormTestOverParameters(self._TestQuantize_FCWithBatchNorm) def _TestQuantize_FCWithBatchNorm(self, activation, activation_op_name, with_bypass, delay, fused_batch_norm, @@ -596,8 +600,9 @@ class QuantizeTest(test_util.TensorFlowTestCase): use_resource) def testQuantize_DepthwiseConv2dWithBatchNorm(self): - self._RunBatchNormTestOverParameters( - self._TestQuantize_DepthwiseConv2dWithBatchNorm) + with compat.forward_compatibility_horizon(2019, 6, 7): + self._RunBatchNormTestOverParameters( + self._TestQuantize_DepthwiseConv2dWithBatchNorm) def _TestQuantize_DepthwiseConv2dWithBatchNorm( self, activation, activation_op_name, with_bypass, delay, @@ -654,8 +659,9 @@ class QuantizeTest(test_util.TensorFlowTestCase): with_bypass, delay, use_resource) def testQuantize_AtrousConvWithBatchNorm(self): - self._RunBatchNormTestOverParameters( - self._TestQuantize_AtrousConvWithBatchNorm) + with compat.forward_compatibility_horizon(2019, 6, 7): + self._RunBatchNormTestOverParameters( + self._TestQuantize_AtrousConvWithBatchNorm) def _TestQuantize_AtrousConvWithBatchNorm( self, activation, activation_op_name, with_bypass, delay, @@ -723,18 +729,19 @@ class QuantizeTest(test_util.TensorFlowTestCase): self.assertEqual(graph_def_before, graph_def_after) def testBatchNormForcedUpdates(self): - parameter_list = [ - # (activation, activation_op_name, fused_batch_norm) - (nn_ops.relu6, 'Relu6', False), - (nn_ops.relu, 'Relu', False), - (array_ops.identity, 'Identity', False), - (nn_ops.relu6, 'Relu6', True), - (nn_ops.relu, 'Relu', True), - (array_ops.identity, 'Identity', True), - ] - for params in parameter_list: - self._TestBatchNormForcedUpdates(params[0], params[1], params[2], False) - self._TestBatchNormForcedUpdates(params[0], params[1], params[2], True) + with compat.forward_compatibility_horizon(2019, 6, 7): + parameter_list = [ + # (activation, activation_op_name, fused_batch_norm) + (nn_ops.relu6, 'Relu6', False), + (nn_ops.relu, 'Relu', False), + (array_ops.identity, 'Identity', False), + (nn_ops.relu6, 'Relu6', True), + (nn_ops.relu, 'Relu', True), + (array_ops.identity, 'Identity', True), + ] + for params in parameter_list: + self._TestBatchNormForcedUpdates(params[0], params[1], params[2], False) + self._TestBatchNormForcedUpdates(params[0], params[1], params[2], True) def _TestBatchNormForcedUpdates(self, activation, activation_op_name, fused_batch_norm, use_resource): diff --git a/tensorflow/contrib/rate/BUILD b/tensorflow/contrib/rate/BUILD index 4a60b4703ec..e67e62b127b 100644 --- a/tensorflow/contrib/rate/BUILD +++ b/tensorflow/contrib/rate/BUILD @@ -1,9 +1,10 @@ # Description: # contains parts of TensorFlow that are experimental or unstable and which are not supported. -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//visibility:public"]) +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) diff --git a/tensorflow/contrib/receptive_field/BUILD b/tensorflow/contrib/receptive_field/BUILD index 18ef0205941..0eeec09b440 100644 --- a/tensorflow/contrib/receptive_field/BUILD +++ b/tensorflow/contrib/receptive_field/BUILD @@ -3,10 +3,9 @@ package( default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 ) -licenses(["notice"]) # Apache 2.0 - exports_files(["LICENSE"]) load("//tensorflow:tensorflow.bzl", "py_test") diff --git a/tensorflow/contrib/recurrent/BUILD b/tensorflow/contrib/recurrent/BUILD index f9827f766da..2db92600fb7 100644 --- a/tensorflow/contrib/recurrent/BUILD +++ b/tensorflow/contrib/recurrent/BUILD @@ -1,8 +1,9 @@ # Recurrent library. -package(default_visibility = ["//visibility:public"]) - -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) diff --git a/tensorflow/contrib/reduce_slice_ops/BUILD b/tensorflow/contrib/reduce_slice_ops/BUILD index 02b3d66e461..d9741286e41 100644 --- a/tensorflow/contrib/reduce_slice_ops/BUILD +++ b/tensorflow/contrib/reduce_slice_ops/BUILD @@ -1,4 +1,6 @@ -licenses(["notice"]) # Apache 2.0 +package( + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) diff --git a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc index ea4026008ed..be09076e862 100644 --- a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc +++ b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc @@ -37,7 +37,7 @@ namespace functor { #define GPUReduceSliceFunctorReduceop(reduceop, beginning) \ template \ __global__ void ReduceSliceDeviceKernel##reduceop( \ - Cuda3DLaunchConfig config, Index indices_width, Index bound, \ + Gpu3DLaunchConfig config, Index indices_width, Index bound, \ const T begin, const Index *indices, const T *input, T *out) { \ CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count.x, X) { \ CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count.y, Y) { \ @@ -73,7 +73,7 @@ namespace functor { if (sizex * sizey * sizez == 0) { \ return; \ } \ - Cuda3DLaunchConfig config = GetCuda3DLaunchConfig( \ + Gpu3DLaunchConfig config = GetGpu3DLaunchConfig( \ sizex, sizey, sizez, d, ReduceSliceDeviceKernel##reduceop, \ 0, 0); \ \ diff --git a/tensorflow/contrib/remote_fused_graph/pylib/BUILD b/tensorflow/contrib/remote_fused_graph/pylib/BUILD index 274bdbeacf7..00552a077ff 100644 --- a/tensorflow/contrib/remote_fused_graph/pylib/BUILD +++ b/tensorflow/contrib/remote_fused_graph/pylib/BUILD @@ -1,9 +1,10 @@ # Description: # Contains ops for remote fused graph -package(default_visibility = ["//tensorflow:__subpackages__"]) - -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) diff --git a/tensorflow/contrib/resampler/BUILD b/tensorflow/contrib/resampler/BUILD index bbf10996759..4e5857b0a55 100644 --- a/tensorflow/contrib/resampler/BUILD +++ b/tensorflow/contrib/resampler/BUILD @@ -1,9 +1,10 @@ -licenses(["notice"]) # Apache 2.0 License +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 License +) exports_files(["LICENSE"]) -package(default_visibility = ["//visibility:public"]) - load( "//tensorflow:tensorflow.bzl", "tf_custom_op_library", diff --git a/tensorflow/contrib/resampler/kernels/resampler_ops_gpu.cu.cc b/tensorflow/contrib/resampler/kernels/resampler_ops_gpu.cu.cc index ecb6c187a07..bdadc36bbc7 100644 --- a/tensorflow/contrib/resampler/kernels/resampler_ops_gpu.cu.cc +++ b/tensorflow/contrib/resampler/kernels/resampler_ops_gpu.cu.cc @@ -117,8 +117,8 @@ struct Resampler2DFunctor { const int data_channels, const int num_sampling_points) { const int output_data_size = batch_size * num_sampling_points * data_channels; - ::tensorflow::CudaLaunchConfig config = - ::tensorflow::GetCudaLaunchConfig(output_data_size, d); + ::tensorflow::GpuLaunchConfig config = + ::tensorflow::GetGpuLaunchConfig(output_data_size, d); TF_CHECK_OK(CudaLaunchKernel( Resampler2DKernel, config.block_count, config.thread_per_block, 0, d.stream(), data, warp, output, batch_size, data_height, data_width, @@ -252,20 +252,20 @@ struct ResamplerGrad2DFunctor { const int grad_data_size = batch_size * data_height * data_width * data_channels; - ::tensorflow::CudaLaunchConfig config = - ::tensorflow::GetCudaLaunchConfig(grad_warp_size, d); + ::tensorflow::GpuLaunchConfig config = + ::tensorflow::GetGpuLaunchConfig(grad_warp_size, d); TF_CHECK_OK(::tensorflow::CudaLaunchKernel( SetZero, config.block_count, config.thread_per_block, 0, d.stream(), grad_warp_size, grad_warp)); - config = ::tensorflow::GetCudaLaunchConfig(grad_data_size, d); + config = ::tensorflow::GetGpuLaunchConfig(grad_data_size, d); TF_CHECK_OK(::tensorflow::CudaLaunchKernel( SetZero, config.block_count, config.thread_per_block, 0, d.stream(), grad_data_size, grad_data)); const int resampler_output_size = batch_size * num_sampling_points * data_channels; - config = ::tensorflow::GetCudaLaunchConfig(resampler_output_size, d); + config = ::tensorflow::GetGpuLaunchConfig(resampler_output_size, d); TF_CHECK_OK(CudaLaunchKernel(ResamplerGrad2DKernel, config.block_count, config.thread_per_block, 0, d.stream(), data, warp, grad_output, grad_data, grad_warp, diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD index 66fadcc16b5..4d3fc81199d 100644 --- a/tensorflow/contrib/rnn/BUILD +++ b/tensorflow/contrib/rnn/BUILD @@ -2,9 +2,10 @@ # Contains ops to train linear models on top of TensorFlow. # APIs here are meant to evolve over time. -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//visibility:public"]) +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) @@ -12,12 +13,12 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_tests") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") load( "//tensorflow:tensorflow.bzl", - "tf_custom_op_library", "tf_cc_test", - "tf_py_test", + "tf_custom_op_library", "tf_gen_op_libs", - "tf_kernel_library", "tf_gen_op_wrapper_py", + "tf_kernel_library", + "tf_py_test", ) cc_library( diff --git a/tensorflow/contrib/rpc/BUILD b/tensorflow/contrib/rpc/BUILD index dbd311a276b..f8463c050c1 100644 --- a/tensorflow/contrib/rpc/BUILD +++ b/tensorflow/contrib/rpc/BUILD @@ -1,6 +1,7 @@ -package(default_visibility = ["//visibility:public"]) - -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) diff --git a/tensorflow/contrib/rpc/python/kernel_tests/BUILD b/tensorflow/contrib/rpc/python/kernel_tests/BUILD index cb0b89ae55b..76f2ddc2d84 100644 --- a/tensorflow/contrib/rpc/python/kernel_tests/BUILD +++ b/tensorflow/contrib/rpc/python/kernel_tests/BUILD @@ -1,6 +1,7 @@ -package(default_visibility = ["//visibility:public"]) - -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) diff --git a/tensorflow/contrib/rpc/python/ops/BUILD b/tensorflow/contrib/rpc/python/ops/BUILD index 84d2a1832f1..4dee58ccaa4 100644 --- a/tensorflow/contrib/rpc/python/ops/BUILD +++ b/tensorflow/contrib/rpc/python/ops/BUILD @@ -1,6 +1,7 @@ -package(default_visibility = ["//visibility:public"]) - -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) diff --git a/tensorflow/contrib/saved_model/BUILD b/tensorflow/contrib/saved_model/BUILD index 969ff19eca6..173fb8f5ac9 100644 --- a/tensorflow/contrib/saved_model/BUILD +++ b/tensorflow/contrib/saved_model/BUILD @@ -16,12 +16,13 @@ # Description: # SavedModel contrib libraries. -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) -package(default_visibility = ["//tensorflow:__subpackages__"]) - load("//tensorflow:tensorflow.bzl", "py_test") py_library( diff --git a/tensorflow/contrib/saved_model/cc/saved_model/BUILD b/tensorflow/contrib/saved_model/cc/saved_model/BUILD index ea4d41d43b5..9d9a39e61e1 100644 --- a/tensorflow/contrib/saved_model/cc/saved_model/BUILD +++ b/tensorflow/contrib/saved_model/cc/saved_model/BUILD @@ -16,9 +16,10 @@ # Description: # SavedModel contrib libraries for C++. -package(default_visibility = ["//tensorflow:__subpackages__"]) - -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) diff --git a/tensorflow/contrib/seq2seq/BUILD b/tensorflow/contrib/seq2seq/BUILD index f42a2953ef9..4d3fdd689d7 100644 --- a/tensorflow/contrib/seq2seq/BUILD +++ b/tensorflow/contrib/seq2seq/BUILD @@ -1,12 +1,13 @@ # Description: # contains parts of TensorFlow that are experimental or unstable and which are not supported. -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = [ - "//learning/brain/google/xla/tests:__subpackages__", - "//tensorflow:__subpackages__", -]) +package( + default_visibility = [ + "//learning/brain/google/xla/tests:__subpackages__", + "//tensorflow:__subpackages__", + ], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) @@ -16,8 +17,8 @@ load( "//tensorflow:tensorflow.bzl", "tf_custom_op_library", "tf_gen_op_libs", - "tf_kernel_library", "tf_gen_op_wrapper_py", + "tf_kernel_library", ) tf_custom_op_py_library( diff --git a/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc b/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc index 3af6b1cb766..4af15095eec 100644 --- a/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc +++ b/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc @@ -90,7 +90,7 @@ struct GatherTree { // First kernel launch to "zero" things out beams.device(d) = beams.constant(end_token); - CudaLaunchConfig config = GetCudaLaunchConfig(batch_size * beam_width, d); + GpuLaunchConfig config = GetCudaLaunchConfig(batch_size * beam_width, d); TF_CHECK_OK(CudaLaunchKernel( GatherTreeOpKernel, config.block_count, config.thread_per_block, 0, d.stream(), batch_size, max_time, beam_width, step_ids.data(), diff --git a/tensorflow/contrib/seq2seq/python/ops/decoder.py b/tensorflow/contrib/seq2seq/python/ops/decoder.py index 33f7bac8159..ab124959001 100644 --- a/tensorflow/contrib/seq2seq/python/ops/decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/decoder.py @@ -113,6 +113,20 @@ class Decoder(object): raise NotImplementedError def finalize(self, outputs, final_state, sequence_lengths): + """Called after decoding iterations complete. + + Args: + outputs: RNNCell outputs (possibly nested tuple of) tensor[s] for all time + steps. + final_state: RNNCell final state (possibly nested tuple of) tensor[s] for + last time step. + sequence_lengths: 1-D `int32` tensor containing lengths of each sequence. + + Returns: + `(final_outputs, final_state)`: `final_outputs` is an object containing + the final decoder output, `final_state` is a (structure of) state tensors + and TensorArrays. + """ raise NotImplementedError @property diff --git a/tensorflow/contrib/session_bundle/BUILD b/tensorflow/contrib/session_bundle/BUILD index 40774c2238a..97c7bb7918b 100644 --- a/tensorflow/contrib/session_bundle/BUILD +++ b/tensorflow/contrib/session_bundle/BUILD @@ -1,9 +1,10 @@ # Description: # TensorFlow Serving session bundle. -package(default_visibility = ["//visibility:public"]) - -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) diff --git a/tensorflow/contrib/session_bundle/example/BUILD b/tensorflow/contrib/session_bundle/example/BUILD index 18a075943c2..37c8656616f 100644 --- a/tensorflow/contrib/session_bundle/example/BUILD +++ b/tensorflow/contrib/session_bundle/example/BUILD @@ -2,10 +2,9 @@ package( default_visibility = ["//tensorflow/contrib/session_bundle:__subpackages__"], + licenses = ["notice"], # Apache 2.0 ) -licenses(["notice"]) # Apache 2.0 - exports_files(["LICENSE"]) # vardef("PYTHON_BIN_PATH", "/usr/bin/python") diff --git a/tensorflow/contrib/session_bundle/session_bundle_test.cc b/tensorflow/contrib/session_bundle/session_bundle_test.cc index 612623ae309..9e4b1c72195 100644 --- a/tensorflow/contrib/session_bundle/session_bundle_test.cc +++ b/tensorflow/contrib/session_bundle/session_bundle_test.cc @@ -240,8 +240,8 @@ TEST(LoadSessionBundleFromPath, BasicTestRunOptionsThreadPoolInvalid) { // Expect failed session run calls with invalid run-options. EXPECT_FALSE(status.ok()); - EXPECT_TRUE(str_util::StrContains(status.error_message(), - "Invalid inter_op_thread_pool: 2")) + EXPECT_TRUE(absl::StrContains(status.error_message(), + "Invalid inter_op_thread_pool: 2")) << status.error_message(); } @@ -315,8 +315,8 @@ TEST_F(SessionBundleTest, ServingGraphEmpty) { }); status_ = LoadSessionBundleFromPath(options_, path, &bundle_); EXPECT_FALSE(status_.ok()); - EXPECT_TRUE(str_util::StrContains(status_.error_message(), - "Expected exactly one serving GraphDef")) + EXPECT_TRUE(absl::StrContains(status_.error_message(), + "Expected exactly one serving GraphDef")) << status_.error_message(); } @@ -332,8 +332,8 @@ TEST_F(SessionBundleTest, ServingGraphAnyIncorrectType) { status_ = LoadSessionBundleFromPath(options_, path, &bundle_); EXPECT_FALSE(status_.ok()); EXPECT_TRUE( - str_util::StrContains(status_.error_message(), - "Expected Any type_url for: tensorflow.GraphDef")) + absl::StrContains(status_.error_message(), + "Expected Any type_url for: tensorflow.GraphDef")) << status_.error_message(); } @@ -349,8 +349,7 @@ TEST_F(SessionBundleTest, ServingGraphAnyValueCorrupted) { }); status_ = LoadSessionBundleFromPath(options_, path, &bundle_); EXPECT_FALSE(status_.ok()); - EXPECT_TRUE( - str_util::StrContains(status_.error_message(), "Failed to unpack")) + EXPECT_TRUE(absl::StrContains(status_.error_message(), "Failed to unpack")) << status_.error_message(); } @@ -365,7 +364,7 @@ TEST_F(SessionBundleTest, AssetFileAnyIncorrectType) { }); status_ = LoadSessionBundleFromPath(options_, path, &bundle_); EXPECT_FALSE(status_.ok()); - EXPECT_TRUE(str_util::StrContains( + EXPECT_TRUE(absl::StrContains( status_.error_message(), "Expected Any type_url for: tensorflow.serving.AssetFile")) << status_.error_message(); @@ -383,8 +382,7 @@ TEST_F(SessionBundleTest, AssetFileAnyValueCorrupted) { }); status_ = LoadSessionBundleFromPath(options_, path, &bundle_); EXPECT_FALSE(status_.ok()); - EXPECT_TRUE( - str_util::StrContains(status_.error_message(), "Failed to unpack")) + EXPECT_TRUE(absl::StrContains(status_.error_message(), "Failed to unpack")) << status_.error_message(); } @@ -399,8 +397,8 @@ TEST_F(SessionBundleTest, InitOpTooManyValues) { }); status_ = LoadSessionBundleFromPath(options_, path, &bundle_); EXPECT_FALSE(status_.ok()); - EXPECT_TRUE(str_util::StrContains(status_.error_message(), - "Expected exactly one serving init op")) + EXPECT_TRUE(absl::StrContains(status_.error_message(), + "Expected exactly one serving init op")) << status_.error_message(); } diff --git a/tensorflow/contrib/session_bundle/signature_test.cc b/tensorflow/contrib/session_bundle/signature_test.cc index b1ff55552e0..99b55e3c3be 100644 --- a/tensorflow/contrib/session_bundle/signature_test.cc +++ b/tensorflow/contrib/session_bundle/signature_test.cc @@ -35,7 +35,7 @@ namespace serving { namespace { static bool HasSubstr(StringPiece base, StringPiece substr) { - bool ok = str_util::StrContains(base, substr); + bool ok = absl::StrContains(base, substr); EXPECT_TRUE(ok) << base << ", expected substring " << substr; return ok; } @@ -70,8 +70,8 @@ TEST(GetClassificationSignature, MissingSignature) { ClassificationSignature signature; const Status status = GetClassificationSignature(meta_graph_def, &signature); ASSERT_FALSE(status.ok()); - EXPECT_TRUE(str_util::StrContains(status.error_message(), - "Expected a classification signature")) + EXPECT_TRUE(absl::StrContains(status.error_message(), + "Expected a classification signature")) << status.error_message(); } @@ -87,8 +87,8 @@ TEST(GetClassificationSignature, WrongSignatureType) { ClassificationSignature signature; const Status status = GetClassificationSignature(meta_graph_def, &signature); ASSERT_FALSE(status.ok()); - EXPECT_TRUE(str_util::StrContains(status.error_message(), - "Expected a classification signature")) + EXPECT_TRUE(absl::StrContains(status.error_message(), + "Expected a classification signature")) << status.error_message(); } @@ -123,8 +123,8 @@ TEST(GetNamedClassificationSignature, MissingSignature) { const Status status = GetNamedClassificationSignature("foo", meta_graph_def, &signature); ASSERT_FALSE(status.ok()); - EXPECT_TRUE(str_util::StrContains(status.error_message(), - "Missing signature named \"foo\"")) + EXPECT_TRUE(absl::StrContains(status.error_message(), + "Missing signature named \"foo\"")) << status.error_message(); } @@ -142,9 +142,9 @@ TEST(GetNamedClassificationSignature, WrongSignatureType) { const Status status = GetNamedClassificationSignature("foo", meta_graph_def, &signature); ASSERT_FALSE(status.ok()); - EXPECT_TRUE(str_util::StrContains( - status.error_message(), - "Expected a classification signature for name \"foo\"")) + EXPECT_TRUE( + absl::StrContains(status.error_message(), + "Expected a classification signature for name \"foo\"")) << status.error_message(); } @@ -177,8 +177,8 @@ TEST(GetRegressionSignature, MissingSignature) { RegressionSignature signature; const Status status = GetRegressionSignature(meta_graph_def, &signature); ASSERT_FALSE(status.ok()); - EXPECT_TRUE(str_util::StrContains(status.error_message(), - "Expected a regression signature")) + EXPECT_TRUE(absl::StrContains(status.error_message(), + "Expected a regression signature")) << status.error_message(); } @@ -194,8 +194,8 @@ TEST(GetRegressionSignature, WrongSignatureType) { RegressionSignature signature; const Status status = GetRegressionSignature(meta_graph_def, &signature); ASSERT_FALSE(status.ok()); - EXPECT_TRUE(str_util::StrContains(status.error_message(), - "Expected a regression signature")) + EXPECT_TRUE(absl::StrContains(status.error_message(), + "Expected a regression signature")) << status.error_message(); } @@ -228,8 +228,8 @@ TEST(GetNamedSignature, MissingSignature) { Signature signature; const Status status = GetNamedSignature("foo", meta_graph_def, &signature); ASSERT_FALSE(status.ok()); - EXPECT_TRUE(str_util::StrContains(status.error_message(), - "Missing signature named \"foo\"")) + EXPECT_TRUE(absl::StrContains(status.error_message(), + "Missing signature named \"foo\"")) << status.error_message(); } @@ -371,7 +371,7 @@ TEST(RunClassification, RunNotOk) { const Status status = RunClassification(signature, input_tensor, &session, &classes_tensor, nullptr); ASSERT_FALSE(status.ok()); - EXPECT_TRUE(str_util::StrContains(status.error_message(), "Data is gone")) + EXPECT_TRUE(absl::StrContains(status.error_message(), "Data is gone")) << status.error_message(); } @@ -387,8 +387,7 @@ TEST(RunClassification, TooManyOutputs) { const Status status = RunClassification(signature, input_tensor, &session, &classes_tensor, nullptr); ASSERT_FALSE(status.ok()); - EXPECT_TRUE( - str_util::StrContains(status.error_message(), "Expected 1 output")) + EXPECT_TRUE(absl::StrContains(status.error_message(), "Expected 1 output")) << status.error_message(); } @@ -405,8 +404,8 @@ TEST(RunClassification, WrongBatchOutputs) { &classes_tensor, nullptr); ASSERT_FALSE(status.ok()); EXPECT_TRUE( - str_util::StrContains(status.error_message(), - "Input batch size did not match output batch size")) + absl::StrContains(status.error_message(), + "Input batch size did not match output batch size")) << status.error_message(); } @@ -452,7 +451,7 @@ TEST_F(RunRegressionTest, RunNotOk) { const Status status = RunRegression(signature_, input_tensor_, &session_, &output_tensor_); ASSERT_FALSE(status.ok()); - EXPECT_TRUE(str_util::StrContains(status.error_message(), "Data is gone")) + EXPECT_TRUE(absl::StrContains(status.error_message(), "Data is gone")) << status.error_message(); } @@ -464,8 +463,8 @@ TEST_F(RunRegressionTest, MismatchedSizeForBatchInputAndOutput) { RunRegression(signature_, input_tensor_, &session_, &output_tensor_); ASSERT_FALSE(status.ok()); EXPECT_TRUE( - str_util::StrContains(status.error_message(), - "Input batch size did not match output batch size")) + absl::StrContains(status.error_message(), + "Input batch size did not match output batch size")) << status.error_message(); } @@ -491,8 +490,7 @@ TEST(GetSignatures, MissingSignature) { Signatures read_signatures; const auto status = GetSignatures(meta_graph_def, &read_signatures); EXPECT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code()); - EXPECT_TRUE( - str_util::StrContains(status.error_message(), "Expected exactly one")) + EXPECT_TRUE(absl::StrContains(status.error_message(), "Expected exactly one")) << status.error_message(); } @@ -506,9 +504,9 @@ TEST(GetSignatures, WrongProtoInAny) { Signatures read_signatures; const auto status = GetSignatures(meta_graph_def, &read_signatures); EXPECT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code()); - EXPECT_TRUE(str_util::StrContains(status.error_message(), - "Expected Any type_url for: " - "tensorflow.serving.Signatures")) + EXPECT_TRUE(absl::StrContains(status.error_message(), + "Expected Any type_url for: " + "tensorflow.serving.Signatures")) << status.error_message(); } @@ -523,7 +521,7 @@ TEST(GetSignatures, JunkInAny) { Signatures read_signatures; const auto status = GetSignatures(meta_graph_def, &read_signatures); EXPECT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code()); - EXPECT_TRUE(str_util::StrContains(status.error_message(), "Failed to unpack")) + EXPECT_TRUE(absl::StrContains(status.error_message(), "Failed to unpack")) << status.error_message(); } @@ -570,8 +568,7 @@ TEST(GetSignatures, MultipleSignaturesNotOK) { Signatures read_signatures; const auto status = GetSignatures(meta_graph_def, &read_signatures); EXPECT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code()); - EXPECT_TRUE( - str_util::StrContains(status.error_message(), "Expected exactly one")) + EXPECT_TRUE(absl::StrContains(status.error_message(), "Expected exactly one")) << status.error_message(); } @@ -645,8 +642,8 @@ TEST(GetGenericSignature, WrongSignatureType) { const Status status = GetGenericSignature("generic_bindings", meta_graph_def, &signature); ASSERT_FALSE(status.ok()); - EXPECT_TRUE(str_util::StrContains(status.error_message(), - "Expected a generic signature:")) + EXPECT_TRUE(absl::StrContains(status.error_message(), + "Expected a generic signature:")) << status.error_message(); } diff --git a/tensorflow/contrib/signal/BUILD b/tensorflow/contrib/signal/BUILD index 5e4f130b314..61798014da2 100644 --- a/tensorflow/contrib/signal/BUILD +++ b/tensorflow/contrib/signal/BUILD @@ -1,6 +1,7 @@ -package(default_visibility = ["//tensorflow:internal"]) - -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) diff --git a/tensorflow/contrib/slim/BUILD b/tensorflow/contrib/slim/BUILD index 96e2dcecbdf..43c665d6687 100644 --- a/tensorflow/contrib/slim/BUILD +++ b/tensorflow/contrib/slim/BUILD @@ -1,12 +1,13 @@ # Description: # Contains the Slim library, including common neural networks and examples. -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) -package(default_visibility = ["//tensorflow:__subpackages__"]) - load("//tensorflow:tensorflow.bzl", "py_test") py_library( diff --git a/tensorflow/contrib/slim/python/slim/data/BUILD b/tensorflow/contrib/slim/python/slim/data/BUILD index f1b57361ac6..d6fe04ec410 100644 --- a/tensorflow/contrib/slim/python/slim/data/BUILD +++ b/tensorflow/contrib/slim/python/slim/data/BUILD @@ -1,12 +1,13 @@ # Description: # Contains packages used for creating and loading datasets. -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) -package(default_visibility = ["//tensorflow:__subpackages__"]) - load("//tensorflow:tensorflow.bzl", "py_test") py_library( diff --git a/tensorflow/contrib/slim/python/slim/nets/BUILD b/tensorflow/contrib/slim/python/slim/nets/BUILD index f19177b1881..36b1f048e79 100644 --- a/tensorflow/contrib/slim/python/slim/nets/BUILD +++ b/tensorflow/contrib/slim/python/slim/nets/BUILD @@ -1,19 +1,18 @@ # Description: # Contains typical networks definitions. -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -load("//tensorflow:tensorflow.bzl", "py_test") - package( default_visibility = [ "//tensorflow:__subpackages__", "//tensorflow_models:__subpackages__", ], + licenses = ["notice"], # Apache 2.0 ) +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + # Transitive dependencies of this target will be included in the pip package. py_library( name = "nets_pip", diff --git a/tensorflow/contrib/solvers/BUILD b/tensorflow/contrib/solvers/BUILD index 5247288d54a..0c30ab24439 100644 --- a/tensorflow/contrib/solvers/BUILD +++ b/tensorflow/contrib/solvers/BUILD @@ -2,12 +2,13 @@ # Contains ops for iterative solvers for linear systems, linear least-squares # problems, singular value decomposition and eigenvalue decomposition. -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) -package(default_visibility = ["//tensorflow:__subpackages__"]) - load("//tensorflow:tensorflow.bzl", "cuda_py_test") py_library( diff --git a/tensorflow/contrib/sparsemax/BUILD b/tensorflow/contrib/sparsemax/BUILD index ed4eca1a60a..cac8818febc 100644 --- a/tensorflow/contrib/sparsemax/BUILD +++ b/tensorflow/contrib/sparsemax/BUILD @@ -2,12 +2,13 @@ # Contains ops to train linear models on top of TensorFlow. # APIs here are meant to evolve over time. -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) -package(default_visibility = ["//visibility:public"]) - load("//tensorflow:tensorflow.bzl", "cuda_py_tests") load( "//tensorflow:tensorflow.bzl", diff --git a/tensorflow/contrib/specs/BUILD b/tensorflow/contrib/specs/BUILD index 055b04db8a5..8cd92293d9f 100644 --- a/tensorflow/contrib/specs/BUILD +++ b/tensorflow/contrib/specs/BUILD @@ -1,12 +1,13 @@ # Description: # A small domain-specific language (DSL) for defining deep learning networks. -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) -package(default_visibility = ["//tensorflow:__subpackages__"]) - load("//tensorflow:tensorflow.bzl", "tf_py_test") py_library( diff --git a/tensorflow/contrib/staging/BUILD b/tensorflow/contrib/staging/BUILD index 0c86f3db1d5..96f7066646c 100644 --- a/tensorflow/contrib/staging/BUILD +++ b/tensorflow/contrib/staging/BUILD @@ -1,8 +1,9 @@ -package(default_visibility = [ - "//visibility:public", -]) - -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) diff --git a/tensorflow/contrib/stat_summarizer/BUILD b/tensorflow/contrib/stat_summarizer/BUILD index 412a2c81a14..d32ccb0270b 100644 --- a/tensorflow/contrib/stat_summarizer/BUILD +++ b/tensorflow/contrib/stat_summarizer/BUILD @@ -1,12 +1,13 @@ # Description: # Contains a Python wrapper for the StatSummarizer C++ class. -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) -package(default_visibility = ["//tensorflow:__subpackages__"]) - load("//tensorflow:tensorflow.bzl", "tf_py_test") py_library( diff --git a/tensorflow/contrib/stateless/BUILD b/tensorflow/contrib/stateless/BUILD index f16d99f64c1..bbc5f7d470e 100644 --- a/tensorflow/contrib/stateless/BUILD +++ b/tensorflow/contrib/stateless/BUILD @@ -1,8 +1,9 @@ # Stateless random ops -package(default_visibility = ["//tensorflow:__subpackages__"]) - -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) diff --git a/tensorflow/contrib/summary/BUILD b/tensorflow/contrib/summary/BUILD index 4085801342b..bea00d70918 100644 --- a/tensorflow/contrib/summary/BUILD +++ b/tensorflow/contrib/summary/BUILD @@ -1,4 +1,6 @@ -licenses(["notice"]) # Apache 2.0 +package( + licenses = ["notice"], # Apache 2.0 +) exports_files([ "LICENSE", diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD index a7f8819915b..e27204dc0a9 100644 --- a/tensorflow/contrib/tensor_forest/BUILD +++ b/tensorflow/contrib/tensor_forest/BUILD @@ -1,7 +1,5 @@ # TensorFlow code for training random forests. -licenses(["notice"]) # Apache 2.0 - load("//tensorflow:tensorflow.bzl", "py_test") load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object") @@ -12,7 +10,10 @@ load("//tensorflow:tensorflow.bzl", "tf_kernel_library") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static") -package(default_visibility = ["//visibility:public"]) +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) diff --git a/tensorflow/contrib/tensor_forest/hybrid/BUILD b/tensorflow/contrib/tensor_forest/hybrid/BUILD index 64176a0dd07..c5949881108 100644 --- a/tensorflow/contrib/tensor_forest/hybrid/BUILD +++ b/tensorflow/contrib/tensor_forest/hybrid/BUILD @@ -1,13 +1,14 @@ # TensorFlow code for training hybrid neural network / decision tree models. -licenses(["notice"]) # Apache 2.0 - load("//tensorflow:tensorflow.bzl", "py_test") load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") -package(default_visibility = ["//visibility:public"]) +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/BUILD b/tensorflow/contrib/tensor_forest/kernels/v4/BUILD index b1b1559383a..d205b255402 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/BUILD +++ b/tensorflow/contrib/tensor_forest/kernels/v4/BUILD @@ -5,10 +5,9 @@ load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static") package( default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 ) -licenses(["notice"]) # Apache 2.0 - exports_files(["LICENSE"]) DECISION_TREE_RESOURCE_DEPS = [ diff --git a/tensorflow/contrib/tensor_forest/proto/BUILD b/tensorflow/contrib/tensor_forest/proto/BUILD index 04fd6a98395..ae5fef78b5e 100644 --- a/tensorflow/contrib/tensor_forest/proto/BUILD +++ b/tensorflow/contrib/tensor_forest/proto/BUILD @@ -1,11 +1,12 @@ -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") -package(default_visibility = ["//visibility:public"]) - tf_proto_library( name = "fertile_stats_proto", srcs = ["fertile_stats.proto"], diff --git a/tensorflow/contrib/tensorboard/BUILD b/tensorflow/contrib/tensorboard/BUILD index 85070cfad01..c2506d0346b 100644 --- a/tensorflow/contrib/tensorboard/BUILD +++ b/tensorflow/contrib/tensorboard/BUILD @@ -1,9 +1,10 @@ # Description: # TensorBoard module containing volatile or experimental code. -package(default_visibility = ["//tensorflow:internal"]) - -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index 91b6d2614a8..d90d6af9ba3 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -3,9 +3,10 @@ # and provide TensorRT operators and converter package. # APIs are meant to change over time. -package(default_visibility = ["//visibility:public"]) - -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD index 0a2cf105baf..43788306880 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD @@ -4,9 +4,10 @@ # APIs are meant to change while upgrading TRT. # add init_py into pip package BUILD dependency to install it. -package(default_visibility = ["//tensorflow:__subpackages__"]) - -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) load( "//tensorflow:tensorflow.bzl", diff --git a/tensorflow/contrib/testing/BUILD b/tensorflow/contrib/testing/BUILD index 8a40e111d77..258026a6bdb 100644 --- a/tensorflow/contrib/testing/BUILD +++ b/tensorflow/contrib/testing/BUILD @@ -1,12 +1,13 @@ # Description: # contains parts of TensorFlow that are experimental or unstable and which are not supported. -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) -package(default_visibility = ["//tensorflow:__subpackages__"]) - py_library( name = "testing_py", srcs = [ diff --git a/tensorflow/contrib/text/BUILD b/tensorflow/contrib/text/BUILD index 9f9e19a7cd6..5f4e4dff3d9 100644 --- a/tensorflow/contrib/text/BUILD +++ b/tensorflow/contrib/text/BUILD @@ -2,12 +2,13 @@ # contains parts of TensorFlow that are experimental or unstable and which # are not supported. -package(default_visibility = [ - "//learning/brain:__subpackages__", - "//tensorflow:__subpackages__", -]) - -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = [ + "//learning/brain:__subpackages__", + "//tensorflow:__subpackages__", + ], + licenses = ["notice"], # Apache 2.0 +) load( "//tensorflow:tensorflow.bzl", diff --git a/tensorflow/contrib/tfprof/BUILD b/tensorflow/contrib/tfprof/BUILD index e7f4ebdd36a..c8846391ccd 100644 --- a/tensorflow/contrib/tfprof/BUILD +++ b/tensorflow/contrib/tfprof/BUILD @@ -1,9 +1,10 @@ -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) -package(default_visibility = ["//tensorflow:__subpackages__"]) - py_library( name = "tfprof", srcs = [ diff --git a/tensorflow/contrib/timeseries/BUILD b/tensorflow/contrib/timeseries/BUILD index 18933227b34..989085564a5 100644 --- a/tensorflow/contrib/timeseries/BUILD +++ b/tensorflow/contrib/timeseries/BUILD @@ -1,8 +1,9 @@ -package(default_visibility = [ - "//tensorflow:internal", -]) - -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = [ + "//tensorflow:internal", + ], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) diff --git a/tensorflow/contrib/timeseries/examples/BUILD b/tensorflow/contrib/timeseries/examples/BUILD index 235f3adb92f..03979932a96 100644 --- a/tensorflow/contrib/timeseries/examples/BUILD +++ b/tensorflow/contrib/timeseries/examples/BUILD @@ -3,10 +3,9 @@ load("//tensorflow:tensorflow.bzl", "py_binary") package( default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], # Apache 2.0 ) -licenses(["notice"]) # Apache 2.0 - exports_files(["LICENSE"]) config_setting( diff --git a/tensorflow/contrib/timeseries/python/timeseries/BUILD b/tensorflow/contrib/timeseries/python/timeseries/BUILD index ae2c4a5cb72..02e475367af 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/BUILD +++ b/tensorflow/contrib/timeseries/python/timeseries/BUILD @@ -4,10 +4,9 @@ package( default_visibility = [ "//tensorflow:internal", ], + licenses = ["notice"], # Apache 2.0 ) -licenses(["notice"]) # Apache 2.0 - exports_files(["LICENSE"]) py_library( diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD index 08eafece5d3..38c3ac4dc4d 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD +++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD @@ -5,10 +5,9 @@ load("//tensorflow:tensorflow.bzl", "py_test") package( default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], # Apache 2.0 ) -licenses(["notice"]) # Apache 2.0 - exports_files(["LICENSE"]) py_library( diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index a53cf2b86c0..be02f29b432 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -9,8 +9,6 @@ load( ) load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") -licenses(["notice"]) # Apache 2.0 - package( default_visibility = [ "//cloud/vmm/testing/tests/tpu:__subpackages__", @@ -23,6 +21,7 @@ package( "//tensorflow_models:__subpackages__", "//vr/perception:__subpackages__", ], + licenses = ["notice"], # Apache 2.0 ) py_library( diff --git a/tensorflow/contrib/tpu/profiler/BUILD b/tensorflow/contrib/tpu/profiler/BUILD index e2ce77e1181..461f9856b0d 100644 --- a/tensorflow/contrib/tpu/profiler/BUILD +++ b/tensorflow/contrib/tpu/profiler/BUILD @@ -1,4 +1,6 @@ -licenses(["notice"]) # Apache 2.0 +package( + licenses = ["notice"], # Apache 2.0 +) load("//tensorflow:tensorflow.bzl", "tf_cc_binary") load("//tensorflow:tensorflow.bzl", "tf_cc_test") diff --git a/tensorflow/contrib/training/BUILD b/tensorflow/contrib/training/BUILD index 8f1d5ce2fdf..22635592aed 100644 --- a/tensorflow/contrib/training/BUILD +++ b/tensorflow/contrib/training/BUILD @@ -1,14 +1,15 @@ # Description: # contains parts of TensorFlow that are experimental or unstable and which are not supported. -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = [ + "//tensorflow:internal", + ], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) -package(default_visibility = [ - "//tensorflow:internal", -]) - load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") load("//tensorflow:tensorflow.bzl", "py_test") diff --git a/tensorflow/contrib/util/BUILD b/tensorflow/contrib/util/BUILD index 7b2bc30e3a8..3d4123062ac 100644 --- a/tensorflow/contrib/util/BUILD +++ b/tensorflow/contrib/util/BUILD @@ -1,12 +1,13 @@ # Description: # contains parts of TensorFlow that are experimental or unstable and which are not supported. -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = ["//tensorflow:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) exports_files(["LICENSE"]) -package(default_visibility = ["//tensorflow:__subpackages__"]) - load("//tensorflow:tensorflow.bzl", "tf_cc_binary") load("//tensorflow:tensorflow.bzl", "tf_cc_test") diff --git a/tensorflow/contrib/verbs/BUILD b/tensorflow/contrib/verbs/BUILD index 19cb8983b68..3cfd4d6e81d 100644 --- a/tensorflow/contrib/verbs/BUILD +++ b/tensorflow/contrib/verbs/BUILD @@ -1,11 +1,12 @@ # Description: # Verbs RDMA communication interfaces and implementations for TensorFlow. -package(default_visibility = [ - "//tensorflow:__subpackages__", -]) - -licenses(["notice"]) # Apache 2.0 +package( + default_visibility = [ + "//tensorflow:__subpackages__", + ], + licenses = ["notice"], # Apache 2.0 +) load("//tensorflow:tensorflow.bzl", "tf_cuda_library") diff --git a/tensorflow/contrib/verbs/verbs_util.cc b/tensorflow/contrib/verbs/verbs_util.cc index a6333d9f362..dc5815181f1 100644 --- a/tensorflow/contrib/verbs/verbs_util.cc +++ b/tensorflow/contrib/verbs/verbs_util.cc @@ -44,7 +44,7 @@ void VerbsUtil::GetKeyAndStepId(const string& key_with_step_id, string& key, CHECK(parts.size() == 6) << "Key with step_id must have 6 parts"; strings::safe_strto64(parts[5], &step_id); parts.pop_back(); // remove step_id - key.assign(str_util::Join(parts, ";")); // stitch them together + key.assign(absl::StrJoin(parts, ";")); // stitch them together } } // namespace tensorflow diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index bcd02aa8410..b07e018dd2a 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -349,6 +349,8 @@ cc_library( deps = [ ":lib_platform", "//tensorflow/core/platform/default/build_config:base", + "@com_google_absl//absl/base", + "@com_google_absl//absl/strings", ], ) @@ -1000,7 +1002,6 @@ cc_library( name = "allocator", srcs = [ "framework/allocator.cc", - "framework/allocator_registry.cc", "framework/allocator_registry.h", "framework/numeric_types.h", "framework/tracking_allocator.cc", @@ -1012,12 +1013,37 @@ cc_library( ], features = ["parse_headers"], visibility = ["//visibility:public"], + deps = [ + ":lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "//third_party/eigen3", + ] + if_static(extra_deps = [":allocator_registry_impl"]), + alwayslink = 1, +) + +# This target will be included in libtensorflow_framework.so via the +# framework_internal_impl target. +# All other dependencies on this target need to go through if_static guard, +# as otherwise duplicate registration in the registry will cause crashes. +cc_library( + name = "allocator_registry_impl", + srcs = [ + "framework/allocator.h", + "framework/allocator_registry.cc", + "framework/allocator_registry.h", + "framework/cpu_allocator_impl.cc", + "framework/numeric_types.h", + "framework/tracking_allocator.h", + "framework/type_traits.h", + ], deps = [ ":lib", "//third_party/eigen3", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], + alwayslink = 1, ) cc_library( @@ -1084,6 +1110,7 @@ cc_library( ":lib_internal", ":protos_all_cc", "//tensorflow/core/util/proto:proto_utils", + "@com_google_absl//absl/strings", ], ) @@ -2324,6 +2351,7 @@ tf_proto_library_cc( srcs = ["protobuf/eager_service.proto"], has_services = 1, cc_api_version = 2, + cc_grpc_version = 1, cc_stubby_versions = ["2"], protodeps = tf_additional_all_protos(), visibility = [ @@ -2872,6 +2900,7 @@ tf_cuda_library( "**/*test*", "**/*main.cc", "framework/allocator.cc", + "framework/cpu_allocator_impl.cc", "framework/allocator_registry.cc", "framework/tracking_allocator.cc", "example/example_parser_configuration.*", @@ -2905,6 +2934,7 @@ tf_cuda_library( ], }), deps = [ + ":allocator_registry_impl", ":allocator", ":feature_util", ":lib", @@ -3342,6 +3372,7 @@ cc_library( "//tensorflow/compiler:__subpackages__", "//tensorflow/core/kernels:__subpackages__", "//tensorflow/core/profiler:__subpackages__", + "//tensorflow/stream_executor:__subpackages__", ], deps = [":lib_internal"], ) @@ -3466,7 +3497,6 @@ GPU_RUNTIME_HEADERS = [ tf_cuda_library( name = "gpu_runtime_impl", srcs = [ - "common_runtime/gpu/gpu_bfc_allocator.cc", "common_runtime/gpu/gpu_cudamalloc_allocator.cc", "common_runtime/gpu/gpu_debug_allocator.cc", "common_runtime/gpu/gpu_device.cc", @@ -3484,6 +3514,7 @@ tf_cuda_library( ":core_cpu_lib", ":framework", ":framework_internal", + ":gpu_bfc_allocator", ":gpu_id_impl", ":gpu_init_impl", ":gpu_lib", @@ -3773,6 +3804,8 @@ tf_cc_tests( ":test", ":test_main", "//third_party/eigen3", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "@zlib_archive//:zlib", ], ) @@ -3788,6 +3821,7 @@ tf_cc_test( ":protos_all_cc", ":test", "//third_party/eigen3", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/core/api_def/base_api/api_def_FusedBatchNormGradV3.pbtxt b/tensorflow/core/api_def/base_api/api_def_FusedBatchNormGradV3.pbtxt new file mode 100644 index 00000000000..76b33b959f6 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_FusedBatchNormGradV3.pbtxt @@ -0,0 +1,116 @@ +op { + graph_op_name: "FusedBatchNormGradV3" + in_arg { + name: "y_backprop" + description: <