diff --git a/.bazelrc b/.bazelrc index dbdadc98ea7..2b80063fd59 100644 --- a/.bazelrc +++ b/.bazelrc @@ -319,11 +319,13 @@ build:xla --define=with_xla_support=true # BEGIN TF REMOTE BUILD EXECUTION OPTIONS # Options when using remote execution # WARNING: THESE OPTIONS WONT WORK IF YOU DO NOT HAVE PROPER AUTHENTICATION AND PERMISSIONS + +# Flag to enable remote config +common --experimental_repo_remote_exec + build:rbe --action_env=BAZEL_DO_NOT_DETECT_CPP_TOOLCHAIN=1 -build:rbe --auth_enabled=true -build:rbe --auth_scope=https://www.googleapis.com/auth/cloud-source-tools +build:rbe --google_default_credentials build:rbe --bes_backend=buildeventservice.googleapis.com -build:rbe --bes_best_effort=false build:rbe --bes_results_url="https://source.cloud.google.com/results/invocations" build:rbe --bes_timeout=600s build:rbe --define=EXECUTOR=remote @@ -336,7 +338,7 @@ build:rbe --spawn_strategy=remote,worker,standalone,local test:rbe --test_env=USER=anon # Attempt to minimize the amount of data transfer between bazel and the remote # workers: -build:rbe --experimental_inmemory_jdeps_files --experimental_inmemory_dotd_files --experimental_remote_download_outputs=toplevel +build:rbe --remote_download_toplevel build:rbe_linux --config=rbe build:rbe_linux --action_env=PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/go/bin" diff --git a/.github/ISSUE_TEMPLATE/00-bug-issue.md b/.github/ISSUE_TEMPLATE/00-bug-issue.md index bb4a1a7ea14..0c2bcb27c7d 100644 --- a/.github/ISSUE_TEMPLATE/00-bug-issue.md +++ b/.github/ISSUE_TEMPLATE/00-bug-issue.md @@ -10,13 +10,20 @@ labels: 'type:bug' we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template -**System information** - Have I written custom code (as opposed to using a stock -example script provided in TensorFlow): - OS Platform and Distribution (e.g., -Linux Ubuntu 16.04): - Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if -the issue happens on mobile device: - TensorFlow installed from (source or -binary): - TensorFlow version (use command below): - Python version: - Bazel -version (if compiling from source): - GCC/Compiler version (if compiling from -source): - CUDA/cuDNN version: - GPU model and memory: +**System information** +- Have I written custom code (as opposed to using a stock +example script provided in TensorFlow): +- OS Platform and Distribution (e.g., +Linux Ubuntu 16.04): +- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if +the issue happens on mobile device: +- TensorFlow installed from (source or +binary): - TensorFlow version (use command below): +- Python version: - Bazel +version (if compiling from source): +- GCC/Compiler version (if compiling from +source): +- CUDA/cuDNN version: - GPU model and memory: You can collect some of this information using our environment capture [script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh) @@ -28,8 +35,9 @@ tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c **Describe the expected behavior** -**Code to reproduce the issue** Provide a reproducible test case that is the -bare minimum necessary to generate the problem. +**Standalone code to reproduce the issue** +Provide a reproducible test case that is the bare minimum necessary to generate +the problem. If possible, please share a link to Colab/Jupyter/any notebook. **Other info / logs** Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full diff --git a/.github/ISSUE_TEMPLATE/40-tflite-op-request.md b/.github/ISSUE_TEMPLATE/40-tflite-op-request.md index f4b6733c211..4f1e60b553a 100644 --- a/.github/ISSUE_TEMPLATE/40-tflite-op-request.md +++ b/.github/ISSUE_TEMPLATE/40-tflite-op-request.md @@ -17,8 +17,14 @@ labels: 'comp:lite' # Copy and paste here ``` +**Standalone code to reproduce the issue** +Provide a reproducible test case that is the bare minimum necessary to generate +the problem. If possible, please share a link to Colab/Jupyter/any notebook. + Also, please include a link to a GraphDef or the model if possible. **Any other info / logs** -Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached. +Include any logs or source code that would be helpful to diagnose the problem. +If including tracebacks, please include the full traceback. Large logs and files +should be attached. diff --git a/.github/ISSUE_TEMPLATE/60-tflite-converter-issue.md b/.github/ISSUE_TEMPLATE/60-tflite-converter-issue.md index 3cd6e977d2f..32ebaff1a9c 100644 --- a/.github/ISSUE_TEMPLATE/60-tflite-converter-issue.md +++ b/.github/ISSUE_TEMPLATE/60-tflite-converter-issue.md @@ -1,6 +1,7 @@ --- name: TensorFlow Lite New Converter Issue about: Use this template for reporting issues during model conversion to TFLite +labels: 'TFLiteConverter' --- @@ -12,6 +13,7 @@ about: Use this template for reporting issues during model conversion to TFLite **Command used to run the converter or code if you’re using the Python API** +If possible, please share a link to Colab/Jupyter/any notebook. ``` # Copy and paste here the exact command diff --git a/.github/ISSUE_TEMPLATE/80-performance-issue.md b/.github/ISSUE_TEMPLATE/80-performance-issue.md index 2090801742c..a1cbf23df4b 100644 --- a/.github/ISSUE_TEMPLATE/80-performance-issue.md +++ b/.github/ISSUE_TEMPLATE/80-performance-issue.md @@ -11,13 +11,20 @@ As per our we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:performance_template -**System information** - Have I written custom code (as opposed to using a stock -example script provided in TensorFlow): - OS Platform and Distribution (e.g., -Linux Ubuntu 16.04): - Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if -the issue happens on mobile device: - TensorFlow installed from (source or -binary): - TensorFlow version (use command below): - Python version: - Bazel -version (if compiling from source): - GCC/Compiler version (if compiling from -source): - CUDA/cuDNN version: - GPU model and memory: +**System information** +- Have I written custom code (as opposed to using a stock +example script provided in TensorFlow): +- OS Platform and Distribution (e.g., +Linux Ubuntu 16.04): +- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if +the issue happens on mobile device: +- TensorFlow installed from (source or +binary): - TensorFlow version (use command below): +- Python version: - Bazel +version (if compiling from source): +- GCC/Compiler version (if compiling from +source): +- CUDA/cuDNN version: - GPU model and memory: You can collect some of this information using our environment capture [script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh) @@ -29,8 +36,9 @@ tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c **Describe the expected behavior** -**Code to reproduce the issue** Provide a reproducible test case that is the -bare minimum necessary to generate the problem. +**Standalone code to reproduce the issue** +Provide a reproducible test case that is the bare minimum necessary to generate +the problem. If possible, please share a link to Colab/Jupyter/any notebook. **Other info / logs** Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full diff --git a/WORKSPACE b/WORKSPACE index 0139c4aa643..ad645add449 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -113,3 +113,28 @@ http_archive( "https://storage.googleapis.com/download.tensorflow.org/models/speech_commands_v0.01.zip", ], ) + +# Required for dependency @com_github_grpc_grpc + +load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps") + +grpc_deps() + +load( + "@build_bazel_rules_apple//apple:repositories.bzl", + "apple_rules_dependencies", +) + +apple_rules_dependencies() + +load( + "@build_bazel_apple_support//lib:repositories.bzl", + "apple_support_dependencies", +) + +apple_support_dependencies() + +load("@upb//bazel:repository_defs.bzl", "bazel_version_repository") + +bazel_version_repository(name = "bazel_version") + diff --git a/configure.py b/configure.py index 64956049c34..ed09a693fd4 100644 --- a/configure.py +++ b/configure.py @@ -49,7 +49,7 @@ _TF_BAZELRC_FILENAME = '.tf_configure.bazelrc' _TF_WORKSPACE_ROOT = '' _TF_BAZELRC = '' _TF_CURRENT_BAZEL_VERSION = None -_TF_MIN_BAZEL_VERSION = '1.2.1' +_TF_MIN_BAZEL_VERSION = '2.0.0' _TF_MAX_BAZEL_VERSION = '2.0.0' NCCL_LIB_PATHS = [ diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 4c6f15f5367..55406a5686a 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -505,13 +505,15 @@ selects.config_setting_group( package_group( name = "internal", packages = [ + # To pass open source testing in the pip Kokoros. + "//bazel_pip/tensorflow/...", "//learning/brain/swift/x10/...", "//perftools/accelerators/xprof/api/...", + "//third_party/py/autograph/...", + "//third_party/swift/tensorflow/x10/...", "//tensorflow/...", "//tensorflow_estimator/python/estimator/...", "//tensorflow_models/official/...", - "//third_party/py/autograph/...", - "//third_party/swift/tensorflow/x10/...", ], ) @@ -545,8 +547,8 @@ cc_library( name = "grpc", visibility = ["//visibility:public"], deps = select({ - ":linux_s390x": ["@grpc//:grpc_unsecure"], - "//conditions:default": ["@grpc"], + ":linux_s390x": ["@com_github_grpc_grpc//:grpc_unsecure"], + "//conditions:default": ["@com_github_grpc_grpc//:grpc"], }), ) @@ -554,8 +556,8 @@ cc_library( name = "grpc++", visibility = ["//visibility:public"], deps = select({ - ":linux_s390x": ["@grpc//:grpc++_unsecure"], - "//conditions:default": ["@grpc//:grpc++"], + ":linux_s390x": ["@com_github_grpc_grpc//:grpc++_unsecure"], + "//conditions:default": ["@com_github_grpc_grpc//:grpc++"], }), ) diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index c11ef3756d5..4e7ba3943ae 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/casts.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/net.h" #include "tensorflow/core/platform/platform.h" @@ -816,12 +817,15 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes, const int num_inputs = input_shapes->num_items; NodeDef node_def; - node_def.set_name(tfe_op->operation.Name()); - node_def.set_op(tfe_op->operation.Name()); + node_def.set_name(tfe_op->operation->Name()); + node_def.set_op(tfe_op->operation->Name()); for (int i = 0; i < num_inputs; ++i) { node_def.add_input("dummy_input"); } - tfe_op->operation.Attrs().FillAttrValueMap(node_def.mutable_attr()); + tensorflow::down_cast( + tfe_op->operation.get()) + ->Attrs() + .FillAttrValueMap(node_def.mutable_attr()); const tensorflow::OpRegistrationData* op_reg_data; status->status = diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 509a6205274..963bafe8ca1 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -28,6 +28,8 @@ tf_cuda_library( "c_api_debug.cc", "c_api_experimental.h", "c_api_internal.h", + "operation_interface.cc", + "operation_interface.h", "tensor_handle_interface.h", ], hdrs = ["c_api.h"], @@ -56,6 +58,7 @@ tf_cuda_library( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/platform:casts", "//tensorflow/core/platform:errors", "//tensorflow/core:protos_all_cc", "//tensorflow/core/profiler/lib:traceme", @@ -93,6 +96,7 @@ filegroup( "c_api_experimental.h", "c_api_internal.h", "dlpack.h", + "operation_interface.h", "tensor_handle_interface.h", ], visibility = [ @@ -105,6 +109,7 @@ tf_cuda_library( name = "c_api_internal", srcs = [ "c_api_experimental.h", + "operation_interface.h", "tensor_handle_interface.h", ], hdrs = ["c_api_internal.h"], @@ -129,6 +134,7 @@ tf_cuda_library( "//tensorflow/core/common_runtime/eager:eager_operation", "//tensorflow/core/common_runtime/eager:kernel_and_device", "//tensorflow/core/common_runtime/eager:tensor_handle", + "@com_google_absl//absl/container:fixed_array", ], ) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index fe31c317853..b6a87cc616d 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -27,7 +27,6 @@ limitations under the License. // clang-format on #include "absl/algorithm/container.h" -#include "absl/container/fixed_array.h" #include "absl/memory/memory.h" #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" @@ -95,14 +94,6 @@ using tensorflow::string; namespace { -const tensorflow::OpDef* GetOpDef(TFE_Op* op, TF_Status* status) { - const tensorflow::OpDef* op_def = op->operation.OpDef(); - if (op_def) return op_def; - status->status = - tensorflow::OpDefForOp(op->operation.Name().c_str(), &op_def); - return op_def; -} - bool IsCPU( absl::variant variant) { if (VariantDeviceIsCustom(variant)) { @@ -1125,9 +1116,13 @@ TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) { return retval; } else { tensorflow::Tensor tensor; - if (IsCPU(handle_->device())) { + if (IsCPU(handle_->device()) || handle_->HasLocalMirror(nullptr)) { const tensorflow::Tensor* src = nullptr; - *status = handle_->Tensor(&src); + if (handle_->HasLocalMirror(nullptr)) { + *status = handle_->TensorFromDevice(nullptr, &src); + } else { + *status = handle_->Tensor(&src); + } if (!status->ok()) return nullptr; tensor = *src; } else { @@ -1135,6 +1130,13 @@ TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) { CHECK_NE(ctx, nullptr); *status = handle_->CopyToDevice(*ctx, ctx->HostCPU(), &tensor); if (!status->ok()) return nullptr; + if (handle_->ImplicitMirroring()) { + *status = handle_->AddEmptyLocalMirror(nullptr); + if (!status->ok()) return nullptr; + Tensor mirror = tensor; + *status = handle_->SetTensor(std::move(mirror), nullptr); + if (!status->ok()) return nullptr; + } } return tensorflow::TF_TensorFromTensor(tensor, status); } @@ -1199,18 +1201,11 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory( dimvec[i] = static_cast(dims[i]); } - if (dtype == TF_STRING || dtype == TF_RESOURCE || - !tensorflow::DataTypeCanUseMemcpy( - static_cast(dtype))) { - status->status = tensorflow::errors::InvalidArgument( - "Trying to create a tensor with a pointer to non-pod memory."); - deallocator(data, len, deallocator_arg); - return nullptr; - } // TODO(apassos) do we need to wrap the deallocator here to make sure to sync // the device? TF_ManagedBuffer* buf = - new TF_ManagedBuffer(data, len, deallocator, deallocator_arg); + new TF_ManagedBuffer(data, len, deallocator, deallocator_arg, + /*owns_memory=*/false); tensorflow::Tensor t(static_cast(dtype), tensorflow::TensorShape(dimvec), buf); @@ -1261,9 +1256,8 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h, TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, TF_Status* status) { std::unique_ptr new_op( - new TFE_Op{tensorflow::EagerOperation(ctx->context)}); - status->status = - new_op->operation.Reset(op_or_function_name, nullptr, false, nullptr); + new TFE_Op{std::make_unique(ctx)}); + status->status = new_op->operation->Reset(op_or_function_name, nullptr); if (!status->status.ok()) { new_op.reset(); } @@ -1273,49 +1267,51 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, void TFE_DeleteOp(TFE_Op* op) { delete op; } void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) { - status->status = op->operation.SetDeviceName(device_name); + status->status = op->operation->SetDeviceName(device_name); } const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) { - tensorflow::Device* device = (op->operation.Device() == nullptr) - ? op->operation.EagerContext().HostCPU() - : op->operation.Device(); - return device->name().c_str(); + return op->operation->DeviceName().c_str(); } void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) { - op->operation.SetUseXla(enable); -#ifndef TENSORFLOW_EAGER_USE_XLA +#ifdef TENSORFLOW_EAGER_USE_XLA + tensorflow::Status s = op->operation->SetUseXla(enable); + if (!s.ok()) { + LOG(ERROR) << "Could not enable XLA compilation for op: " << s; + } +#else LOG(WARNING) << "This call is a no-op, as the TensorFlow library is not " "built with XLA support."; #endif // TENSORFLOW_EAGER_USE_XLA } void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) { - tensorflow::TensorHandle* h = - tensorflow::down_cast( - input->handle.get()) - ->Handle(); - op->operation.AddInput(h); - status->status = op->operation.MaybeInferSingleInputAttrs(h); + status->status = op->operation->AddInput(input->handle); } void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs, TF_Status* status) { + absl::FixedArray> handles( + num_inputs); for (int i = 0; i < num_inputs; ++i) { - op->operation.AddInput( - tensorflow::down_cast( - inputs[i]->handle.get()) - ->Handle()); + handles[i].reset(inputs[i]->handle->Copy()); } - status->status = op->operation.InferInputListAttrs(num_inputs); + status->status = op->operation->AddInputList(handles); } TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, unsigned char* is_list, TF_Status* status) { TF_AttrType ret = TF_ATTR_INT; - status->status = tensorflow::AttrTypeByName(*op->operation.AttrTypes(), - attr_name, &ret, is_list); + const tensorflow::AttrTypeMap* attr_types_; + bool is_function; + status->status = tensorflow::AttrTypeMapForOp(op->operation->Name().c_str(), + &attr_types_, &is_function); + if (!status->status.ok()) { + return ret; + } + status->status = + tensorflow::AttrTypeByName(*attr_types_, attr_name, &ret, is_list); return ret; } @@ -1336,221 +1332,150 @@ TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx, void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value, size_t length) { - op->operation.MutableAttrs()->Set( - attr_name, - tensorflow::StringPiece(static_cast(value), length)); + auto s = op->operation->SetAttrString( + attr_name, static_cast(value), length); + if (!s.ok()) { + LOG(WARNING) << "Unable to set attribute: " << attr_name; + } } void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) { - op->operation.MutableAttrs()->Set(attr_name, static_cast(value)); + auto s = op->operation->SetAttrInt(attr_name, value); + if (!s.ok()) { + LOG(WARNING) << "Unable to set attribute: " << attr_name; + } } void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) { - op->operation.MutableAttrs()->Set(attr_name, value); + auto s = op->operation->SetAttrFloat(attr_name, value); + if (!s.ok()) { + LOG(WARNING) << "Unable to set attribute: " << attr_name; + } } void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) { - op->operation.MutableAttrs()->Set(attr_name, (value == 0) ? false : true); + auto s = op->operation->SetAttrBool(attr_name, (value == 0) ? false : true); + if (!s.ok()) { + LOG(WARNING) << "Unable to set attribute: " << attr_name; + } } void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) { - op->operation.MutableAttrs()->Set(attr_name, - static_cast(value)); + auto s = op->operation->SetAttrType(attr_name, value); + if (!s.ok()) { + LOG(WARNING) << "Unable to set attribute: " << attr_name; + } } void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims, const int num_dims, TF_Status* out_status) { - if (num_dims > tensorflow::TensorShape::MaxDimensions()) { - TF_SetStatus(out_status, TF_INVALID_ARGUMENT, - tensorflow::strings::StrCat( - "Value specified for `", attr_name, "` has ", num_dims, - " dimensions which is over the limit of ", - tensorflow::TensorShape::MaxDimensions(), ".") - .c_str()); - return; - } - tensorflow::TensorShapeProto proto; - if (num_dims < 0) { - proto.set_unknown_rank(true); - } else { - for (int d = 0; d < num_dims; ++d) { - proto.add_dim()->set_size(dims[d]); - } - } - op->operation.MutableAttrs()->Set(attr_name, proto); + out_status->status = op->operation->SetAttrShape(attr_name, dims, num_dims); } void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name, const TFE_Op* value) { - tensorflow::AttrValue attr_value; - tensorflow::NameAttrList* func = attr_value.mutable_func(); - func->set_name(value->operation.Name()); - value->operation.Attrs().FillAttrValueMap(func->mutable_attr()); - op->operation.MutableAttrs()->Set(attr_name, attr_value); + auto s = op->operation->SetAttrFunction(attr_name, value->operation); + if (!s.ok()) { + LOG(WARNING) << "Unable to set attribute: " << attr_name; + } } void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name, const char* data, size_t length) { - tensorflow::AttrValue attr_value; - tensorflow::NameAttrList* func = attr_value.mutable_func(); - func->set_name(data, length); - op->operation.MutableAttrs()->Set(attr_name, attr_value); + auto s = op->operation->SetAttrFunctionName(attr_name, data, length); + if (!s.ok()) { + LOG(WARNING) << "Unable to set attribute: " << attr_name; + } } void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor, TF_Status* status) { - tensorflow::Tensor t; - status->status = TF_TensorToTensor(tensor, &t); - if (status->status.ok()) op->operation.MutableAttrs()->Set(attr_name, t); + status->status = op->operation->SetAttrTensor(attr_name, tensor); } void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name, const void* const* values, const size_t* lengths, int num_values) { - std::vector v(num_values); - for (int i = 0; i < num_values; ++i) { - v[i] = tensorflow::StringPiece(static_cast(values[i]), - lengths[i]); + auto s = + op->operation->SetAttrStringList(attr_name, values, lengths, num_values); + if (!s.ok()) { + LOG(WARNING) << "Unable to set attribute: " << attr_name; } - op->operation.MutableAttrs()->Set(attr_name, v); } void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name, const float* values, int num_values) { - op->operation.MutableAttrs()->Set( - attr_name, tensorflow::gtl::ArraySlice(values, num_values)); + auto s = op->operation->SetAttrFloatList(attr_name, values, num_values); + if (!s.ok()) { + LOG(WARNING) << "Unable to set attribute: " << attr_name; + } } void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name, const int64_t* values, int num_values) { - op->operation.MutableAttrs()->Set( - attr_name, tensorflow::gtl::ArraySlice( - reinterpret_cast(values), num_values)); + auto s = op->operation->SetAttrIntList(attr_name, values, num_values); + if (!s.ok()) { + LOG(WARNING) << "Unable to set attribute: " << attr_name; + } } void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name, const TF_DataType* values, int num_values) { - op->operation.MutableAttrs()->Set( - attr_name, - tensorflow::gtl::ArraySlice( - reinterpret_cast(values), num_values)); + auto s = op->operation->SetAttrTypeList(attr_name, values, num_values); + if (!s.ok()) { + LOG(WARNING) << "Unable to set attribute: " << attr_name; + } } void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name, const unsigned char* values, int num_values) { - std::unique_ptr b(new bool[num_values]); - for (int i = 0; i < num_values; ++i) { - b[i] = values[i]; + auto s = op->operation->SetAttrBoolList(attr_name, values, num_values); + if (!s.ok()) { + LOG(WARNING) << "Unable to set attribute: " << attr_name; } - op->operation.MutableAttrs()->Set( - attr_name, tensorflow::gtl::ArraySlice(b.get(), num_values)); } void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name, const int64_t** dims, const int* num_dims, int num_values, TF_Status* out_status) { - std::unique_ptr proto( - new tensorflow::TensorShapeProto[num_values]); - for (int i = 0; i < num_values; ++i) { - const auto num_dims_i = num_dims[i]; - - if (num_dims_i > tensorflow::TensorShape::MaxDimensions()) { - TF_SetStatus(out_status, TF_INVALID_ARGUMENT, - tensorflow::strings::StrCat( - "Value specified for `", attr_name, "` has ", num_dims_i, - " dimensions which is over the limit of ", - tensorflow::TensorShape::MaxDimensions(), ".") - .c_str()); - return; - } - if (num_dims_i < 0) { - proto[i].set_unknown_rank(true); - } else { - const int64_t* dims_i = dims[i]; - auto proto_i = &proto[i]; - for (int d = 0; d < num_dims_i; ++d) { - proto_i->add_dim()->set_size(dims_i[d]); - } - } - } - op->operation.MutableAttrs()->Set( - attr_name, tensorflow::gtl::ArraySlice( - proto.get(), num_values)); + out_status->status = + op->operation->SetAttrShapeList(attr_name, dims, num_dims, num_values); } void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name, const TFE_Op** value, int num_values) { - std::unique_ptr funcs( - new tensorflow::NameAttrList[num_values]); - for (int i = 0; i < num_values; i++) { - funcs[i].set_name(value[i]->operation.Name()); - value[i]->operation.Attrs().FillAttrValueMap(funcs[i].mutable_attr()); + auto s = op->operation->SetAttrFunctionList(attr_name, value, num_values); + if (!s.ok()) { + LOG(WARNING) << "Unable to set attribute: " << attr_name; } - op->operation.MutableAttrs()->Set( - attr_name, tensorflow::gtl::ArraySlice( - funcs.get(), num_values)); } TF_CAPI_EXPORT extern int TFE_OpGetInputLength(TFE_Op* op, const char* input_name, TF_Status* status) { - const tensorflow::OpDef* op_def = GetOpDef(op, status); - if (!status->status.ok()) { - return -1; - } - tensorflow::AttrValueMap attrs; - op->operation.Attrs().FillAttrValueMap(&attrs); - tensorflow::NameRangeMap name_ranges; - status->status = tensorflow::NameRangesForNode( - tensorflow::AttrSlice(&attrs), *op_def, &name_ranges, nullptr); - if (!status->status.ok()) { - return -1; - } - auto iter = name_ranges.find(input_name); - if (iter == name_ranges.end()) { - status->status = tensorflow::errors::InvalidArgument("Input '", input_name, - "' not found"); - return -1; - } - return iter->second.second - iter->second.first; + int ret = -1; + status->status = op->operation->InputLength(input_name, &ret); + return ret; } TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op, const char* output_name, TF_Status* status) { - const tensorflow::OpDef* op_def = GetOpDef(op, status); - if (!status->status.ok()) { - return -1; - } - tensorflow::AttrValueMap attrs; - op->operation.Attrs().FillAttrValueMap(&attrs); - tensorflow::NameRangeMap name_ranges; - status->status = tensorflow::NameRangesForNode( - tensorflow::AttrSlice(&attrs), *op_def, nullptr, &name_ranges); - if (!status->status.ok()) { - return -1; - } - auto iter = name_ranges.find(output_name); - if (iter == name_ranges.end()) { - status->status = tensorflow::errors::InvalidArgument( - "Output '", output_name, "' not found"); - return -1; - } - return iter->second.second - iter->second.first; + int ret = -1; + status->status = op->operation->OutputLength(output_name, &ret); + return ret; } void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, TF_Status* status) { - absl::FixedArray handle_retvals(*num_retvals); - VLOG(1) << "Calling TFE_Execute() on op " << op; - status->status = tensorflow::EagerExecute(&op->operation, - handle_retvals.data(), num_retvals); + absl::FixedArray> handles( + *num_retvals); + status->status = op->operation->Execute(&handles, num_retvals); if (!status->status.ok()) { return; } for (int i = 0; i < *num_retvals; ++i) { - retvals[i] = new TFE_TensorHandle{ - std::make_unique(handle_retvals[i])}; + retvals[i] = new TFE_TensorHandle{std::move(handles[i])}; } } @@ -1678,6 +1603,23 @@ void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context->StartStep(); } void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context->EndStep(); } +void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs) { + auto operation = tensorflow::down_cast( + op->operation.get()); + *attrs = TFE_OpAttrs(&operation->Attrs()); +} + +void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) { + tensorflow::AttrValueMap m; + attrs->attributes->FillAttrValueMap(&m); + auto operation = tensorflow::down_cast( + op->operation.get()); + tensorflow::AttrBuilder* destination = operation->MutableAttrs(); + for (auto attribute : m) { + destination->Set(attribute.first, attribute.second); + } +} + namespace tensorflow { void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, const tensorflow::AttrValue& default_value, @@ -1797,10 +1739,10 @@ class CustomDeviceAPI : public tensorflow::CustomDevice { op->Inputs()[i])}); } std::vector outputs(*num_retvals); - // TODO(allenl): figure out how to get attrs from EagerOperation TF_Status status; + TFE_OpAttrs attributes(&op->Attrs()); device_.execute(inputs.size(), inputs.data(), op->Name().c_str(), - num_retvals, outputs.data(), &status, info_); + &attributes, num_retvals, outputs.data(), &status, info_); if (status.status.ok()) { for (int i = 0; i < *num_retvals; ++i) { retvals[i] = tensorflow::down_cast( diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index 4f97d7b0517..afa36fe1210 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -31,20 +31,14 @@ using tensorflow::string; void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name, const char* raw_device_name, TF_Status* status) { if (op_to_reset) { - status->status = op_to_reset->operation.Reset( - op_or_function_name, raw_device_name, false, nullptr); + status->status = + op_to_reset->operation->Reset(op_or_function_name, raw_device_name); } else { TF_SetStatus(status, TF_INVALID_ARGUMENT, "op_to_reset should not be nullptr"); } } -void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) { - op->operation.ConsumeInput( - tensorflow::down_cast(h->handle.get()) - ->Handle()); -} - void TFE_ContextEnableGraphCollection(TFE_Context* ctx) { ctx->context->SetShouldStoreGraphs(true); } @@ -520,8 +514,7 @@ void TFE_DeleteCancellationManager( void TFE_OpSetCancellationManager(TFE_Op* op, TFE_CancellationManager* cancellation_manager, TF_Status* status) { - op->operation.SetCancellationManager( - &cancellation_manager->cancellation_manager); + status->status = op->operation->SetCancellationManager(cancellation_manager); } TFE_Executor* TFE_NewExecutor(bool is_async) { @@ -569,3 +562,22 @@ void TFE_TensorHandleEnableImplicitMirroring(TFE_TensorHandle* h, h->handle->EnableImplicitMirroring(); status->status = tensorflow::Status::OK(); } + +void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name, + TF_Buffer* buf, TF_Status* status) { + auto* function_def = ctx->context->FindFunctionDef(function_name); + if (function_def == nullptr) { + status->status = tensorflow::errors::NotFound( + "Unable to find FunctionDef with name: ", function_name); + return; + } + string str = function_def->SerializeAsString(); + void* data = tensorflow::port::Malloc(str.length()); + str.copy(static_cast(data), str.length(), 0); + buf->data = data; + buf->length = str.length(); + buf->data_deallocator = [](void* data, size_t length) { + tensorflow::port::Free(data); + }; + status->status = tensorflow::Status::OK(); +} diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h index 075b5d02fdc..92dab6a36c6 100644 --- a/tensorflow/c/eager/c_api_experimental.h +++ b/tensorflow/c/eager/c_api_experimental.h @@ -34,9 +34,6 @@ TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Op* op_to_reset, const char* raw_device_name, TF_Status* status); -TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, - TF_Status* status); - // Enables only graph collection in RunMetadata on the functions executed from // this context. TF_CAPI_EXPORT extern void TFE_ContextEnableGraphCollection(TFE_Context* ctx); @@ -424,7 +421,27 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory( TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf); -#define TFE_CUSTOM_DEVICE_VERSION 0 +// APIs for generically dealing with op attributes (e.g. when forwarding them +// through custom device implementations). +// +// TODO(allenl): Currently these are black boxes, but we should have some way to +// inspect values. This would let people e.g. copy over most attributes and then +// modify some based on their values. + +// A reference to an op's name -> attribute mapping +typedef struct TFE_OpAttrs TFE_OpAttrs; + +// Fetch a struct with a reference to information about attributes of `op`. +// +// The `attrs` struct does not own any memory, and `op` must outlive it. +TF_CAPI_EXPORT extern void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs); + +// Add attributes in `attrs` to `op`. +// +// Does not overwrite or update existing attributes, but adds new ones. +TF_CAPI_EXPORT extern void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs); + +#define TFE_CUSTOM_DEVICE_VERSION 1 // Struct to be filled in typedef struct TFE_CustomDevice { @@ -441,10 +458,10 @@ typedef struct TFE_CustomDevice { void* device_info); // Method to execute an operation. - // TODO(allenl) figure out a generic way of passing attrs here void (*execute)(int num_inputs, TFE_TensorHandle** inputs, - const char* operation_name, int* num_outputs, - TFE_TensorHandle** outputs, TF_Status* s, void* device_info); + const char* operation_name, const TFE_OpAttrs* attributes, + int* num_outputs, TFE_TensorHandle** outputs, TF_Status* s, + void* device_info); // Method to delete a device. void (*delete_device)(void* device_info); @@ -475,6 +492,11 @@ typedef struct TFE_CustomDevice { void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device, const char* device_name, void* device_info); +TF_CAPI_EXPORT extern void TFE_ContextGetFunctionDef(TFE_Context* ctx, + const char* function_name, + TF_Buffer* buf, + TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index f4bdcc05489..943890b6259 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -27,12 +27,12 @@ limitations under the License. #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/operation_interface.h" #include "tensorflow/c/eager/tensor_handle_interface.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/eager_executor.h" -#include "tensorflow/core/common_runtime/eager/eager_operation.h" #include "tensorflow/core/common_runtime/eager/kernel_and_device.h" #include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/common_runtime/function.h" @@ -89,7 +89,7 @@ struct TFE_TensorDebugInfo { }; struct TFE_Op { - tensorflow::EagerOperation operation; + std::unique_ptr operation; }; struct TFE_MonitoringCounterCell { @@ -236,4 +236,13 @@ struct TFE_Executor { tensorflow::EagerExecutor* unowned_executor; }; +struct TFE_OpAttrs { + explicit TFE_OpAttrs() : attributes(nullptr) {} + + explicit TFE_OpAttrs(const tensorflow::AttrBuilder* value) + : attributes(value) {} + + const tensorflow::AttrBuilder* attributes; +}; + #endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_ diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 9ae1e7b896b..7a089a30164 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -17,6 +17,8 @@ limitations under the License. #include +#include + #include "absl/strings/match.h" #include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_internal.h" @@ -367,7 +369,7 @@ TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevicesAsync) { void TensorHandleSilentCopy(bool async, TFE_ContextDevicePlacementPolicy global_policy, TFE_ContextDevicePlacementPolicy thread_policy, - bool cpu_op) { + bool mirror, bool cpu_op) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); TFE_ContextOptions* opts = TFE_NewContextOptions(); @@ -390,6 +392,12 @@ void TensorHandleSilentCopy(bool async, TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice( hcpu, ctx, gpu_device_name.c_str(), status.get()); ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); + if (mirror) { + TFE_TensorHandleEnableImplicitMirroring(hcpu, status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); + TFE_TensorHandleEnableImplicitMirroring(hgpu, status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); + } TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu); if (cpu_op) { @@ -414,10 +422,23 @@ void TensorHandleSilentCopy(bool async, hgpu->handle.get()) ->Handle(); - // The input handles should never change since they have been mirrored. - ASSERT_EQ(matmul->operation.Inputs()[0], arg0); - ASSERT_EQ(matmul->operation.Inputs()[1], arg1); - + auto op = tensorflow::down_cast( + matmul->operation.get()); + if (mirror) { + // The input handles should never change since they have been mirrored. + ASSERT_EQ(op->GetInput(0), arg0); + ASSERT_EQ(op->GetInput(1), arg1); + } else { + if (cpu_op) { + ASSERT_EQ(op->GetInput(0), arg0); + // The GPU handle should be replaced with a CPU copy + ASSERT_NE(op->GetInput(1), arg1); + } else { + // The CPU handle should be replaced with a GPU copy + ASSERT_NE(op->GetInput(0), arg0); + ASSERT_EQ(op->GetInput(1), arg1); + } + } TFE_DeleteOp(matmul); TFE_DeleteTensorHandle(retvals[0]); TFE_DeleteTensorHandle(hgpu); @@ -433,19 +454,27 @@ void TensorHandleSilentCopy(bool async, } TEST(CAPI, TensorHandleSilentCopy) { TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_SILENT, - TFE_DEVICE_PLACEMENT_SILENT, false); + TFE_DEVICE_PLACEMENT_SILENT, false, false); } TEST(CAPI, TensorHandleSilentCopyAsync) { TensorHandleSilentCopy(true, TFE_DEVICE_PLACEMENT_SILENT, - TFE_DEVICE_PLACEMENT_SILENT, false); + TFE_DEVICE_PLACEMENT_SILENT, false, false); } TEST(CAPI, TensorHandleSilentCopyLocalPolicy) { TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_EXPLICIT, - TFE_DEVICE_PLACEMENT_SILENT, false); + TFE_DEVICE_PLACEMENT_SILENT, false, false); } TEST(CAPI, TensorHandleSilentCopyLocalPolicyAsync) { TensorHandleSilentCopy(true, TFE_DEVICE_PLACEMENT_EXPLICIT, - TFE_DEVICE_PLACEMENT_SILENT, false); + TFE_DEVICE_PLACEMENT_SILENT, false, false); +} +TEST(CAPI, TensorHandleMirrorCopy) { + TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_SILENT, + TFE_DEVICE_PLACEMENT_SILENT, true, false); +} +TEST(CAPI, TensorHandleMirrorCopyCpu) { + TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_SILENT, + TFE_DEVICE_PLACEMENT_SILENT, true, true); } void SetAndGetOpDevices(bool async) { @@ -581,6 +610,91 @@ TEST(CAPI, TensorHandleDevices) { TFE_DeleteContext(ctx); } +void ExecuteAdd(bool async, bool forward_input) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* n = TestMatrixTensorHandle100x100(); + // If a GPU exists, copy the handle to GPU so that we can exercise + // unprotecting a mirror. + std::string gpu_device_name; + if (GetDeviceName(ctx, &gpu_device_name, "GPU")) { + TFE_TensorHandle* n_gpu = + TFE_TensorHandleCopyToDevice(n, ctx, gpu_device_name.c_str(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_TensorHandleEnableImplicitMirroring(n_gpu, status); + TFE_DeleteTensorHandle(n); + n = n_gpu; + } + + TFE_TensorHandle* m = TestMatrixTensorHandle100x100(); + + // Store pointer to raw buffer for validation of forwarding behaviour. + TF_Tensor* orig = TFE_TensorHandleResolve(n, status); + void* orig_ptr = TF_TensorData(orig); + TF_DeleteTensor(orig); + + TFE_Op* add_op = AddOp(ctx, n, m); + std::string cpu_device_name; + ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU")); + TFE_OpSetDevice(add_op, cpu_device_name.c_str(), status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + if (forward_input) { + TFE_DeleteTensorHandle(n); + } + + int num_retvals = 1; + + if (async) { + // Enqueue dummy ops so we backlog async execution & actually test async. + for (int i = 0; i < 10000; ++i) { + TFE_TensorHandle* dummy = nullptr; + TFE_Execute(add_op, &dummy, &num_retvals, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteTensorHandle(dummy); + } + } + + TFE_TensorHandle* retval = nullptr; + TFE_Execute(add_op, &retval, &num_retvals, status); + EXPECT_EQ(1, num_retvals); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + if (!forward_input) { + TFE_DeleteTensorHandle(n); + } + TFE_DeleteOp(add_op); + + TF_Tensor* t = TFE_TensorHandleResolve(retval, status); + if (forward_input || async) { + EXPECT_EQ(orig_ptr, TF_TensorData(t)); + } else { + EXPECT_NE(orig_ptr, TF_TensorData(t)); + } + + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteTensorHandle(m); + TFE_DeleteTensorHandle(retval); + TFE_DeleteContext(ctx); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + float result[100 * 100] = {0}; + EXPECT_EQ(sizeof(result), TF_TensorByteSize(t)); + memcpy(&result[0], TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + for (int i = 0; i < 100 * 100; ++i) { + EXPECT_EQ(2.0f, result[i]); + } + TF_DeleteStatus(status); +} +TEST(CAPI, ExecuteAdd) { ExecuteAdd(false, false); } +TEST(CAPI, ExecuteAddAsync) { ExecuteAdd(true, false); } +TEST(CAPI, ExecuteAddForward) { ExecuteAdd(false, true); } +TEST(CAPI, ExecuteAddForwardAsync) { ExecuteAdd(true, true); } + void Execute_MatMul_CPU(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); @@ -1219,6 +1333,14 @@ TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) { TFE_DeleteTensorHandle(h_shares_tensor); } +tensorflow::AttrValueMap ExtractAttrs(TFE_Op* op) { + tensorflow::AttrValueMap attr_values; + tensorflow::down_cast(op->operation.get()) + ->Attrs() + .FillAttrValueMap(&attr_values); + return attr_values; +} + TEST(CAPI, TestTFE_OpInferSingleInputAttrs) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); @@ -1235,8 +1357,7 @@ TEST(CAPI, TestTFE_OpInferSingleInputAttrs) { TFE_OpAddInput(minOp, axis, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - tensorflow::AttrValueMap attr_values; - minOp->operation.Attrs().FillAttrValueMap(&attr_values); + tensorflow::AttrValueMap attr_values = ExtractAttrs(minOp); tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T"); EXPECT_NE(attr_found, attr_values.cend()); EXPECT_EQ(attr_found->second.type(), tensorflow::DataType::DT_FLOAT); @@ -1275,8 +1396,7 @@ TEST(CAPI, TestTFE_OpInferSingleTypeInputListAttrs) { TFE_OpAddInputList(concatOp, inputs, 2, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - tensorflow::AttrValueMap attr_values; - concatOp->operation.Attrs().FillAttrValueMap(&attr_values); + tensorflow::AttrValueMap attr_values = ExtractAttrs(concatOp); tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T"); EXPECT_NE(attr_found, attr_values.cend()); EXPECT_EQ(attr_found->second.type(), tensorflow::DataType::DT_FLOAT); @@ -1316,8 +1436,7 @@ TEST(CAPI, TestTFE_OpInferMixedTypeInputListAttrs) { TFE_OpAddInputList(assertOp, data, 3, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - tensorflow::AttrValueMap attr_values; - assertOp->operation.Attrs().FillAttrValueMap(&attr_values); + tensorflow::AttrValueMap attr_values = ExtractAttrs(assertOp); tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T"); EXPECT_NE(attr_found, attr_values.cend()); EXPECT_EQ(attr_found->second.list().type(0), tensorflow::DataType::DT_BOOL); @@ -1353,16 +1472,15 @@ TEST(CAPI, TestTFE_OpAttrsInferenceDisabledWhenNotCallingOpAddInputList) { TFE_TensorHandle* inputs[] = {input1, input2}; TFE_OpAddInput(concatOp, dim, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - CHECK(concatOp->operation.OpDef()); + CHECK(concatOp->operation->OpDef()); TFE_OpAddInput(concatOp, inputs[0], status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - EXPECT_FALSE(concatOp->operation.OpDef()) + EXPECT_FALSE(concatOp->operation->OpDef()) << "Inference context is still present"; TFE_OpAddInput(concatOp, inputs[1], status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - tensorflow::AttrValueMap attr_values; - concatOp->operation.Attrs().FillAttrValueMap(&attr_values); + tensorflow::AttrValueMap attr_values = ExtractAttrs(concatOp); EXPECT_EQ(attr_values.find("T"), attr_values.end()); EXPECT_EQ(attr_values.find("N"), attr_values.end()); @@ -1449,4 +1567,40 @@ TEST(CAPI, TestTFE_OpGetInputAndOutputLengthsFailForUnknownArguments) { TFE_DeleteContext(ctx); } +TEST(CAPI, TestTFE_OpGetAttrs) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_Op* var_op = TFE_NewOp(ctx, "VarHandleOp", status); + TFE_OpSetAttrType(var_op, "dtype", TF_INT64); + TFE_OpSetAttrShape(var_op, "shape", {}, 0, status); + TFE_OpAttrs attributes; + TFE_OpGetAttrs(var_op, &attributes); + + TFE_Op* copy_op = TFE_NewOp(ctx, "VarHandleOp", status); + TFE_OpSetAttrType(copy_op, "dtype", TF_FLOAT); + TFE_OpAddAttrs(copy_op, &attributes); + unsigned char is_list = 0; + ASSERT_EQ(TF_ATTR_TYPE, + TFE_OpGetAttrType(copy_op, "dtype", &is_list, status)); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(TF_ATTR_SHAPE, + TFE_OpGetAttrType(copy_op, "shape", &is_list, status)); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + tensorflow::AttrValueMap attr_values; + auto op = tensorflow::down_cast( + copy_op->operation.get()); + op->Attrs().FillAttrValueMap(&attr_values); + EXPECT_EQ(tensorflow::DT_FLOAT, attr_values.find("dtype")->second.type()); + + TF_DeleteStatus(status); + TFE_DeleteOp(var_op); + TFE_DeleteOp(copy_op); + TFE_DeleteContext(ctx); +} + } // namespace diff --git a/tensorflow/c/eager/c_api_test_util.cc b/tensorflow/c/eager/c_api_test_util.cc index 51566b35a9f..bee76fe296f 100644 --- a/tensorflow/c/eager/c_api_test_util.cc +++ b/tensorflow/c/eager/c_api_test_util.cc @@ -131,6 +131,21 @@ TFE_TensorHandle* TestMatrixTensorHandle3X2() { return th; } +TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) { + TF_Status* status = TF_NewStatus(); + + TFE_Op* op = TFE_NewOp(ctx, "AddV2", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, a, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, b, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a)); + + return op; +} + TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) { TF_Status* status = TF_NewStatus(); diff --git a/tensorflow/c/eager/c_api_test_util.h b/tensorflow/c/eager/c_api_test_util.h index 28062222cf0..2c2f8323363 100644 --- a/tensorflow/c/eager/c_api_test_util.h +++ b/tensorflow/c/eager/c_api_test_util.h @@ -42,6 +42,9 @@ TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2(); // Return a tensor handle containing a 3x2 matrix of floats TFE_TensorHandle* TestMatrixTensorHandle3X2(); +// Return an add op multiplying `a` by `b`. +TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b); + // Return a matmul op multiplying `a` by `b`. TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b); diff --git a/tensorflow/c/eager/custom_device_test.cc b/tensorflow/c/eager/custom_device_test.cc index be2cdd3bd1c..742844c3f75 100644 --- a/tensorflow/c/eager/custom_device_test.cc +++ b/tensorflow/c/eager/custom_device_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/c/tf_status.h" +#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/platform/test.h" namespace { @@ -31,6 +32,8 @@ struct LoggingDevice { tensorflow::string underlying_device; // Set to true whenever a TensorHandle is copied onto the device bool* arrived_flag; + // Set to true whenever an operation is executed + bool* executed_flag; }; struct LoggedTensor { @@ -81,12 +84,14 @@ TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_TensorHandle* tensor, } void LoggingDeviceExecute(int num_inputs, TFE_TensorHandle** inputs, - const char* operation_name, int* num_outputs, + const char* operation_name, + const TFE_OpAttrs* attributes, int* num_outputs, TFE_TensorHandle** outputs, TF_Status* s, void* device_info) { LoggingDevice* dev = reinterpret_cast(device_info); TFE_Op* op(TFE_NewOp(dev->ctx, operation_name, s)); if (TF_GetCode(s) != TF_OK) return; + TFE_OpAddAttrs(op, attributes); TFE_OpSetDevice(op, dev->underlying_device.c_str(), s); for (int j = 0; j < num_inputs; ++j) { TFE_TensorHandle* input = inputs[j]; @@ -115,6 +120,7 @@ void LoggingDeviceExecute(int num_inputs, TFE_TensorHandle** inputs, outputs[i] = MakeLoggedTensorHandle(dev->ctx, dev->device_name, std::move(logged_tensor), s); } + *(dev->executed_flag) = true; } void DeleteLoggingDevice(void* device_info) { @@ -122,7 +128,7 @@ void DeleteLoggingDevice(void* device_info) { } void RegisterLoggingDevice(TFE_Context* context, const char* name, - bool* arrived_flag) { + bool* arrived_flag, bool* executed_flag) { TFE_CustomDevice custom_device; custom_device.copy_tensor_to_device = &CopyToLoggingDevice; custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice; @@ -131,6 +137,7 @@ void RegisterLoggingDevice(TFE_Context* context, const char* name, LoggingDevice* device = new LoggingDevice; device->ctx = context; device->arrived_flag = arrived_flag; + device->executed_flag = executed_flag; device->device_name = name; device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0"; TFE_RegisterCustomDevice(context, custom_device, name, device); @@ -144,13 +151,15 @@ TEST(CUSTOM_DEVICE, RegisterSimpleDevice) { TFE_DeleteContextOptions(opts); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); bool arrived = false; + bool executed = false; const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; - RegisterLoggingDevice(context, name, &arrived); + RegisterLoggingDevice(context, name, &arrived, &executed); TFE_TensorHandle* hcpu = TestMatrixTensorHandle(); ASSERT_FALSE(arrived); TFE_TensorHandle* hdevice = TFE_TensorHandleCopyToDevice(hcpu, context, name, status.get()); ASSERT_TRUE(arrived); + ASSERT_FALSE(executed); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); std::unique_ptr matmul( MatMulOp(context, hcpu, hdevice), TFE_DeleteOp); @@ -160,6 +169,7 @@ TEST(CUSTOM_DEVICE, RegisterSimpleDevice) { int num_retvals = 1; TFE_Execute(matmul.get(), &retval, &num_retvals, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_TRUE(executed); TFE_DeleteTensorHandle(retval); TFE_DeleteTensorHandle(hcpu); @@ -167,4 +177,118 @@ TEST(CUSTOM_DEVICE, RegisterSimpleDevice) { TFE_DeleteContext(context); } +TEST(CUSTOM_DEVICE, ResetOperation) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + std::unique_ptr context( + TFE_NewContext(opts, status.get()), TFE_DeleteContext); + TFE_DeleteContextOptions(opts); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + bool arrived = false; + bool executed = false; + const char* custom_device_name = + "/job:localhost/replica:0/task:0/device:CUSTOM:0"; + RegisterLoggingDevice(context.get(), custom_device_name, &arrived, &executed); + + std::unique_ptr reused_op( + TFE_NewOp(context.get(), "Identity", status.get()), TFE_DeleteOp); + TFE_OpReset(reused_op.get(), "Identity", custom_device_name, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(tensorflow::string(TFE_OpGetDevice(reused_op.get(), status.get())), + tensorflow::string(custom_device_name)); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_OpReset(reused_op.get(), "Identity", + "/job:localhost/replica:0/task:0/device:CPU:0", status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(tensorflow::string(TFE_OpGetDevice(reused_op.get(), status.get())), + tensorflow::string("/job:localhost/replica:0/task:0/device:CPU:0")); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); +} + +TEST(CUSTOM_DEVICE, MakeVariable) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + std::unique_ptr opts( + TFE_NewContextOptions(), TFE_DeleteContextOptions); + std::unique_ptr context( + TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + bool arrived = false; + bool executed = false; + const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; + RegisterLoggingDevice(context.get(), name, &arrived, &executed); + + // Create a variable handle placed on the custom device. + std::unique_ptr op( + TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT); + TFE_OpSetAttrShape(op.get(), "shape", {}, 0, status.get()); + TFE_OpSetAttrString(op.get(), "container", "", 0); + TFE_OpSetAttrString(op.get(), "shared_name", "", 0); + TFE_OpSetDevice(op.get(), name, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_TensorHandle* var_handle = nullptr; + int num_retvals = 1; + executed = false; + TFE_Execute(op.get(), &var_handle, &num_retvals, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_TRUE(executed); + auto handle_cleaner = tensorflow::gtl::MakeCleanup( + [var_handle]() { TFE_DeleteTensorHandle(var_handle); }); + + // Assign to the variable, copying to the custom device. + std::unique_ptr one( + TestScalarTensorHandle(111.f), TFE_DeleteTensorHandle); + op.reset(TFE_NewOp(context.get(), "AssignVariableOp", status.get())); + TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT); + TFE_OpAddInput(op.get(), var_handle, status.get()); + TFE_OpAddInput(op.get(), one.get(), status.get()); + TFE_OpSetDevice(op.get(), name, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + executed = false; + num_retvals = 0; + TFE_Execute(op.get(), nullptr, &num_retvals, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_TRUE(executed); + + // Read the variable's value. + op.reset(TFE_NewOp(context.get(), "ReadVariableOp", status.get())); + TFE_OpAddInput(op.get(), var_handle, status.get()); + TFE_OpSetDevice(op.get(), name, status.get()); + TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + executed = false; + num_retvals = 1; + TFE_TensorHandle* var_value = nullptr; + TFE_Execute(op.get(), &var_value, &num_retvals, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_TRUE(executed); + auto value_cleaner = tensorflow::gtl::MakeCleanup( + [var_value]() { TFE_DeleteTensorHandle(var_value); }); + ASSERT_EQ(tensorflow::string(name), + tensorflow::string( + TFE_TensorHandleBackingDeviceName(var_value, status.get()))); + TFE_TensorHandle* var_value_unpacked = + reinterpret_cast( + TFE_TensorHandleDevicePointer(var_value, status.get())) + ->tensor; + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + std::unique_ptr resolved_value( + TFE_TensorHandleResolve(var_value_unpacked, status.get()), + TF_DeleteTensor); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(111., *static_cast(TF_TensorData(resolved_value.get()))); + + // Free the backing buffer for the variable. + op.reset(TFE_NewOp(context.get(), "DestroyResourceOp", status.get())); + TFE_OpAddInput(op.get(), var_handle, status.get()); + TFE_OpSetDevice(op.get(), name, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + num_retvals = 0; + TFE_Execute(op.get(), nullptr, &num_retvals, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); +} + } // namespace diff --git a/tensorflow/c/eager/operation_interface.cc b/tensorflow/c/eager/operation_interface.cc new file mode 100644 index 00000000000..5703d3231bd --- /dev/null +++ b/tensorflow/c/eager/operation_interface.cc @@ -0,0 +1,312 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/operation_interface.h" + +#include "absl/container/fixed_array.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_internal.h" +#include "tensorflow/c/eager/tensor_handle_interface.h" +#include "tensorflow/core/common_runtime/eager/eager_operation.h" +#include "tensorflow/core/common_runtime/eager/execute.h" +#include "tensorflow/core/platform/casts.h" +#include "tensorflow/core/platform/errors.h" + +namespace tensorflow { + +OperationInterface::OperationInterface(TFE_Context* ctx) + : operation_(ctx->context) {} + +const string& OperationInterface::DeviceName() const { + absl::variant variant_device = + (operation_.Device() == kVariantDeviceNull) + ? operation_.EagerContext().HostCPU() + : operation_.Device(); + return absl::visit([](auto* d) -> const string& { return d->name(); }, + variant_device); +} + +Status OperationInterface::SetDeviceName(const char* name) { + return operation_.SetDeviceName(name); +} + +Status OperationInterface::SetAttrString(const char* attr_name, + const char* data, size_t length) { + operation_.MutableAttrs()->Set(attr_name, StringPiece(data, length)); + return Status::OK(); +} + +Status OperationInterface::SetAttrInt(const char* attr_name, int64_t value) { + operation_.MutableAttrs()->Set(attr_name, static_cast(value)); + return Status::OK(); +} + +Status OperationInterface::SetAttrFloat(const char* attr_name, float value) { + operation_.MutableAttrs()->Set(attr_name, value); + return Status::OK(); +} + +Status OperationInterface::SetAttrBool(const char* attr_name, bool value) { + operation_.MutableAttrs()->Set(attr_name, value); + return Status::OK(); +} + +Status OperationInterface::SetAttrType(const char* attr_name, + TF_DataType value) { + operation_.MutableAttrs()->Set(attr_name, static_cast(value)); + return Status::OK(); +} + +Status OperationInterface::SetAttrShape(const char* attr_name, + const int64_t* dims, + const int num_dims) { + if (num_dims > TensorShape::MaxDimensions()) { + return errors::InvalidArgument("Value specified for `", attr_name, "` has ", + num_dims, + " dimensions which is over the limit of ", + TensorShape::MaxDimensions(), "."); + } + + TensorShapeProto proto; + if (num_dims < 0) { + proto.set_unknown_rank(true); + } else { + for (int d = 0; d < num_dims; ++d) { + proto.add_dim()->set_size(dims[d]); + } + } + + operation_.MutableAttrs()->Set(attr_name, proto); + + return Status::OK(); +} + +Status OperationInterface::SetAttrFunction( + const char* attr_name, + const std::unique_ptr& value) { + AttrValue attr_value; + NameAttrList* func = attr_value.mutable_func(); + func->set_name(value->Name()); + OperationInterface* value_operation = + tensorflow::down_cast(value.get()); + value_operation->operation_.Attrs().FillAttrValueMap(func->mutable_attr()); + operation_.MutableAttrs()->Set(attr_name, attr_value); + return Status::OK(); +} + +Status OperationInterface::SetAttrFunctionName(const char* attr_name, + const char* data, + size_t length) { + AttrValue attr_value; + NameAttrList* func = attr_value.mutable_func(); + func->set_name(data, length); + operation_.MutableAttrs()->Set(attr_name, attr_value); + return Status::OK(); +} + +Status OperationInterface::SetAttrTensor(const char* attr_name, + TF_Tensor* tensor) { + Tensor t; + TF_RETURN_IF_ERROR(TF_TensorToTensor(tensor, &t)); + operation_.MutableAttrs()->Set(attr_name, t); + return Status::OK(); +} + +Status OperationInterface::SetAttrStringList(const char* attr_name, + const void* const* values, + const size_t* lengths, + int num_values) { + std::vector v(num_values); + for (int i = 0; i < num_values; ++i) { + v[i] = StringPiece(static_cast(values[i]), lengths[i]); + } + operation_.MutableAttrs()->Set(attr_name, v); + + return Status::OK(); +} + +Status OperationInterface::SetAttrFloatList(const char* attr_name, + const float* values, + int num_values) { + operation_.MutableAttrs()->Set( + attr_name, gtl::ArraySlice(values, num_values)); + return Status::OK(); +} + +Status OperationInterface::SetAttrIntList(const char* attr_name, + const int64_t* values, + int num_values) { + operation_.MutableAttrs()->Set( + attr_name, gtl::ArraySlice( + reinterpret_cast(values), num_values)); + return Status::OK(); +} + +Status OperationInterface::SetAttrTypeList(const char* attr_name, + const TF_DataType* values, + int num_values) { + operation_.MutableAttrs()->Set( + attr_name, gtl::ArraySlice( + reinterpret_cast(values), num_values)); + return Status::OK(); +} + +Status OperationInterface::SetAttrBoolList(const char* attr_name, + const unsigned char* values, + int num_values) { + std::unique_ptr b(new bool[num_values]); + for (int i = 0; i < num_values; ++i) { + b[i] = values[i]; + } + operation_.MutableAttrs()->Set( + attr_name, gtl::ArraySlice(b.get(), num_values)); + return Status::OK(); +} + +Status OperationInterface::SetAttrShapeList(const char* attr_name, + const int64_t** dims, + const int* num_dims, + int num_values) { + std::unique_ptr proto(new TensorShapeProto[num_values]); + for (int i = 0; i < num_values; ++i) { + const auto num_dims_i = num_dims[i]; + + if (num_dims_i > TensorShape::MaxDimensions()) { + return errors::InvalidArgument( + strings::StrCat("Value specified for `", attr_name, "` has ", + num_dims_i, " dimensions which is over the limit of ", + TensorShape::MaxDimensions(), ".")); + } + if (num_dims_i < 0) { + proto[i].set_unknown_rank(true); + } else { + const int64_t* dims_i = dims[i]; + auto proto_i = &proto[i]; + for (int d = 0; d < num_dims_i; ++d) { + proto_i->add_dim()->set_size(dims_i[d]); + } + } + } + operation_.MutableAttrs()->Set( + attr_name, gtl::ArraySlice(proto.get(), num_values)); + return Status::OK(); +} + +Status OperationInterface::SetAttrFunctionList(const char* attr_name, + const TFE_Op** value, + int num_values) { + std::unique_ptr funcs(new NameAttrList[num_values]); + for (int i = 0; i < num_values; i++) { + auto value_operation = + tensorflow::down_cast(value[i]->operation.get()); + funcs[i].set_name(value_operation->operation_.Name()); + value_operation->operation_.Attrs().FillAttrValueMap( + funcs[i].mutable_attr()); + } + operation_.MutableAttrs()->Set( + attr_name, gtl::ArraySlice(funcs.get(), num_values)); + return Status::OK(); +} + +const OpDef* OperationInterface::GetOpDef(Status* status) { + const tensorflow::OpDef* op_def = operation_.OpDef(); + if (op_def) return op_def; + *status = OpDefForOp(Name(), &op_def); + return op_def; +} + +Status OperationInterface::InputLength(const char* input_name, int* length) { + Status status; + const tensorflow::OpDef* op_def = GetOpDef(&status); + if (!status.ok()) { + return status; + } + AttrValueMap attrs; + operation_.Attrs().FillAttrValueMap(&attrs); + NameRangeMap name_ranges; + TF_RETURN_IF_ERROR( + NameRangesForNode(AttrSlice(&attrs), *op_def, &name_ranges, nullptr)); + auto iter = name_ranges.find(input_name); + if (iter == name_ranges.end()) { + return errors::InvalidArgument("Input '", input_name, "' not found"); + } + *length = iter->second.second - iter->second.first; + return Status::OK(); +} + +Status OperationInterface::OutputLength(const char* output_name, int* length) { + Status status; + const tensorflow::OpDef* op_def = GetOpDef(&status); + if (!status.ok()) { + return status; + } + AttrValueMap attrs; + operation_.Attrs().FillAttrValueMap(&attrs); + NameRangeMap name_ranges; + TF_RETURN_IF_ERROR( + NameRangesForNode(AttrSlice(&attrs), *op_def, nullptr, &name_ranges)); + auto iter = name_ranges.find(output_name); + if (iter == name_ranges.end()) { + return errors::InvalidArgument("Output '", output_name, "' not found"); + } + *length = iter->second.second - iter->second.first; + return Status::OK(); +} + +Status OperationInterface::AddInput( + const std::unique_ptr& input) { + TensorHandle* h = + tensorflow::down_cast(input.get())->Handle(); + operation_.AddInput(h); + return operation_.MaybeInferSingleInputAttrs(h); +} + +Status OperationInterface::AddInputList( + const absl::FixedArray>& + inputs) { + for (auto& input : inputs) { + TensorHandle* h = + tensorflow::down_cast(input.get())->Handle(); + operation_.AddInput(h); + } + return operation_.InferInputListAttrs(inputs.size()); +} + +Status OperationInterface::Execute( + absl::FixedArray>* retvals, + int* num_retvals) { + absl::FixedArray handle_retvals(*num_retvals); + TF_RETURN_IF_ERROR( + EagerExecute(&operation_, handle_retvals.data(), num_retvals)); + for (int i = 0; i < *num_retvals; ++i) { + retvals->at(i).reset( + new tensorflow::TensorHandleInterface(handle_retvals[i])); + } + return Status::OK(); +} + +Status OperationInterface::SetCancellationManager( + TFE_CancellationManager* cancellation_manager) { + operation_.SetCancellationManager( + &cancellation_manager->cancellation_manager); + return Status::OK(); +} + +Status OperationInterface::SetUseXla(bool enable) { + operation_.SetUseXla(enable); + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/c/eager/operation_interface.h b/tensorflow/c/eager/operation_interface.h new file mode 100644 index 00000000000..900c5112c08 --- /dev/null +++ b/tensorflow/c/eager/operation_interface.h @@ -0,0 +1,188 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_ +#define TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_ + +#include + +#include "absl/container/fixed_array.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/tensor_handle_interface.h" +#include "tensorflow/core/common_runtime/eager/eager_operation.h" + +// Abstract interface to an operation. +class AbstractOperationInterface { + public: + virtual ~AbstractOperationInterface() {} + + virtual void Clear() = 0; + virtual tensorflow::Status Reset(const char* op, + const char* raw_device_name) = 0; + + virtual const tensorflow::string& Name() const = 0; + virtual const tensorflow::string& DeviceName() const = 0; + virtual tensorflow::Status SetDeviceName(const char* name) = 0; + + virtual tensorflow::Status AddInput( + const std::unique_ptr& input) = 0; + virtual tensorflow::Status AddInputList( + const absl::FixedArray>& + inputs) = 0; + virtual tensorflow::Status Execute( + absl::FixedArray>* retvals, + int* num_retvals) = 0; + virtual const tensorflow::OpDef* OpDef() const = 0; + + virtual tensorflow::Status SetAttrString(const char* attr_name, + const char* data, size_t length) = 0; + virtual tensorflow::Status SetAttrInt(const char* attr_name, + int64_t value) = 0; + virtual tensorflow::Status SetAttrFloat(const char* attr_name, + float value) = 0; + virtual tensorflow::Status SetAttrBool(const char* attr_name, bool value) = 0; + virtual tensorflow::Status SetAttrType(const char* attr_name, + TF_DataType value) = 0; + virtual tensorflow::Status SetAttrShape(const char* attr_name, + const int64_t* dims, + const int num_dims) = 0; + virtual tensorflow::Status SetAttrFunction( + const char* attr_name, + const std::unique_ptr& value) = 0; + virtual tensorflow::Status SetAttrFunctionName(const char* attr_name, + const char* value, + size_t length) = 0; + virtual tensorflow::Status SetAttrTensor(const char* attr_name, + TF_Tensor* tensor) = 0; + virtual tensorflow::Status SetAttrStringList(const char* attr_name, + const void* const* values, + const size_t* lengths, + int num_values) = 0; + virtual tensorflow::Status SetAttrFloatList(const char* attr_name, + const float* values, + int num_values) = 0; + virtual tensorflow::Status SetAttrIntList(const char* attr_name, + const int64_t* values, + int num_values) = 0; + virtual tensorflow::Status SetAttrTypeList(const char* attr_name, + const TF_DataType* values, + int num_values) = 0; + virtual tensorflow::Status SetAttrBoolList(const char* attr_name, + const unsigned char* values, + int num_values) = 0; + virtual tensorflow::Status SetAttrShapeList(const char* attr_name, + const int64_t** dims, + const int* num_dims, + int num_values) = 0; + virtual tensorflow::Status SetAttrFunctionList(const char* attr_name, + const TFE_Op** value, + int num_values) = 0; + + virtual tensorflow::Status InputLength(const char* input_name, + int* length) = 0; + virtual tensorflow::Status OutputLength(const char* output_name, + int* length) = 0; + + // Experimental + virtual tensorflow::Status SetUseXla(bool enable) { + return tensorflow::errors::Unimplemented("SetUseXla not implemented"); + } + virtual tensorflow::Status SetCancellationManager( + TFE_CancellationManager* cancellation_manager) { + return tensorflow::errors::Unimplemented( + "SetCancellationManager not implemented"); + } +}; + +namespace tensorflow { + +class OpDef; + +class OperationInterface : public AbstractOperationInterface { + public: + explicit OperationInterface(TFE_Context* ctx); + ~OperationInterface() override{}; + + void Clear() override { operation_.Clear(); } + Status Reset(const char* op, const char* raw_device_name) override { + return operation_.Reset(op, raw_device_name, false, nullptr); + } + + const string& Name() const override { return operation_.Name(); } + const string& DeviceName() const override; + Status SetDeviceName(const char* name) override; + + Status AddInput( + const std::unique_ptr& input) override; + Status AddInputList( + const absl::FixedArray>& + inputs) override; + Status Execute( + absl::FixedArray>* retvals, + int* num_retvals) override; + const tensorflow::OpDef* OpDef() const override { + return operation_.OpDef(); + }; + + Status SetAttrString(const char* attr_name, const char* data, + size_t length) override; + Status SetAttrInt(const char* attr_name, int64_t value) override; + Status SetAttrFloat(const char* attr_name, float value) override; + Status SetAttrBool(const char* attr_name, bool value) override; + Status SetAttrType(const char* attr_name, TF_DataType value) override; + Status SetAttrShape(const char* attr_name, const int64_t* dims, + const int num_dims) override; + Status SetAttrFunction( + const char* attr_name, + const std::unique_ptr& value) override; + Status SetAttrFunctionName(const char* attr_name, const char* data, + size_t length) override; + Status SetAttrTensor(const char* attr_name, TF_Tensor* tensor) override; + Status SetAttrStringList(const char* attr_name, const void* const* values, + const size_t* lengths, int num_values) override; + Status SetAttrFloatList(const char* attr_name, const float* values, + int num_values) override; + Status SetAttrIntList(const char* attr_name, const int64_t* values, + int num_values) override; + Status SetAttrTypeList(const char* attr_name, const TF_DataType* values, + int num_values) override; + Status SetAttrBoolList(const char* attr_name, const unsigned char* values, + int num_values) override; + Status SetAttrShapeList(const char* attr_name, const int64_t** dims, + const int* num_dims, int num_values) override; + Status SetAttrFunctionList(const char* attr_name, const TFE_Op** value, + int num_values) override; + + Status InputLength(const char* input_name, int* length) override; + Status OutputLength(const char* output_name, int* length) override; + + Status SetUseXla(bool enable) override; + Status SetCancellationManager( + TFE_CancellationManager* cancellation_manager) override; + + // TODO(gjn): Remove once TFE_InferShapes is removed + const tensorflow::AttrBuilder& Attrs() const { return operation_.Attrs(); } + tensorflow::AttrBuilder* MutableAttrs() { return operation_.MutableAttrs(); } + + const TensorHandle* GetInput(int i) const { return operation_.Inputs()[i]; } + + private: + const tensorflow::OpDef* GetOpDef(Status* status); + EagerOperation operation_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_ diff --git a/tensorflow/c/tf_tensor.cc b/tensorflow/c/tf_tensor.cc index 6bb2cafbbc5..4e75beceb3e 100644 --- a/tensorflow/c/tf_tensor.cc +++ b/tensorflow/c/tf_tensor.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/c/tf_tensor.h" #include +#include #include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status_helper.h" @@ -64,25 +65,41 @@ void deallocate_buffer(void* data, size_t len, void* arg) { } } // namespace tensorflow +namespace { +TF_Tensor* CreateTensor(TF_ManagedBuffer* buf, TF_DataType dtype, + const int64_t* dims, int num_dims, size_t len) { + std::vector dimvec(num_dims); + for (int i = 0; i < num_dims; ++i) { + dimvec[i] = static_cast(dims[i]); + } + + // TODO(gjn): Make the choice of interface a compile-time configuration. + tensorflow::TensorInterface ret( + Tensor(static_cast(dtype), + tensorflow::TensorShape(dimvec), buf)); + buf->Unref(); + size_t elem_size = TF_DataTypeSize(dtype); + if (elem_size > 0 && len < (elem_size * ret.NumElements())) { + return nullptr; + } + return new TF_Tensor{std::make_unique(ret)}; +} +} // namespace TF_Tensor* TF_AllocateTensor(TF_DataType dtype, const int64_t* dims, int num_dims, size_t len) { void* data = tensorflow::allocate_tensor("TF_AllocateTensor", len, tensorflow::cpu_allocator()); - return TF_NewTensor(dtype, dims, num_dims, data, len, - tensorflow::deallocate_buffer, - tensorflow::cpu_allocator()); + TF_ManagedBuffer* buf = + new TF_ManagedBuffer(data, len, tensorflow::deallocate_buffer, + tensorflow::cpu_allocator(), /*owns_memory=*/true); + return CreateTensor(buf, dtype, dims, num_dims, len); } TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims, void* data, size_t len, void (*deallocator)(void* data, size_t len, void* arg), void* deallocator_arg) { - std::vector dimvec(num_dims); - for (int i = 0; i < num_dims; ++i) { - dimvec[i] = static_cast(dims[i]); - } - TF_ManagedBuffer* buf = nullptr; if (dtype != TF_STRING && dtype != TF_RESOURCE && tensorflow::DataTypeCanUseMemcpy( @@ -97,24 +114,17 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims, // Other types have the same representation, so copy only if it is safe to // do so. buf = new TF_ManagedBuffer(tensorflow::allocate_tensor("TF_NewTensor", len), - len, tensorflow::deallocate_buffer, nullptr); + len, tensorflow::deallocate_buffer, nullptr, + /*owns_memory=*/true); std::memcpy(buf->data(), data, len); // Free the original buffer. deallocator(data, len, deallocator_arg); } else { - buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg); + buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg, + /*owns_memory=*/false); } - // TODO(gjn): Make the choice of interface a compile-time configuration. - tensorflow::TensorInterface ret( - Tensor(static_cast(dtype), - tensorflow::TensorShape(dimvec), buf)); - buf->Unref(); - size_t elem_size = TF_DataTypeSize(dtype); - if (elem_size > 0 && len < (elem_size * ret.NumElements())) { - return nullptr; - } - return new TF_Tensor{std::make_unique(ret)}; + return CreateTensor(buf, dtype, dims, num_dims, len); } TF_Tensor* TF_TensorMaybeMove(TF_Tensor* t) { diff --git a/tensorflow/c/tf_tensor_internal.h b/tensorflow/c/tf_tensor_internal.h index 7ce6e637b2b..08a55f26a83 100644 --- a/tensorflow/c/tf_tensor_internal.h +++ b/tensorflow/c/tf_tensor_internal.h @@ -38,11 +38,12 @@ class TF_ManagedBuffer : public tensorflow::TensorBuffer { public: TF_ManagedBuffer(void* data, size_t len, void (*deallocator)(void* data, size_t len, void* arg), - void* deallocator_arg) + void* deallocator_arg, bool owns_memory) : TensorBuffer(data), len_(len), deallocator_(deallocator), - deallocator_arg_(deallocator_arg) {} + deallocator_arg_(deallocator_arg), + owns_memory_(owns_memory) {} ~TF_ManagedBuffer() override { (*deallocator_)(data(), len_, deallocator_arg_); @@ -57,13 +58,13 @@ class TF_ManagedBuffer : public tensorflow::TensorBuffer { proto->set_allocator_name(tensorflow::cpu_allocator()->Name()); } - // Prevents input forwarding from mutating this buffer. - bool OwnsMemory() const override { return false; } + bool OwnsMemory() const override { return owns_memory_; } private: const size_t len_; void (*const deallocator_)(void* data, size_t len, void* arg); void* const deallocator_arg_; + bool owns_memory_; }; namespace tensorflow { diff --git a/tensorflow/cc/gradients/array_grad.cc b/tensorflow/cc/gradients/array_grad.cc index e9173227aad..3c0813bfe23 100644 --- a/tensorflow/cc/gradients/array_grad.cc +++ b/tensorflow/cc/gradients/array_grad.cc @@ -15,13 +15,12 @@ limitations under the License. #include +#include "tensorflow/cc/framework/grad_op_registry.h" +#include "tensorflow/cc/framework/gradients.h" #include "tensorflow/cc/ops/array_ops_internal.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/cc/framework/grad_op_registry.h" -#include "tensorflow/cc/framework/gradients.h" - namespace tensorflow { namespace ops { namespace { @@ -90,15 +89,25 @@ Status QuantizeAndDequantizeGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("QuantizeAndDequantize", QuantizeAndDequantizeGrad); -Status QuantizeAndDequantizeV2Grad(const Scope& scope, const Operation& op, - const std::vector& grad_inputs, - std::vector* grad_outputs) { - grad_outputs->push_back(Identity(scope, grad_inputs[0])); - grad_outputs->push_back(NoGradient()); - grad_outputs->push_back(NoGradient()); +Status QuantizeAndDequantizeV2GradHelper(const Scope& scope, + const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + Input input = Shape(scope, op.input(0)); + Input input_min = op.input(1); + Input input_max = op.input(2); + int64 axis; + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis)); + auto qdq_v2_grad = QuantizeAndDequantizeV2Grad( + scope, grad_inputs[0], input, input_min, input_max, + QuantizeAndDequantizeV2Grad::Axis(axis)); + grad_outputs->push_back(qdq_v2_grad.input_backprop); + grad_outputs->push_back(qdq_v2_grad.input_min_backprop); + grad_outputs->push_back(qdq_v2_grad.input_max_backprop); return scope.status(); } -REGISTER_GRADIENT_OP("QuantizeAndDequantizeV2", QuantizeAndDequantizeV2Grad); +REGISTER_GRADIENT_OP("QuantizeAndDequantizeV2", + QuantizeAndDequantizeV2GradHelper); Status QuantizeAndDequantizeV3Grad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index e680cc72b3b..882b4032f76 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -68,6 +68,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "//tensorflow/core/platform:resource_loader", ], ) diff --git a/tensorflow/cc/saved_model/reader_test.cc b/tensorflow/cc/saved_model/reader_test.cc index e898664c221..bc630bcaede 100644 --- a/tensorflow/cc/saved_model/reader_test.cc +++ b/tensorflow/cc/saved_model/reader_test.cc @@ -21,15 +21,22 @@ limitations under the License. #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/path.h" +#include "tensorflow/core/platform/resource_loader.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { namespace { -constexpr char kTestDataPbTxt[] = - "cc/saved_model/testdata/half_plus_two_pbtxt/00000123"; -constexpr char kTestDataSharded[] = - "cc/saved_model/testdata/half_plus_two/00000123"; +string TestDataPbTxt() { + return io::JoinPath("tensorflow", "cc", "saved_model", "testdata", + "half_plus_two_pbtxt", "00000123"); +} + +string TestDataSharded() { + return io::JoinPath("tensorflow", "cc", "saved_model", "testdata", + "half_plus_two", "00000123"); +} class ReaderTest : public ::testing::Test { protected: @@ -49,8 +56,7 @@ class ReaderTest : public ::testing::Test { TEST_F(ReaderTest, TagMatch) { MetaGraphDef meta_graph_def; - const string export_dir = - io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded); + const string export_dir = GetDataDependencyFilepath(TestDataSharded()); TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe}, &meta_graph_def)); CheckMetaGraphDef(meta_graph_def); @@ -59,8 +65,7 @@ TEST_F(ReaderTest, TagMatch) { TEST_F(ReaderTest, NoTagMatch) { MetaGraphDef meta_graph_def; - const string export_dir = - io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded); + const string export_dir = GetDataDependencyFilepath(TestDataSharded()); Status st = ReadMetaGraphDefFromSavedModel(export_dir, {"missing-tag"}, &meta_graph_def); EXPECT_FALSE(st.ok()); @@ -73,8 +78,7 @@ TEST_F(ReaderTest, NoTagMatch) { TEST_F(ReaderTest, NoTagMatchMultiple) { MetaGraphDef meta_graph_def; - const string export_dir = - io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded); + const string export_dir = GetDataDependencyFilepath(TestDataSharded()); Status st = ReadMetaGraphDefFromSavedModel( export_dir, {kSavedModelTagServe, "missing-tag"}, &meta_graph_def); EXPECT_FALSE(st.ok()); @@ -87,8 +91,7 @@ TEST_F(ReaderTest, NoTagMatchMultiple) { TEST_F(ReaderTest, PbtxtFormat) { MetaGraphDef meta_graph_def; - const string export_dir = - io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPbTxt); + const string export_dir = GetDataDependencyFilepath(TestDataPbTxt()); TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe}, &meta_graph_def)); CheckMetaGraphDef(meta_graph_def); @@ -97,8 +100,7 @@ TEST_F(ReaderTest, PbtxtFormat) { TEST_F(ReaderTest, InvalidExportPath) { MetaGraphDef meta_graph_def; - const string export_dir = - io::JoinPath(testing::TensorFlowSrcRoot(), "missing-path"); + const string export_dir = GetDataDependencyFilepath("missing-path"); Status st = ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe}, &meta_graph_def); EXPECT_FALSE(st.ok()); diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index a53d5265459..dfbea9c49eb 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -84,6 +84,7 @@ tf_cc_test( "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/platform:resource_loader", "@com_google_absl//absl/strings", "@llvm-project//llvm:support", # fixdeps: keep "@llvm-project//llvm:x86_code_gen", # fixdeps: keep diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index a7294323d1d..6206f68faf9 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/aot/codegen.h" +#include #include #include @@ -29,6 +30,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/resource_loader.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -139,23 +141,40 @@ TEST_F(ParseCppClassTest, ParseFail) { static void CompareWithGoldenFile( const string& tensorflow_relative_golden_file_name, - const string& expected_contents) { + const string& expected_contents, bool ignore_cr) { + // Get rid of all CR characters, we may be running under windows. + string sanitized_expected_contents(expected_contents); + if (ignore_cr) { + sanitized_expected_contents.erase( + std::remove(sanitized_expected_contents.begin(), + sanitized_expected_contents.end(), '\r'), + sanitized_expected_contents.end()); + } + // To update the golden file, flip update_golden to true and run the // following: // bazel test --test_strategy=local \ // third_party/tensorflow/compiler/aot:codegen_test const bool update_golden = false; - const string golden_file_name = io::JoinPath( - testing::TensorFlowSrcRoot(), tensorflow_relative_golden_file_name); + string golden_file_name; if (update_golden) { + golden_file_name = io::JoinPath(testing::TensorFlowSrcRoot(), + tensorflow_relative_golden_file_name); TF_EXPECT_OK( WriteStringToFile(Env::Default(), golden_file_name, expected_contents)); } + golden_file_name = + GetDataDependencyFilepath(tensorflow_relative_golden_file_name); string golden_file_contents; TF_ASSERT_OK(ReadFileToString(Env::Default(), golden_file_name, &golden_file_contents)); + if (ignore_cr) { + golden_file_contents.erase(std::remove(golden_file_contents.begin(), + golden_file_contents.end(), '\r'), + golden_file_contents.end()); + } EXPECT_EQ(golden_file_contents, expected_contents); } @@ -229,14 +248,18 @@ TEST(CodegenTest, Golden) { // The other fields in metadata_result are tested as part of the generated // header test. - CompareWithGoldenFile("compiler/aot/codegen_test_o.golden", - metadata_result.object_file_data); + // This specific golden test checks a binary file. It can potentially run into + // issues due to ABIs not being stable, but has not so far. + // If we see any ABI issues, we should reconsider this specific test case. + CompareWithGoldenFile("tensorflow/compiler/aot/codegen_test_o.golden", + metadata_result.object_file_data, false); string header; TF_ASSERT_OK( GenerateHeader(opts, config, compile_result, metadata_result, &header)); - CompareWithGoldenFile("compiler/aot/codegen_test_h.golden", header); + CompareWithGoldenFile("tensorflow/compiler/aot/codegen_test_h.golden", header, + true); } } // namespace } // namespace tfcompile diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 4bb1fde7a9b..08dc1b13db6 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -1883,6 +1883,8 @@ absl::flat_hash_set GetKnownXLAWhitelistOp() { "EmptyTensorList", "ExtractImagePatches", "Igamma", + "IgammaGradA", + "RandomGammaGrad", "Igammac", "FFT", "FFT2D", @@ -1909,7 +1911,6 @@ absl::flat_hash_set GetKnownXLAWhitelistOp() { "LinSpace", "ListDiff", "LogMatrixDeterminant", - "LowerBound", "MatMul", "MatrixBandPart", "MatrixDiag", @@ -2036,7 +2037,6 @@ absl::flat_hash_set GetKnownXLAWhitelistOp() { "TensorScatterUpdate", "TridiagonalSolve", "TruncatedNormal", - "UpperBound", "UnsortedSegmentMax", "UnsortedSegmentMin", "UnsortedSegmentProd", diff --git a/tensorflow/compiler/jit/xla_kernel_creator.cc b/tensorflow/compiler/jit/xla_kernel_creator.cc index 6ee1db2c7c5..fd6fd4b5b58 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator.cc +++ b/tensorflow/compiler/jit/xla_kernel_creator.cc @@ -20,15 +20,17 @@ limitations under the License. namespace tensorflow { -bool XlaKernelCreator::CanCreateKernel(const FunctionLibraryRuntime& flr, - const NodeDef& node_def) const { - return CanCreateXlaKernel(node_def); +bool XlaKernelCreator::CanCreateKernel( + const FunctionLibraryRuntime& flr, + const std::shared_ptr& props) const { + return CanCreateXlaKernel(props->node_def); } -Status XlaKernelCreator::CreateKernel(FunctionLibraryRuntime* flr, - const NodeDef& node_def, - std::unique_ptr* kernel) const { - return CreateXlaKernel(flr, node_def, kernel); +Status XlaKernelCreator::CreateKernel( + FunctionLibraryRuntime* flr, + const std::shared_ptr& props, + std::unique_ptr* kernel) const { + return CreateXlaKernel(flr, props->node_def, kernel); } namespace { diff --git a/tensorflow/compiler/jit/xla_kernel_creator.h b/tensorflow/compiler/jit/xla_kernel_creator.h index 8815ee49ce5..856701a791d 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator.h +++ b/tensorflow/compiler/jit/xla_kernel_creator.h @@ -29,11 +29,13 @@ class XlaKernelCreator : public CustomKernelCreator { // Given a NodeDef 'node_def' and the function library runtime 'flr', returns // true if 'node_def' is a call to a compilable function defined in 'flr', // with the kXlaCompileAttr set. - bool CanCreateKernel(const FunctionLibraryRuntime& flr, - const NodeDef& node_def) const override; + bool CanCreateKernel( + const FunctionLibraryRuntime& flr, + const std::shared_ptr& props) const override; // Given a supported NodeDef, returns a XlaLaunchOp that computes the node. - Status CreateKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def, + Status CreateKernel(FunctionLibraryRuntime* flr, + const std::shared_ptr& props, std::unique_ptr* kernel) const override; }; diff --git a/tensorflow/compiler/jit/xla_kernel_creator_test.cc b/tensorflow/compiler/jit/xla_kernel_creator_test.cc index 7ec37332906..ad94d60d9b5 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator_test.cc +++ b/tensorflow/compiler/jit/xla_kernel_creator_test.cc @@ -30,10 +30,12 @@ limitations under the License. namespace tensorflow { -NodeDef ToNodeDef(const string& text) { +std::shared_ptr ToNodeProperties(const string& text) { NodeDef node_def; + DataTypeVector dummy; EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def)); - return node_def; + return std::make_shared(nullptr, std::move(node_def), dummy, + dummy); } // Create a FunctionDef that takes one resource and one regular param @@ -98,11 +100,11 @@ TEST_F(XlaKernelCreatorTest, OneFloatOneResourceArgument) { (*fdef.mutable_attr())["_XlaMustCompile"] = BoolAttr(true); Init({fdef}); XlaKernelCreator xla_kernel_creator; - NodeDef callsite = - ToNodeDef(R"pb( + auto callsite = + ToNodeProperties(R"pb( name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b' )pb"); - (*callsite.mutable_attr())["_XlaMustCompile"] = BoolAttr(true); + (*(callsite->node_def.mutable_attr()))["_XlaMustCompile"] = BoolAttr(true); // Note: need to set attribute on the created node. Status status = xla_kernel_creator.CreateKernel(flr_, callsite, &kernel_); @@ -127,13 +129,14 @@ TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrNotSet) { Init({fdef}); XlaKernelCreator xla_kernel_creator; - Status status = xla_kernel_creator.CreateKernel(flr_, ToNodeDef(R"proto( - name: 'XTimesY' - op: 'XTimesY' - input: 'a' - input: 'b' - )proto"), - &kernel_); + Status status = + xla_kernel_creator.CreateKernel(flr_, ToNodeProperties(R"proto( + name: 'XTimesY' + op: 'XTimesY' + input: 'a' + input: 'b' + )proto"), + &kernel_); EXPECT_TRUE(errors::IsInternal(status)) << status.ToString(); } @@ -143,13 +146,14 @@ TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrIsSetToFalse) { Init({fdef}); XlaKernelCreator xla_kernel_creator; - Status status = xla_kernel_creator.CreateKernel(flr_, ToNodeDef(R"proto( - name: 'XTimesY' - op: 'XTimesY' - input: 'a' - input: 'b' - )proto"), - &kernel_); + Status status = + xla_kernel_creator.CreateKernel(flr_, ToNodeProperties(R"proto( + name: 'XTimesY' + op: 'XTimesY' + input: 'a' + input: 'b' + )proto"), + &kernel_); EXPECT_TRUE(errors::IsInternal(status)) << status.ToString(); } diff --git a/tensorflow/compiler/jit/xla_kernel_creator_util.cc b/tensorflow/compiler/jit/xla_kernel_creator_util.cc index 5aab0ff3bd6..de091fc93b4 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator_util.cc +++ b/tensorflow/compiler/jit/xla_kernel_creator_util.cc @@ -218,12 +218,13 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def, TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function)); Device* dev = flr->device(); Status s; - OpKernelConstruction construction( - DeviceType(dev->device_type()), dev, - dev->GetAllocator(AllocatorAttributes()), &node_def, - &fbody->fdef.signature(), flr, dev->resource_manager(), fbody->arg_types, - input_memory_types, fbody->ret_types, output_memory_types, - flr->graph_def_version(), &s); + auto props = std::make_shared( + &fbody->fdef.signature(), node_def, fbody->arg_types, fbody->ret_types); + OpKernelConstruction construction(DeviceType(dev->device_type()), dev, + dev->GetAllocator(AllocatorAttributes()), + flr, dev->resource_manager(), props, + input_memory_types, output_memory_types, + flr->graph_def_version(), &s); *kernel = absl::make_unique( &construction, constant_arg_indices, resource_arg_indices, function, diff --git a/tensorflow/compiler/mlir/g3doc/README.md b/tensorflow/compiler/mlir/g3doc/README.md new file mode 100644 index 00000000000..39734828d19 --- /dev/null +++ b/tensorflow/compiler/mlir/g3doc/README.md @@ -0,0 +1,3 @@ +# TensorFlow MLIR + +These are the docs for: https://www.tensorflow.org/mlir diff --git a/tensorflow/compiler/mlir/g3doc/_book.yaml b/tensorflow/compiler/mlir/g3doc/_book.yaml new file mode 100644 index 00000000000..a75a2137536 --- /dev/null +++ b/tensorflow/compiler/mlir/g3doc/_book.yaml @@ -0,0 +1,26 @@ +upper_tabs: +# Tabs left of dropdown menu +- include: /_upper_tabs_left.yaml +- include: /api_docs/_upper_tabs_api.yaml +# Dropdown menu +- name: Resources + path: /resources + is_default: true + menu: + - include: /resources/_menu_toc.yaml + lower_tabs: + # Subsite tabs + other: + - name: Guide + contents: + - title: Overview + path: /mlir/overview + - heading: Dialects + - title: Overview + path: /mlir/dialects + - title: TensorFlow + path: /mlir/tf_ops + - title: TensorFlow Lite + path: /mlir/tfl_ops + +- include: /_upper_tabs_right.yaml diff --git a/tensorflow/compiler/mlir/g3doc/_index.yaml b/tensorflow/compiler/mlir/g3doc/_index.yaml new file mode 100644 index 00000000000..affd0926af5 --- /dev/null +++ b/tensorflow/compiler/mlir/g3doc/_index.yaml @@ -0,0 +1,54 @@ +book_path: /mlir/_book.yaml +project_path: /mlir/_project.yaml +description: +landing_page: + custom_css_path: /site-assets/css/style.css + rows: + - heading: MLIR unifies the infrastructure for high-performance ML models in TensorFlow. + items: + - description: > + The MLIR project defines a common + intermediate representation (IR) that unifies the infrastructure required to execute high + performance machine learning models in TensorFlow and similar ML frameworks. This project + will include the application of HPC techniques, along with integration of + search algorithms like reinforcement learning. MLIR aims to reduce the + cost to bring up new hardware, and improve usability for existing + TensorFlow users. + + - code_block: | +
+        // Syntactically similar to LLVM:
+        func @testFunction(%arg0: i32) {
+          %x = call @thingToCall(%arg0) : (i32) -> i32
+          br ^bb1
+        ^bb1:
+          %y = addi %x, %x : i32
+          return %y : i32
+        }
+        
+ + - classname: devsite-landing-row-cards + items: + - heading: "Multi-Level Intermediate Representation for Compiler Infrastructure" + youtube_id: qzljG6DKgic + buttons: + - label: Watch the video + path: https://www.youtube.com/watch?v=qzljG6DKgic + - heading: "A new intermediate representation and compiler framework" + image_path: /resources/images/tf-logo-card-16x9.png + path: https://blog.tensorflow.org/2019/04/mlir-new-intermediate-representation.html + buttons: + - label: Read on TensorFlow blog + path: https://blog.tensorflow.org/2019/04/mlir-new-intermediate-representation.html + - heading: MLIR on GitHub + image_path: /resources/images/github-card-16x9.png + path: https://github.com/llvm/llvm-project/tree/master/mlir + buttons: + - label: View on GitHub + path: https://github.com/llvm/llvm-project/tree/master/mlir + - heading: TensorFlow MLIR on GitHub + image_path: /resources/images/github-card-16x9.png + path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/mlir + buttons: + - label: View on GitHub + path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/mlir diff --git a/tensorflow/compiler/mlir/g3doc/dialects.md b/tensorflow/compiler/mlir/g3doc/dialects.md new file mode 100644 index 00000000000..fa6c4605b27 --- /dev/null +++ b/tensorflow/compiler/mlir/g3doc/dialects.md @@ -0,0 +1,37 @@ +# MLIR dialects + +## Overview + + +To separate different hardware and software targets, MLIR has “dialects”, +including: + +* TensorFlow IR, which represents all things possible in TensorFlow graphs. +* XLA HLO IR, which is designed to take advantage of XLA’s compilation + abilities (with output to, among other things, TPUs). +* An experimental affine dialect, which focuses on + [polyhedral representations](https://en.wikipedia.org/wiki/Polytope_model) + and optimizations. +* LLVM IR, which has a 1:1 mapping between it and LLVM’s own representation, + allowing MLIR to emit GPU and CPU code through LLVM. +* TensorFlow Lite, which will translate to running code on mobile platforms. + +Each dialect consists of a set of defined operations which have invariants +placed on them, like: “This is a binary operator, and the inputs and outputs +have the same types.” + +## Adding to MLIR + +MLIR has no fixed/built-in list of globally known operations (no “intrinsics”). +Dialects can define entirely custom types, which is how MLIR can model things +like the LLVM IR type system (which has first class aggregates), domain +abstractions important for ML-optimized accelerators like quantized types, and +even the Swift or Clang type systems (which are built around Swift/Clang +declaration nodes) in the future. + +If you want to connect a new low-level compiler, you would create a new dialect +and the lowerings between the TensorFlow Graph dialect and your dialect. +This smooths the path for hardware and compiler makers. You can even target +dialects at different levels in the same model; the higher-level optimizers +will respect the unfamiliar parts of the IR and wait for a lower level to handle +it. diff --git a/tensorflow/compiler/mlir/g3doc/images/mlir-infra.svg b/tensorflow/compiler/mlir/g3doc/images/mlir-infra.svg new file mode 100644 index 00000000000..aec0986ba02 --- /dev/null +++ b/tensorflow/compiler/mlir/g3doc/images/mlir-infra.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/tensorflow/compiler/mlir/g3doc/overview.md b/tensorflow/compiler/mlir/g3doc/overview.md new file mode 100644 index 00000000000..4cf99ba3800 --- /dev/null +++ b/tensorflow/compiler/mlir/g3doc/overview.md @@ -0,0 +1,36 @@ +# MLIR + +## Overview + +MLIR, or Multi-Level Intermediate Representation, is a representation format +and library of compiler utilities that sits between the model representation +and low-level compilers/executors that generate hardware-specific code. + +MLIR is, at its heart, a flexible infrastructure for modern optimizing +compilers. This means it consists of a specification for intermediate +representations (IR) and a code toolkit to perform transformations on that +representation. (In compiler parlance, as you move from higher-level +representations to lower-level representations, these transformations can be +called “lowerings”) + +MLIR is highly influenced by [LLVM](https://llvm.org/) and unabashedly reuses +many great ideas from it. It has a flexible type system, and allows +representing, analyzing and transforming graphs combining multiple levels of +abstraction in the same compilation unit. These abstractions include TensorFlow +operations, nested polyhedral loop regions, and even LLVM instructions and fixed +hardware operations and types. + +We expect MLIR to be of interest to many groups, including: + +* Compiler researchers and implementers looking to optimize performance and + memory consumption of machine learning models +* Hardware makers looking for a way to connect their hardware to TensorFlow, + such as TPUs, portable neural hardware in phones, and other custom ASICs +* People writing language bindings that want to take advantage of optimizing + compilers and hardware acceleration. + +The TensorFlow ecosystem contains a number of compilers and optimizers that +operate at multiple levels of the software and hardware stack. We expect the +gradual adoption of MLIR to simplify every aspect of this stack. + +MLIR overview diagram diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 7f5da2ad3de..8d51dd3cfc2 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -208,6 +208,7 @@ cc_library( "ir/tfl_ops.h.inc", "ir/tfl_ops_interface.cc.inc", "ir/tfl_ops_interface.h.inc", + "runtime_verifiers.inc", "utils/attribute_utils.cc", ], hdrs = [ @@ -303,12 +304,14 @@ cc_library( "transforms/optimize_functional_ops.cc", "transforms/prepare_composite_functions_tf.cc", "transforms/prepare_tf.cc", + "transforms/runtime_type_verify.cc", "transforms/split_merged_operands.cc", "transforms/trim_functions_tf.cc", "transforms/unroll_batch_matmul.cc", "transforms/while_loop_outline.cc", ], hdrs = [ + "ir/tfl_ops_interface.h.inc", "transforms/dilated_conv.h", "transforms/passes.h", "transforms/unroll_batch_matmul.h", @@ -461,9 +464,9 @@ cc_library( ) tf_native_cc_binary( - name = "operator-converter-gen", + name = "converter-gen", srcs = [ - "operator_converter_gen.cc", + "converter_gen.cc", ], deps = [ "@llvm-project//llvm:support", @@ -473,14 +476,18 @@ tf_native_cc_binary( ) gentbl( - name = "operator_converter_inc", + name = "converter_inc", tbl_outs = [ ( - "", # This driver has no options. + "--gen-operator-converters", "operator_converters.inc", ), + ( + "--gen-runtime-verifiers", + "runtime_verifiers.inc", + ), ], - tblgen = ":operator-converter-gen", + tblgen = ":converter-gen", td_file = "ir/tfl_ops.td", td_srcs = [ ":tensorflow_lite_ops_td_files", @@ -582,7 +589,6 @@ cc_library( "@com_google_absl//absl/strings", "@flatbuffers", "@llvm-project//llvm:support", - "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:IR", "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:StandardOps", @@ -645,12 +651,14 @@ tf_cc_binary( "//tensorflow/compiler/mlir:init_mlir", "//tensorflow/compiler/mlir/tensorflow:translate_cl_options", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:errors", "//tensorflow/lite:framework", "//tensorflow/lite/schema:schema_fbs", "//tensorflow/stream_executor/lib", "@com_google_absl//absl/strings", "@llvm-project//llvm:support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", ], ) @@ -694,7 +702,6 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes", "//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass", "//tensorflow/compiler/mlir/tensorflow:translate_lib", - "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", @@ -727,7 +734,6 @@ cc_library( "//tensorflow/lite/tools/optimize:quantize_weights", "//tensorflow/stream_executor/lib", "@llvm-project//llvm:support", - "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", diff --git a/tensorflow/compiler/mlir/lite/operator_converter_gen.cc b/tensorflow/compiler/mlir/lite/converter_gen.cc similarity index 75% rename from tensorflow/compiler/mlir/lite/operator_converter_gen.cc rename to tensorflow/compiler/mlir/lite/converter_gen.cc index 6ebc71fd029..02d9ef45591 100644 --- a/tensorflow/compiler/mlir/lite/operator_converter_gen.cc +++ b/tensorflow/compiler/mlir/lite/converter_gen.cc @@ -28,6 +28,9 @@ limitations under the License. #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" #include "mlir/TableGen/Attribute.h" // TF:llvm-project +#include "mlir/TableGen/Format.h" // TF:llvm-project +#include "mlir/TableGen/Operator.h" // TF:llvm-project +#include "mlir/TableGen/Predicate.h" // TF:llvm-project using llvm::DefInit; using llvm::dyn_cast; @@ -41,6 +44,19 @@ using llvm::SmallVector; using llvm::StringInit; using llvm::StringRef; +enum ActionType { + OpConv, + RuntimeVerify, +}; + +// NOLINTNEXTLINE +llvm::cl::opt action( + llvm::cl::desc("Action to perform:"), + llvm::cl::values(clEnumValN(OpConv, "gen-operator-converters", + "Generate operator converters"), + clEnumValN(RuntimeVerify, "gen-runtime-verifiers", + "Generate TFLite runtime verifiers"))); + // Returns the associated option name for the given op definition. static inline std::string GetOperatorOptionName(const Record &def) { assert(def.getName().startswith("TFL_") && "unexpected op prefix"); @@ -342,8 +358,101 @@ static bool OperatorWritersMain(raw_ostream &os, RecordKeeper &records) { return false; } +static void GenOperandResultVerifier(raw_ostream &os, + llvm::ArrayRef values, + StringRef valueKind) { + mlir::tblgen::FmtContext fctx; + + bool first = true; + for (auto static_value : llvm::enumerate(values)) { + auto *definit = llvm::cast(static_value.value()); + auto *val = definit->getDef()->getValue("tflRuntimeTypePredicate"); + if (!val) continue; + + // Create code block on first type to verify. + if (first) { + os << " {\n"; + os << " unsigned index = " << static_value.index() << ";\n"; + first = false; + } + + mlir::tblgen::Pred pred(dyn_cast(val->getValue())); + auto desc = + definit->getDef()->getValueAsString("tflRuntimeTypeDescription"); + + // Emit a loop to check all the dynamic values in the pack. + os << formatv(" for (Value v : top.getODS{0}{1}s({2})) {{\n", + // Capitalize the first letter to match the function name + valueKind.substr(0, 1).upper(), valueKind.substr(1), + static_value.index()); + + os << " (void)v;\n" + << " if (!(" + << tgfmt(pred.getCondition(), &fctx.withSelf("v.getType()")) << ")) {\n" + << formatv( + " return op->emitOpError(\"{0} #\") << index " + "<< \" must be {1}, but got \" << v.getType();\n", + valueKind, desc) + << " }\n" // if + << " ++index;\n" + << " }\n"; // for + } + + // Emit closing brace if needed. + if (!first) os << " }\n"; +} + +// NOLINTNEXTLINE +static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) { + emitSourceFileHeader("MLIR TFLite Runtime Verifiers", os); + + // Retrieve all the definitions derived from TFL_Op and sort by record name. + std::vector defs = records.getAllDerivedDefinitions("Op"); + llvm::sort(defs, LessRecord()); + + // Iterate through all the ops defined. + for (const auto *def : defs) { + mlir::tblgen::Operator op(*def); + if (!op.getTrait("TflRuntimeVerifyOpInterface::Trait")) continue; + + mlir::tblgen::FmtContext verify_ctx; + os << "::mlir::LogicalResult " << op.getCppClassName() + << "::VerifyTflRuntimeTypes(::mlir::Operation *op) {\n"; + os << " auto top = cast<" << op.getCppClassName() << ">(op); (void)top;\n"; + verify_ctx.withOp("top"); + + for (int i = 0, e = op.getNumOperands(); i < e; ++i) { + for (int i = 0, e = op.getNumOperands(); i < e; ++i) { + auto &value = op.getOperand(i); + // Skip from from first variadic operands for now. Else getOperand index + // used below doesn't match. + if (value.isVariadic()) break; + if (!value.name.empty()) + verify_ctx.addSubst(value.name, formatv("op->getOperand({0})", i)); + } + for (int i = 0, e = op.getNumResults(); i < e; ++i) { + auto &value = op.getResult(i); + // Skip from from first variadic results for now. Else getResult index + // used below doesn't match. + if (value.isVariadic()) break; + if (!value.name.empty()) + verify_ctx.addSubst(value.name, formatv("op->getResult({0})", i)); + } + } + GenOperandResultVerifier(os, def->getValueAsDag("arguments")->getArgs(), + "operand"); + GenOperandResultVerifier(os, def->getValueAsDag("results")->getArgs(), + "result"); + os << " return mlir::success();\n}\n"; + } + + return false; +} + int main(int argc, char **argv) { llvm::InitLLVM y(argc, argv); llvm::cl::ParseCommandLineOptions(argc, argv); - return TableGenMain(argv[0], &OperatorWritersMain); + if (action == ActionType::OpConv) + return TableGenMain(argv[0], &OperatorWritersMain); + return TableGenMain(argv[0], &RuntimeVerifierWriterMain); } diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td b/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td index 8c72e93d1aa..8e100538659 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td @@ -71,4 +71,23 @@ def TFL_SparseOp : OpInterface<"SparseOpInterface"> { ]; } +//===----------------------------------------------------------------------===// +// TFL runtime type verification of operand/result types. + +def TFL_RuntimeVerification : OpInterface<"TflRuntimeVerifyOpInterface"> { + let description = [{ + Interface to verify TFLite runtime op verification. + + This verifies that the converted TFLite ops has operand/result type + supported by the TFLite runtime. + }]; + + let methods = [ + StaticInterfaceMethod< + [{Returns whether the op's operands/results are supported by runtime.}], + "LogicalResult", "VerifyTflRuntimeTypes", (ins "Operation*":$op) + >, + ]; +} + #endif // TFL_OP_INTERFACES diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index 2c9f7badb23..e73f6b732eb 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -723,12 +723,11 @@ static LogicalResult Verify(PackOp op) { } // Make sure all inputs have the same shape and element type. - // TODO(rahulsp): Simplify once b/135032064 is fixed. - for (Value operand : op.getOperands()) { - auto other_type = operand.getType().cast(); - if (input_type != other_type) + // TODO(b/135032063): Simplify once fixed. + for (Type operand_type : op.getOperandTypes()) { + if (failed(mlir::verifyCompatibleShape(input_type, operand_type))) return op.emitOpError("operands should be of the same type. got ") - << input_type << ", " << other_type; + << input_type << ", " << operand_type; } return success(); @@ -1872,6 +1871,7 @@ LogicalResult WhileOp::moveOutOfLoop(llvm::ArrayRef ops) { #include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.cc.inc" #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc" +#include "tensorflow/compiler/mlir/lite/runtime_verifiers.inc" Operation *TensorFlowLiteDialect::materializeConstant(OpBuilder &builder, Attribute value, diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 5b247a43442..d4127e53fa9 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -109,29 +109,63 @@ def TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">; // Derived shape attribute class. //===----------------------------------------------------------------------===// class DerivedShapeAttr : DerivedAttr<"ArrayRef", body>; -class DerivedTFLiteTypeAttr : DerivedAttr<"tflite::TensorType", body>; +class DerivedTFLiteTypeAttr : + DerivedAttr<"tflite::TensorType", body>; + +// These additional types/type constraints here are used to decouple the ops +// from runtime support for the ops. Prefer to use these types when defining +// new TF_Ops for uniformity. + +// TFL Runtime type predicate. +class TFL_RuntimeType { + Pred tflRuntimeTypePredicate = t.predicate; + string tflRuntimeTypeDescription = t.description; +} + +class TFL_AnyTypeOf allowedRuntimeTypes, string description = "", + list allowedOpTypes = [AnyType]> : + AnyTypeOf, + TFL_RuntimeType>; + +class TFL_TensorOf allowedRuntimeTypes, + list allowedOpTypes = [AnyType]> : + TensorOf, TFL_RuntimeType>; + +class TFL_TensorOfOrNone allowedRuntimeTypes, string description = "", + list allowedOpTypes = [AnyType]> : + AnyTypeOf<[TFL_TensorOf, NoneType], description>, + TFL_RuntimeType, NoneType]>>; + +class TFL_VariadicTensorOf allowedRuntimeTypes, + list allowedOpTypes = [AnyType]> : + Variadic>, + TFL_RuntimeType>>; def TFL_Int32Or64 : IntOfWidths<[32, 64]>; -def TFL_FpTensor : TensorOf<[AnyFloat]>; - -def TFL_I32OrI64Tensor : TensorOf<[TFL_Int32Or64]>; - -def TFL_BoolTensor : TypeAlias; - +def TFL_BoolTensor : TFL_TensorOf<[I1]>; +def TFL_FpOrI32OrI64Tensor : TFL_TensorOf<[AnyFloat, TFL_Int32Or64]>; +def TFL_FpTensor : TFL_TensorOf<[AnyFloat]>; +def TFL_I32OrI64Tensor : TFL_TensorOf<[TFL_Int32Or64]>; +def TFL_I32Tensor : TFL_TensorOf<[I32]>; +def TFL_I64Tensor : TFL_TensorOf<[I64]>; // TODO(jpienaar): Expand to all int types. -def TFL_IntTensor : TypeAlias; +def TFL_IntTensor : TypeAlias; + +class TFL_0DTensorOf allowedRuntimeTypes, + list allowedOpTypes = [AnyType]> : + 0DTensorOf, TFL_RuntimeType>; +class TFL_1DTensorOf allowedRuntimeTypes, + list allowedOpTypes = [AnyType]> : + 1DTensorOf, TFL_RuntimeType>; +class TFL_2DTensorOf allowedRuntimeTypes, + list allowedOpTypes = [AnyType]> : + 2DTensorOf, TFL_RuntimeType>; // This is used to represent the type of "ref tensors" or tensors that are // used as variables to track state. def TFL_StatefulTensor : TypeAlias; -// Tensor or None type. -class TFL_TensorOfOrNone allowedTypes, string description = ""> : - AnyTypeOf<[TensorOf, NoneType], description>; - -def TFL_FpOrI32OrI64Tensor : TensorOf<[AnyFloat, TFL_Int32Or64]>; - //===----------------------------------------------------------------------===// // Rank/Shape helpers. //===----------------------------------------------------------------------===// @@ -255,7 +289,8 @@ def TFL_ComparisonBinaryBuilder : OpBuilder< //===----------------------------------------------------------------------===// class TFL_Op traits = []> : - Op { + Op])> { // FlatBuffer generation specific information. // ------------------------------------------- // When generating the FlatBuffer output some operations have @@ -360,11 +395,11 @@ def TFL_AddNOp : TFL_Op<"add_n", [Commutative, NoSideEffect, SameOperandsAndResu }]; let arguments = (ins - Variadic>:$inputs + TFL_VariadicTensorOf<[F32, I32, QI16, QUI16]>:$inputs ); let results = (outs - TensorOf<[F32, I32, QI16, QUI16]>:$sum + TFL_TensorOf<[F32, I32, QI16, QUI16]>:$sum ); } @@ -381,14 +416,14 @@ retained with length 1. }]; let arguments = (ins - I1Tensor:$input, - I32Tensor:$reduction_indices, + TFL_BoolTensor:$input, + TFL_I32Tensor:$reduction_indices, DefaultValuedAttr:$keep_dims ); let results = (outs - I1Tensor:$output + TFL_BoolTensor:$output ); let hasOptions = 1; @@ -403,10 +438,10 @@ def TFL_TransposeConvOp: Performs transpose convolution operation on input. }]; - let arguments = ( - ins 1DTensorOf<[I32]>:$output_shape, - TensorOf<[F32, TFL_Uint8, QI8, QUI8]>:$weights, - TensorOf<[F32, TFL_Uint8, QI8, QUI8]>:$input, + let arguments = (ins + TFL_1DTensorOf<[I32]>:$output_shape, + TFL_TensorOf<[F32, TFL_Uint8, QI8, QUI8]>:$weights, + TFL_TensorOf<[F32, TFL_Uint8, QI8, QUI8]>:$input, TFL_PaddingAttr:$padding, I32Attr:$stride_h, I32Attr:$stride_w @@ -478,7 +513,7 @@ def TFL_ArgMaxOp : TFL_Op<"arg_max", [NoSideEffect]> { }]; let arguments = ( - ins TensorOf<[F32, I32, I8, TFL_Uint8, QI8, QUI8]>:$input, + ins TFL_TensorOf<[F32, I32, I8, TFL_Uint8, QI8, QUI8]>:$input, TFL_I32OrI64Tensor:$dim ); @@ -506,7 +541,7 @@ def TFL_ArgMinOp : TFL_Op<"arg_min", [NoSideEffect]> { }]; let arguments = ( - ins TensorOf<[F32, I32, I8, TFL_Uint8, QI8, QUI8]>:$input, + ins TFL_TensorOf<[F32, I32, I8, TFL_Uint8, QI8, QUI8]>:$input, TFL_I32OrI64Tensor:$dim ); @@ -549,14 +584,14 @@ def TFL_ConcatenationOp : TFL_Op<"concatenation", }]; let arguments = ( - ins Variadic>:$values, + ins TFL_VariadicTensorOf< + [F32, I64, I32, I16, I8, QI8, QUI8, QI16, TFL_Uint8]>:$values, I32Attr:$axis, TFL_AFAttr:$fused_activation_function ); let results = (outs - TensorOf< + TFL_TensorOf< [F32, I64, I32, I16, I8, QI8, QUI8, QI16, TFL_Uint8]>:$output ); @@ -708,8 +743,8 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [ let summary = "Fully connected op"; let arguments = (ins - TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$input, - TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$filter, + TFL_TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$input, + TFL_TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$filter, TFL_TensorOfOrNone<[F32, QI32, QUI32]>:$bias, TFL_AFAttr:$fused_activation_function, @@ -719,7 +754,7 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [ // Depending on the weights format, this op can have one or two outputs. let results = (outs - Variadic>:$output + TFL_VariadicTensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$output ); let verifier = [{ return Verify(*this); }]; @@ -748,8 +783,8 @@ def TFL_GatherOp : TFL_Op<"gather", [ }]; let arguments = (ins - TensorOf<[F32, I1, I8, I32, I64, TFL_Str, QI8, QUI8, QI16]>:$params, - TensorOf<[I32, I64]>:$indices, + TFL_TensorOf<[F32, I1, I8, I32, I64, TFL_Str, QI8, QUI8, QI16]>:$params, + TFL_TensorOf<[I32, I64]>:$indices, I32Attr:$axis ); @@ -761,7 +796,7 @@ def TFL_GatherOp : TFL_Op<"gather", [ ]; let results = (outs - TensorOf<[F32, I1, I8, I32, I64, TFL_Str, QI8, QUI8, QI16]>:$output + TFL_TensorOf<[F32, I1, I8, I32, I64, TFL_Str, QI8, QUI8, QI16]>:$output ); let hasOptions = 1; @@ -775,12 +810,12 @@ def TFL_GatherNdOp : TFL_Op<"gather_nd", [NoSideEffect]> { }]; let arguments = (ins - TensorOf<[F32, I8, I64, I32, TFL_Uint8]>:$params, + TFL_TensorOf<[F32, I8, I64, I32, TFL_Uint8]>:$params, TFL_I32OrI64Tensor:$indices ); let results = (outs - TensorOf<[F32, I8, I64, I32, TFL_Uint8]>:$output + TFL_TensorOf<[F32, I8, I64, I32, TFL_Uint8]>:$output ); } @@ -794,8 +829,8 @@ def TFL_LessEqualOp : TFL_Op<"less_equal", [ }]; let arguments = ( - ins TensorOf<[F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$lhs, - TensorOf<[F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$rhs); + ins TFL_TensorOf<[F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$lhs, + TFL_TensorOf<[F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$rhs); let results = (outs TFL_BoolTensor:$output); @@ -827,7 +862,7 @@ convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imag }]; let arguments = (ins - TensorOf<[F32, QI8, QUI8]>:$input, + TFL_TensorOf<[F32, QI8, QUI8]>:$input, I32Attr:$radius, F32Attr:$bias, F32Attr:$alpha, @@ -835,7 +870,7 @@ convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imag ); let results = (outs - TensorOf<[F32, QI8, QUI8]>:$output + TFL_TensorOf<[F32, QI8, QUI8]>:$output ); let hasOptions = 1; @@ -881,11 +916,34 @@ def TFL_MatrixDiagOp : TFL_Op<"matrix_diag", [ }]; let arguments = (ins - TensorOf<[F32, I8, I64, I32, TFL_Uint8]>:$diagonal + TFL_TensorOf<[F32, I8, I64, I32, TFL_Uint8]>:$diagonal ); let results = (outs - TensorOf<[F32, I8, I64, I32, TFL_Uint8]>:$output + TFL_TensorOf<[F32, I8, I64, I32, TFL_Uint8]>:$output + ); + + let hasOptions = 0; +} + +def TFL_MatrixSetDiagOp : TFL_Op<"matrix_set_diag", [NoSideEffect]> { + let summary = [{ + Returns a batched matrix tensor with new batched diagonal values. + }]; + + let description = [{ +Given `input` and `diagonal`, this operation returns a tensor with the +same shape and values as `input`, except for the main diagonal of the +innermost matrices. These will be overwritten by the values in `diagonal`. + }]; + + let arguments = (ins + TensorOf<[F32, I32, I64, I8, QI8, QI16, QUI8, TFL_Uint8, TFL_Quint8]>:$input, + TensorOf<[F32, I32, I64, I8, QI8, QI16, QUI8, TFL_Uint8, TFL_Quint8]>:$diagonal + ); + + let results = (outs + TensorOf<[F32, I32, I64, I8, QI8, QI16, QUI8, TFL_Uint8, TFL_Quint8]>:$output ); let hasOptions = 0; @@ -935,14 +993,14 @@ using the `tf.gather operation`. For example: let arguments = (ins TFL_FpTensor:$boxes, TFL_FpTensor:$scores, - I32Tensor:$max_output_size, + TFL_I32Tensor:$max_output_size, TFL_FpTensor:$iou_threshold, TFL_FpTensor:$score_threshold ); let results = (outs - I32Tensor:$selected_indices, - I32Tensor:$valid_outputs + TFL_I32Tensor:$selected_indices, + TFL_I32Tensor:$valid_outputs ); } @@ -989,16 +1047,16 @@ larger than 0. let arguments = (ins TFL_FpTensor:$boxes, TFL_FpTensor:$scores, - I32Tensor:$max_output_size, + TFL_I32Tensor:$max_output_size, TFL_FpTensor:$iou_threshold, TFL_FpTensor:$score_threshold, TFL_FpTensor:$soft_nms_sigma ); let results = (outs - I32Tensor:$selected_indices, + TFL_I32Tensor:$selected_indices, TFL_FpTensor:$selected_scores, - I32Tensor:$valid_outputs + TFL_I32Tensor:$valid_outputs ); } @@ -1082,11 +1140,11 @@ def TFL_EmbeddingLookupOp: TFL_Op<"embedding_lookup", }]; let arguments = (ins - TensorOf<[I32]>:$lookup, - TensorOf<[F32, I8, TFL_Uint8]>:$value + TFL_TensorOf<[I32]>:$lookup, + TFL_TensorOf<[F32, I8, TFL_Uint8]>:$value ); - let results = (outs TensorOf<[F32, I8, TFL_Uint8]>:$output); + let results = (outs TFL_TensorOf<[F32, I8, TFL_Uint8]>:$output); } def TFL_EqualOp: TFL_Op<"equal", [Commutative, ResultsBroadcastableShape, @@ -1100,8 +1158,8 @@ def TFL_EqualOp: TFL_Op<"equal", [Commutative, ResultsBroadcastableShape, let arguments = ( ins - TensorOf<[I1, F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$x, - TensorOf<[I1, F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$y + TFL_TensorOf<[I1, F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$x, + TFL_TensorOf<[I1, F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$y ); let results = (outs TFL_BoolTensor:$output); @@ -1261,10 +1319,10 @@ def TFL_FloorModOp : TFL_Op<"floor_mod", [ResultsBroadcastableShape, NoSideEffec }]; let arguments = ( - ins TensorOf<[I32, I64, F32]>:$lhs, - TensorOf<[I32, I64, F32]>:$rhs); + ins TFL_TensorOf<[I32, I64, F32]>:$lhs, + TFL_TensorOf<[I32, I64, F32]>:$rhs); - let results = (outs TensorOf<[I32, I64, F32]>:$output); + let results = (outs TFL_TensorOf<[I32, I64, F32]>:$output); let builders = [TFL_BroadcastableBinaryBuilder]; } @@ -1291,7 +1349,7 @@ def TFL_GreaterOp : TFL_Op<"greater", [ } def TFL_HardSwishOp: TFL_Op<"hard_swish", [NoSideEffect, - SameOperandsAndResultType]> { + SameOperandsAndResultShape]> { let summary = "Hardswish activation function."; let description = [{ Computes hard-swish activation function @@ -1299,9 +1357,9 @@ def TFL_HardSwishOp: TFL_Op<"hard_swish", [NoSideEffect, element-wise. }]; - let arguments = (ins TensorOf<[F32, QUI8, QI8]>:$input); + let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8]>:$input); - let results = (outs TensorOf<[F32, QUI8, QI8]>:$out); + let results = (outs TFL_TensorOf<[F32, QUI8, QI8]>:$out); let hasOptions = 0; } @@ -1319,11 +1377,11 @@ def TFL_L2NormalizationOp : TFL_Op<"l2_normalization", [NoSideEffect, }]; let arguments = (ins - TensorOf<[F32, QUI8, QI8, QUI16, QI16, I8]>:$input, + TFL_TensorOf<[F32, QUI8, QI8, QUI16, QI16, I8]>:$input, TFL_AFAttr:$fused_activation_function ); - let results = (outs TensorOf<[F32, QUI8, QI8, QUI16, QI16, I8]>:$output); + let results = (outs TFL_TensorOf<[F32, QUI8, QI8, QUI16, QI16, I8]>:$output); let hasOptions = 1; @@ -1380,10 +1438,10 @@ def TFL_LogicalAndOp : TFL_Op<"logical_and", [NoSideEffect]> { }]; let arguments = ( - ins I1Tensor:$lhs, - I1Tensor:$rhs); + ins TFL_BoolTensor:$lhs, + TFL_BoolTensor:$rhs); - let results = (outs I1Tensor:$output); + let results = (outs TFL_BoolTensor:$output); let parser = [{ return mlir::impl::parseOneResultSameOperandTypeOp(parser, result); }]; @@ -1397,9 +1455,9 @@ def TFL_LogicalNotOp : TFL_Op<"logical_not", [NoSideEffect, NoQuantizableResult] Element-wise logical NOT operation. }]; - let arguments = (ins I1Tensor:$lhs); + let arguments = (ins TFL_BoolTensor:$lhs); - let results = (outs I1Tensor:$output); + let results = (outs TFL_BoolTensor:$output); } def TFL_LogicalOrOp : TFL_Op<"logical_or", [NoSideEffect]> { @@ -1410,10 +1468,10 @@ def TFL_LogicalOrOp : TFL_Op<"logical_or", [NoSideEffect]> { }]; let arguments = ( - ins I1Tensor:$lhs, - I1Tensor:$rhs); + ins TFL_BoolTensor:$lhs, + TFL_BoolTensor:$rhs); - let results = (outs I1Tensor:$output); + let results = (outs TFL_BoolTensor:$output); let parser = [{ return mlir::impl::parseOneResultSameOperandTypeOp(parser, result); }]; @@ -1433,9 +1491,9 @@ def TFL_LogisticOp: TFL_Op<"logistic", [ Computes element-wise Sigmoid of input }]; - let arguments = (ins TensorOf<[AnyFloat, QI8, QUI8, QI16, QUI16]>:$x); + let arguments = (ins TFL_TensorOf<[AnyFloat, QI8, QUI8, QI16, QUI16]>:$x); - let results = (outs TensorOf<[AnyFloat, QI8, QUI8, QI16, QUI16]>:$y); + let results = (outs TFL_TensorOf<[AnyFloat, QI8, QUI8, QI16, QUI16]>:$y); } def TFL_LogOp: TFL_Op<"log", [ @@ -1585,12 +1643,12 @@ def TFL_MaximumOp : TFL_Op<"maximum", [ }]; let arguments = ( - ins TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$lhs, - TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$rhs + ins TFL_TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$lhs, + TFL_TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$rhs ); let results = (outs - TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$max + TFL_TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$max ); let builders = [TFL_BroadcastableBinaryBuilder]; @@ -1610,13 +1668,13 @@ def TFL_MeanOp : TFL_Op<"mean", [NoSideEffect, SameOperandsAndResultsScale]> { }]; let arguments = (ins - TensorOf<[F32, I8, I32, I64, QI8, QUI8, TFL_Uint8]>:$input, - TensorOf<[I32, I64]>:$axis, + TFL_TensorOf<[F32, I8, I32, I64, QI8, QUI8, TFL_Uint8]>:$input, + TFL_TensorOf<[I32, I64]>:$axis, BoolAttr:$keep_dims ); let results = (outs - TensorOf<[F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$output); + TFL_TensorOf<[F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$output); let hasOptions = 1; let customOption = "ReducerOptions"; @@ -1635,16 +1693,16 @@ def TFL_OneHotOp : TFL_Op<"one_hot", [NoSideEffect]> { }]; let arguments = (ins - TensorOf<[I32, I64]>:$indices, - I32Tensor:$depth, - TensorOf<[F32, I32, I64, I1]>:$on_value, - TensorOf<[F32, I32, I64, I1]>:$off_value, + TFL_TensorOf<[I32, I64]>:$indices, + TFL_I32Tensor:$depth, + TFL_TensorOf<[F32, I32, I64, I1]>:$on_value, + TFL_TensorOf<[F32, I32, I64, I1]>:$off_value, I32Attr:$axis ); let results = (outs - TensorOf<[F32, I32, I64, I1]>:$output + TFL_TensorOf<[F32, I32, I64, I1]>:$output ); let hasOptions = 1; @@ -1658,11 +1716,11 @@ Rounds the values of a tensor to the nearest integer, element-wise. }]; let arguments = (ins - TensorOf<[F32]>:$x + TFL_TensorOf<[F32]>:$x ); let results = (outs - TensorOf<[F32]>:$y + TFL_TensorOf<[F32]>:$y ); } @@ -1706,7 +1764,7 @@ def TFL_SumOp: TFL_Op<"sum", [NoSideEffect]> { let arguments = (ins AnyTensor:$input, - I32Tensor:$axes, + TFL_I32Tensor:$axes, BoolAttr:$keep_dims ); @@ -1726,7 +1784,7 @@ def TFL_ReduceMinOp: TFL_Op<"reduce_min", [ let arguments = (ins AnyTensor:$input, - I32Tensor:$axes, + TFL_I32Tensor:$axes, BoolAttr:$keep_dims ); @@ -1746,7 +1804,7 @@ def TFL_ReduceMaxOp: TFL_Op<"reduce_max", [ let arguments = (ins AnyTensor:$input, - I32Tensor:$axes, + TFL_I32Tensor:$axes, BoolAttr:$keep_dims ); @@ -1764,8 +1822,8 @@ def TFL_ReduceProdOp: TFL_Op<"reduce_prod", [NoSideEffect]> { }]; let arguments = (ins - TensorOf<[F32, I8, I32, I64]>:$input, - I32Tensor:$axes, + TFL_TensorOf<[F32, I8, I32, I64]>:$input, + TFL_I32Tensor:$axes, BoolAttr:$keep_dims ); @@ -1784,12 +1842,12 @@ def TFL_MinimumOp : TFL_Op<"minimum", [ }]; let arguments = ( - ins TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$lhs, - TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$rhs + ins TFL_TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$lhs, + TFL_TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$rhs ); let results = (outs - TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$min + TFL_TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$min ); let builders = [TFL_BroadcastableBinaryBuilder]; @@ -1869,14 +1927,14 @@ def TFL_PackOp : TFL_Op<"pack", [NoSideEffect, SameOperandsAndResultsScale]> { }]; let arguments = (ins - Variadic>:$values, + TFL_VariadicTensorOf<[F32, I8, I16, I32, I64, QI8, QUI8, QI16]>:$values, I32Attr:$values_count, I32Attr:$axis ); let results = (outs - TensorOf<[F32, I8, I16, I32, I64, QI8, QUI8, QI16]>:$output + TFL_TensorOf<[F32, I8, I16, I32, I64, QI8, QUI8, QI16]>:$output ); let verifier = [{ return Verify(*this); }]; @@ -1918,11 +1976,10 @@ def TFL_PadOp : TFL_Op<"pad", [ ``` }]; - let arguments = ( - ins TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$input, + let arguments = (ins TFL_TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$input, TFL_I32OrI64Tensor:$padding); - let results = (outs TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$output); + let results = (outs TFL_TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$output); let hasOptions = 1; } @@ -1965,11 +2022,11 @@ def TFL_PadV2Op : TFL_Op<"padv2", [ }]; let arguments = ( - ins TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$input, + ins TFL_TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$input, TFL_I32OrI64Tensor:$padding, - TensorOf<[F32, I8, I32, I64]>:$constant_values); + TFL_TensorOf<[F32, I8, I32, I64]>:$constant_values); - let results = (outs TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$output); + let results = (outs TFL_TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$output); let hasOptions = 1; } @@ -2007,11 +2064,11 @@ def TFL_PReluOp : TFL_Op<"prelu", [NoSideEffect]> { }]; let arguments = ( - ins TensorOf<[F32, QUI8]>:$input, - TensorOf<[F32, QUI8]>:$alpha + ins TFL_TensorOf<[F32, QUI8]>:$input, + TFL_TensorOf<[F32, QUI8]>:$alpha ); - let results = (outs TensorOf<[F32, QUI8]>:$output); + let results = (outs TFL_TensorOf<[F32, QUI8]>:$output); let verifier = [{ return Verify(*this); }]; } @@ -2039,9 +2096,9 @@ def TFL_ReluOp: TFL_Op<"relu", [NoSideEffect, x -> max(0, x) }]; - let arguments = (ins TensorOf<[F32, QUI8, I8]>:$x); + let arguments = (ins TFL_TensorOf<[F32, QUI8, I8]>:$x); - let results = (outs TensorOf<[F32, QUI8, I8]>:$y); + let results = (outs TFL_TensorOf<[F32, QUI8, I8]>:$y); } def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect, @@ -2054,9 +2111,9 @@ def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect, x -> max(0, min(6, x)) }]; - let arguments = (ins TensorOf<[F32, QUI8, I8]>:$x); + let arguments = (ins TFL_TensorOf<[F32, QUI8, I8]>:$x); - let results = (outs TensorOf<[F32, QUI8, I8]>:$y); + let results = (outs TFL_TensorOf<[F32, QUI8, I8]>:$y); } def TFL_Relu1Op: TFL_Op<"relu_n1_to_1", [NoSideEffect, @@ -2069,9 +2126,9 @@ def TFL_Relu1Op: TFL_Op<"relu_n1_to_1", [NoSideEffect, x -> max(-1, min(1, x)) }]; - let arguments = (ins TensorOf<[F32, QUI8, I8]>:$x); + let arguments = (ins TFL_TensorOf<[F32, QUI8, I8]>:$x); - let results = (outs TensorOf<[F32, QUI8, I8]>:$y); + let results = (outs TFL_TensorOf<[F32, QUI8, I8]>:$y); } def TFL_ReshapeOp: TFL_Op<"reshape", [ @@ -2085,7 +2142,7 @@ def TFL_ReshapeOp: TFL_Op<"reshape", [ let arguments = ( ins AnyTensor:$input, - I32Tensor:$shape); + TFL_I32Tensor:$shape); let results = (outs AnyTensor:$output); let hasCanonicalizer = 0b1; @@ -2109,7 +2166,7 @@ slice `i`, with the first `seq_lengths[i]` slices along dimension }]; let arguments = (ins - TensorOf<[F32, I16, I32, I64, TFL_Uint8]>:$input, + TFL_TensorOf<[F32, I16, I32, I64, TFL_Uint8]>:$input, TFL_I32OrI64Tensor:$seq_lengths, I32Attr:$seq_dim, @@ -2117,7 +2174,7 @@ slice `i`, with the first `seq_lengths[i]` slices along dimension ); let results = (outs - TensorOf<[F32, I16, I32, I64, TFL_Uint8]>:$output + TFL_TensorOf<[F32, I16, I32, I64, TFL_Uint8]>:$output ); let hasOptions = 1; @@ -2201,12 +2258,12 @@ def TFL_ReverseV2Op: TFL_Op<"reverse_v2", let arguments = ( ins - TensorOf<[F32, I16, I32, I64, TFL_Uint8, I1]>:$input, - TensorOf<[I32, I64]>:$axis + TFL_TensorOf<[F32, I16, I32, I64, TFL_Uint8, I1]>:$input, + TFL_TensorOf<[I32, I64]>:$axis ); let results = (outs - TensorOf<[F32, I16, I32, I64, TFL_Uint8, I1]>:$output + TFL_TensorOf<[F32, I16, I32, I64, TFL_Uint8, I1]>:$output ); } @@ -2228,8 +2285,8 @@ def TFL_SelectOp : TFL_Op<"select", [NoSideEffect, let arguments = (ins TFL_BoolTensor:$condition, - TensorOf<[F32, I1, I8, I16, I32, I64, TFL_Uint8]>:$x, - TensorOf<[F32, I1, I8, I16, I32, I64, TFL_Uint8]>:$y); + TFL_TensorOf<[F32, I1, I8, I16, I32, I64, TFL_Uint8]>:$x, + TFL_TensorOf<[F32, I1, I8, I16, I32, I64, TFL_Uint8]>:$y); let results = (outs AnyTensor:$output); // TODO(jpienaar): autogenerate this. @@ -2257,8 +2314,8 @@ def TFL_SelectV2Op : TFL_Op<"select_v2", [NoSideEffect]> { let arguments = (ins TFL_BoolTensor:$condition, - TensorOf<[F32, I1, I8, I16, I32, I64, TFL_Uint8]>:$x, - TensorOf<[F32, I1, I8, I16, I32, I64, TFL_Uint8]>:$y); + TFL_TensorOf<[F32, I1, I8, I16, I32, I64, TFL_Uint8]>:$x, + TFL_TensorOf<[F32, I1, I8, I16, I32, I64, TFL_Uint8]>:$y); let results = (outs AnyTensor:$output); let builders = [OpBuilder<"Builder *builder, OperationState &result, " @@ -2405,9 +2462,9 @@ def TFL_TanhOp: TFL_Op<"tanh", [ Computes element-wise Hyperbolic tangent of input }]; - let arguments = (ins TensorOf<[F32, I16, I8, QI8, QUI8, QI16, QUI16, TFL_Uint8]>:$x); + let arguments = (ins TFL_TensorOf<[F32, I16, I8, QI8, QUI8, QI16, QUI16, TFL_Uint8]>:$x); - let results = (outs TensorOf<[F32, I16, I8, QI8, QUI8, QI16, QUI16, TFL_Uint8]>:$y); + let results = (outs TFL_TensorOf<[F32, I16, I8, QI8, QUI8, QI16, QUI16, TFL_Uint8]>:$y); } def TFL_TileOp: TFL_Op<"tile", [NoSideEffect, SameOperandsAndResultsScale, @@ -2425,11 +2482,11 @@ def TFL_TileOp: TFL_Op<"tile", [NoSideEffect, SameOperandsAndResultsScale, }]; let arguments = (ins - TensorOf<[F32, I1, I32, I64, TFL_Uint8, QUI8]>:$input, + TFL_TensorOf<[F32, I1, I32, I64, TFL_Uint8, QUI8, TFL_Str]>:$input, TFL_I32OrI64Tensor:$multiples); let results = (outs - TensorOf<[F32, I1, I32, I64, TFL_Uint8, QUI8]>:$output); + TFL_TensorOf<[F32, I1, I32, I64, TFL_Uint8, QUI8, TFL_Str]>:$output); let hasOptions = 0; } @@ -2449,12 +2506,12 @@ def TFL_TopKV2Op: TFL_Op<"topk_v2", [NoSideEffect, TFL_OperandHasRank<1,0>, }]; let arguments = (ins - TensorOf<[F32, I8, I32, I64, TFL_Uint8, QI8, QUI8]>:$input, - I32Tensor:$k); + TFL_TensorOf<[F32, I8, I32, I64, TFL_Uint8, QI8, QUI8]>:$input, + TFL_I32Tensor:$k); let results = (outs - TensorOf<[F32, I8, I32, I64, TFL_Uint8, QI8, QUI8]>:$values, - I32Tensor:$indices); + TFL_TensorOf<[F32, I8, I32, I64, TFL_Uint8, QI8, QUI8]>:$values, + TFL_I32Tensor:$indices); let builders = [OpBuilder<"Builder *builder, OperationState &result, " "Value input, Value k", @@ -2480,7 +2537,7 @@ def TFL_TransposeOp : TFL_Op<"transpose", let arguments = ( ins AnyTensor:$x, - TensorOf<[I32]>:$perm + TFL_TensorOf<[I32]>:$perm ); let results = (outs @@ -2513,14 +2570,14 @@ def TFL_UnpackOp : TFL_Op<"unpack", [NoSideEffect, SameOperandsAndResultsScale]> }]; let arguments = (ins - TensorOf<[F32, I1, I8, I32, QI8, QUI8]>:$input, + TFL_TensorOf<[F32, I1, I8, I32, QI8, QUI8]>:$input, I32Attr:$num, I32Attr:$axis ); let results = (outs - Variadic>:$outputs + TFL_VariadicTensorOf<[F32, I1, I8, I32, QI8, QUI8]>:$outputs ); let verifier = [{ return Verify(*this); }]; @@ -2555,13 +2612,13 @@ def TFL_BatchToSpaceNdOp: TFL_Op<"batch_to_space_nd", [ }]; let arguments = (ins - TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$input, - TensorOf<[I32]>:$block_shape, - TensorOf<[I32]>:$indices + TFL_TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$input, + TFL_TensorOf<[I32]>:$block_shape, + TFL_TensorOf<[I32]>:$indices ); let results = (outs - TensorOf<[F32, I16, I32, I64, QI8, QUI8]>:$output + TFL_TensorOf<[F32, I16, I32, I64, QI8, QUI8]>:$output ); } @@ -2578,13 +2635,13 @@ def TFL_SpaceToBatchNdOp: TFL_Op<"space_to_batch_nd", [ }]; let arguments = (ins - TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$input, - TensorOf<[I32]>:$block_shape, - TensorOf<[I32]>:$paddings + TFL_TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$input, + TFL_TensorOf<[I32]>:$block_shape, + TFL_TensorOf<[I32]>:$paddings ); let results = (outs - TensorOf<[F32, I16, I32, I64, QI8, QUI8]>:$output + TFL_TensorOf<[F32, I16, I32, I64, QI8, QUI8]>:$output ); } @@ -2604,12 +2661,12 @@ def TFL_SpaceToDepthOp: TFL_Op<"space_to_depth", [ }]; let arguments = (ins - TensorOf<[F32, I8, I32, I64, TFL_Uint8, QUI8]>:$input, + TFL_TensorOf<[F32, I8, I32, I64, TFL_Uint8, QUI8]>:$input, I32Attr:$block_size ); let results = (outs - TensorOf<[F32, I8, I32, I64, TFL_Uint8, QUI8]>:$output + TFL_TensorOf<[F32, I8, I32, I64, TFL_Uint8, QUI8]>:$output ); let hasOptions = 1; @@ -2633,12 +2690,12 @@ def TFL_DepthToSpaceOp: TFL_Op<"depth_to_space", [ }]; let arguments = (ins - TensorOf<[F32, I8, I32, I64, TFL_Uint8, TFL_Quint8, QUI8]>:$input, + TFL_TensorOf<[F32, I8, I32, I64, TFL_Uint8, TFL_Quint8, QUI8]>:$input, I32Attr:$block_size ); let results = (outs - TensorOf<[F32, I8, I32, I64, TFL_Uint8, TFL_Quint8, QUI8]>:$output + TFL_TensorOf<[F32, I8, I32, I64, TFL_Uint8, TFL_Quint8, QUI8]>:$output ); let hasOptions = 1; @@ -2657,13 +2714,13 @@ def TFL_SplitOp : TFL_Op<"split", [ }]; let arguments = (ins - TensorOf<[I32]>:$split_dim, - TensorOf<[F32, I16, I32, I64, QI8, QUI8, QI16]>:$value, + TFL_TensorOf<[I32]>:$split_dim, + TFL_TensorOf<[F32, I16, I32, I64, QI8, QUI8, QI16]>:$value, PositiveI32Attr:$num_splits ); let results = (outs - Variadic>:$outputs + TFL_VariadicTensorOf<[F32, I16, I32, I64, QI8, QUI8, QI16]>:$outputs ); let verifier = [{ return Verify(*this); }]; @@ -2681,14 +2738,14 @@ def TFL_SplitVOp : TFL_Op<"split_v", [NoSideEffect, SameOperandsAndResultsScale] }]; let arguments = (ins - TensorOf<[F32, I16, I32, I64, QI8, QUI8, QI16]>:$value, - 1DTensorOf<[I32]>:$size_splits, - 0DTensorOf<[I32]>:$split_dim, + TFL_TensorOf<[F32, I16, I32, I64, QI8, QUI8, QI16]>:$value, + TFL_1DTensorOf<[I32], [I32]>:$size_splits, + TFL_0DTensorOf<[I32], [I32]>:$split_dim, PositiveI32Attr:$num_splits ); let results = (outs - Variadic>:$outputs + TFL_VariadicTensorOf<[F32, I16, I32, I64, QI8, QUI8, QI16]>:$outputs ); let verifier = [{ return Verify(*this); }]; @@ -2706,14 +2763,14 @@ def TFL_ResizeBilinearOp: TFL_Op<"resize_bilinear", [ let arguments = (ins // TODO(ycling): Support quantized types. - TensorOf<[F32, I32, QI8, QUI8]>:$input, - TensorOf<[I32]>:$size, + TFL_TensorOf<[F32, I32, QI8, QUI8]>:$input, + TFL_TensorOf<[I32]>:$size, BoolAttr:$align_corners, DefaultValuedAttr:$half_pixel_centers ); let results = (outs - TensorOf<[F32, QI8, QUI8]>:$output + TFL_TensorOf<[F32, QI8, QUI8]>:$output ); let hasOptions = 1; @@ -2729,13 +2786,13 @@ def TFL_ResizeNearestNeighborOp : TFL_Op<"resize_nearest_neighbor", }]; let arguments = (ins - TensorOf<[F32, I8, TFL_Uint8, QUI8, QI8]>:$input, - TensorOf<[I32]>:$size, + TFL_TensorOf<[F32, I8, TFL_Uint8, QUI8, QI8]>:$input, + TFL_TensorOf<[I32]>:$size, BoolAttr:$align_corners ); let results = (outs - TensorOf<[F32, I8, TFL_Uint8, QUI8, QI8]>:$output + TFL_TensorOf<[F32, I8, TFL_Uint8, QUI8, QI8]>:$output ); let hasOptions = 1; @@ -2769,12 +2826,12 @@ are checked during execution. let arguments = (ins TFL_I32OrI64Tensor:$sparse_indices, TFL_I32OrI64Tensor:$output_shape, - TensorOf<[I32, I64, I8, TFL_Uint8, F32]>:$sparse_values, - TensorOf<[I32, I64, I8, TFL_Uint8, F32]>:$default_value + TFL_TensorOf<[I32, I64, I8, TFL_Uint8, F32]>:$sparse_values, + TFL_TensorOf<[I32, I64, I8, TFL_Uint8, F32]>:$default_value ); let results = (outs - TensorOf<[I32, I64, I8, TFL_Uint8, F32]>:$dense + TFL_TensorOf<[I32, I64, I8, TFL_Uint8, F32]>:$dense ); } @@ -2792,10 +2849,10 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice", }]; let arguments = (ins - TensorOf<[F32, I32, I64, I8, QI8, QUI8, I1, TFL_Quint8, TFL_Uint8]>:$input, - TensorOf<[I32]>:$begin, - TensorOf<[I32]>:$end, - TensorOf<[I32]>:$strides, + TFL_TensorOf<[F32, I32, I64, I8, QI8, QUI8, I1, TFL_Quint8, TFL_Uint8]>:$input, + TFL_TensorOf<[I32]>:$begin, + TFL_TensorOf<[I32]>:$end, + TFL_TensorOf<[I32]>:$strides, I32Attr:$begin_mask, I32Attr:$end_mask, @@ -2805,7 +2862,7 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice", ); let results = (outs - TensorOf<[F32, I32, I64, I8, QI8, QUI8, I1, TFL_Quint8, TFL_Uint8]>:$output + TFL_TensorOf<[F32, I32, I64, I8, QI8, QUI8, I1, TFL_Quint8, TFL_Uint8]>:$output ); let hasOptions = 1; @@ -2820,10 +2877,10 @@ def TFL_CastOp : TFL_Op<"cast", [ }]; let arguments = (ins - TensorOf<[F32, I1, I32, I64, TFL_Quint8, TFL_Uint8, Complex>]>:$input + TFL_TensorOf<[F32, I1, I32, I64, TFL_Quint8, TFL_Uint8, Complex>]>:$input ); - let results = (outs TensorOf<[F32, I1, I32, I64, Complex>]>:$output); + let results = (outs TFL_TensorOf<[F32, I1, I32, I64, Complex>]>:$output); // TFLite's cast op does not utilize CastOptions, instead derives types // from the TfLiteTensors. @@ -2855,13 +2912,13 @@ def TFL_MirrorPadOp: TFL_Op<"mirror_pad", [ let arguments = (ins // TODO: add uint8 support when ready. - TensorOf<[F32, I32, I64]>:$input, - TensorOf<[I32, I64]>:$pad, + TFL_TensorOf<[F32, I32, I64]>:$input, + TFL_TensorOf<[I32, I64]>:$pad, TFL_MirrorPaddingAttr:$mode ); let results = (outs - TensorOf<[F32, I32, I64]>:$output + TFL_TensorOf<[F32, I32, I64]>:$output ); let hasOptions = 1; @@ -2879,12 +2936,12 @@ in the unique output `y`. In other words: let arguments = (ins // TODO: add uint8 support after quantize support. - TensorOf<[I8, I16, I32, I64, F32]>:$input + TFL_TensorOf<[I8, I16, I32, I64, F32]>:$input ); let results = (outs - TensorOf<[I8, I16, I32, I64, F32]>:$output, - TensorOf<[I32, I64]>:$idx + TFL_TensorOf<[I8, I16, I32, I64, F32]>:$output, + TFL_TensorOf<[I32, I64]>:$idx ); DerivedTFLiteTypeAttr idx_out_type = DerivedTFLiteTypeAttr<[{ @@ -3084,11 +3141,11 @@ def TFL_BasicLSTMOp : TFL_Op<"basic_lstm", [NoSideEffect, }]; let arguments = ( - ins TensorOf<[F32, I8, QI8, QUI8, QI16, QUI16]>:$data_input, - TensorOf<[F32, I8, QI8, QUI8, QI16, QUI16]>:$prev_activ_input, - TensorOf<[F32, I8, QI8, QUI8, QI16, QUI16]>:$weights_input, - TensorOf<[F32, QI32, QUI32]>:$biases_input, - TensorOf<[F32, I8, QI8, QUI8, QI16, QUI16]>:$prev_state_input, + ins TFL_TensorOf<[F32, I8, QI8, QUI8, QI16, QUI16]>:$data_input, + TFL_TensorOf<[F32, I8, QI8, QUI8, QI16, QUI16]>:$prev_activ_input, + TFL_TensorOf<[F32, I8, QI8, QUI8, QI16, QUI16]>:$weights_input, + TFL_TensorOf<[F32, QI32, QUI32]>:$biases_input, + TFL_TensorOf<[F32, I8, QI8, QUI8, QI16, QUI16]>:$prev_state_input, // Attributes DefaultValuedAttr:$fused_activation_function, @@ -3102,10 +3159,10 @@ def TFL_BasicLSTMOp : TFL_Op<"basic_lstm", [NoSideEffect, let hasOptions = 1; - let results = (outs 2DTensorOf<[F32, I8, QI8, QUI8, QI16, QUI16]>:$activ_output, - 2DTensorOf<[F32, I8, QI8, QUI8, QI16, QUI16]>:$state_output, - 2DTensorOf<[F32, I8, QI8, QUI8, QI16, QUI16]>:$concat_temp, - 2DTensorOf<[F32, I8, QI8, QUI8, QI16, QUI16]>:$activ_temp); + let results = (outs TFL_2DTensorOf<[F32, I8, QI8, QUI8, QI16, QUI16]>:$activ_output, + TFL_2DTensorOf<[F32, I8, QI8, QUI8, QI16, QUI16]>:$state_output, + TFL_2DTensorOf<[F32, I8, QI8, QUI8, QI16, QUI16]>:$concat_temp, + TFL_2DTensorOf<[F32, I8, QI8, QUI8, QI16, QUI16]>:$activ_temp); } // This is the FULL kernel type LSTM op. @@ -3138,19 +3195,19 @@ Ba et al. “Layer Normalization” }]; let arguments = ( - ins TensorOf<[F32]>:$input, + ins TFL_TensorOf<[F32]>:$input, // Weights TFL_TensorOfOrNone<[F32, I8]>:$input_to_input_weights, - TensorOf<[F32, I8]>:$input_to_forget_weights, - TensorOf<[F32, I8]>:$input_to_cell_weights, - TensorOf<[F32, I8]>:$input_to_output_weights, + TFL_TensorOf<[F32, I8]>:$input_to_forget_weights, + TFL_TensorOf<[F32, I8]>:$input_to_cell_weights, + TFL_TensorOf<[F32, I8]>:$input_to_output_weights, // Recurrent weights TFL_TensorOfOrNone<[F32, I8]>:$recurrent_to_input_weights, - TensorOf<[F32, I8]>:$recurrent_to_forget_weights, - TensorOf<[F32, I8]>:$recurrent_to_cell_weights, - TensorOf<[F32, I8]>:$recurrent_to_output_weights, + TFL_TensorOf<[F32, I8]>:$recurrent_to_forget_weights, + TFL_TensorOf<[F32, I8]>:$recurrent_to_cell_weights, + TFL_TensorOf<[F32, I8]>:$recurrent_to_output_weights, // Cell weights TFL_TensorOfOrNone<[F32, I8]>:$cell_to_input_weights, @@ -3161,9 +3218,9 @@ Ba et al. “Layer Normalization” // Bias TFL_TensorOfOrNone<[F32]>:$input_gate_bias, - TensorOf<[F32]>:$forget_gate_bias, - TensorOf<[F32]>:$cell_bias, - TensorOf<[F32]>:$output_gate_bias, + TFL_TensorOf<[F32]>:$forget_gate_bias, + TFL_TensorOf<[F32]>:$cell_bias, + TFL_TensorOf<[F32]>:$output_gate_bias, // Projection weight and bias TFL_TensorOfOrNone<[F32, I8]>:$projection_weights, @@ -3230,19 +3287,19 @@ def TFL_UnidirectionalSequenceLSTMOp : }]; let arguments = ( - ins TensorOf<[F32, I8]>:$input, + ins TFL_TensorOf<[F32, I8]>:$input, // Weights TFL_TensorOfOrNone<[F32, I8]>:$input_to_input_weights, - TensorOf<[F32, I8]>:$input_to_forget_weights, - TensorOf<[F32, I8]>:$input_to_cell_weights, - TensorOf<[F32, I8]>:$input_to_output_weights, + TFL_TensorOf<[F32, I8]>:$input_to_forget_weights, + TFL_TensorOf<[F32, I8]>:$input_to_cell_weights, + TFL_TensorOf<[F32, I8]>:$input_to_output_weights, // Recurrent weights TFL_TensorOfOrNone<[F32, I8]>:$recurrent_to_input_weights, - TensorOf<[F32, I8]>:$recurrent_to_forget_weights, - TensorOf<[F32, I8]>:$recurrent_to_cell_weights, - TensorOf<[F32, I8]>:$recurrent_to_output_weights, + TFL_TensorOf<[F32, I8]>:$recurrent_to_forget_weights, + TFL_TensorOf<[F32, I8]>:$recurrent_to_cell_weights, + TFL_TensorOf<[F32, I8]>:$recurrent_to_output_weights, // Cell weights TFL_TensorOfOrNone<[F32, I8]>:$cell_to_input_weights, @@ -3253,9 +3310,9 @@ def TFL_UnidirectionalSequenceLSTMOp : // Bias TFL_TensorOfOrNone<[F32]>:$input_gate_bias, - TensorOf<[F32]>:$forget_gate_bias, - TensorOf<[F32]>:$cell_bias, - TensorOf<[F32]>:$output_gate_bias, + TFL_TensorOf<[F32]>:$forget_gate_bias, + TFL_TensorOf<[F32]>:$cell_bias, + TFL_TensorOf<[F32]>:$output_gate_bias, // Projection weight and bias TFL_TensorOfOrNone<[F32, I8]>:$projection_weights, @@ -3316,16 +3373,16 @@ def TFL_UnidirectionalSequenceRNNOp : }]; let arguments = ( - ins TensorOf<[F32, I8]>:$input, + ins TFL_TensorOf<[F32, I8]>:$input, // Weights - TensorOf<[F32, I8]>:$input_to_input_weights, + TFL_TensorOf<[F32, I8]>:$input_to_input_weights, // Recurrent weights - TensorOf<[F32, I8]>:$recurrent_to_input_weights, + TFL_TensorOf<[F32, I8]>:$recurrent_to_input_weights, // Bias - TensorOf<[F32]>:$input_gate_bias, + TFL_TensorOf<[F32]>:$input_gate_bias, // Hidden state. TFL_StatefulTensor:$hidden_state, @@ -3335,7 +3392,7 @@ def TFL_UnidirectionalSequenceRNNOp : TFL_AFAttr:$fused_activation_function ); - let results = (outs TensorOf<[F32, I8]>:$output); + let results = (outs TFL_TensorOf<[F32, I8]>:$output); let hasOptions = 1; @@ -3362,11 +3419,11 @@ the output tensor can vary depending on how many true values there are in }]; let arguments = (ins - I1Tensor:$input + TFL_BoolTensor:$input ); let results = (outs - I64Tensor:$index + TFL_I64Tensor:$index ); } @@ -3381,8 +3438,8 @@ def TFL_NumericVerifyOp : Op:$input, - TensorOf<[F32]>:$ref, + TFL_TensorOf<[QI8, QUI8, QI16, QUI16]>:$input, + TFL_TensorOf<[F32]>:$ref, // Attributes DefaultValuedAttr:$tolerance @@ -3410,13 +3467,13 @@ def TFL_SVDFOp : }]; let arguments = ( - ins TensorOf<[F32, I8]>:$input, + ins TFL_TensorOf<[F32, I8]>:$input, // Feature Weights. - TensorOf<[F32, I8]>:$feature_weights, + TFL_TensorOf<[F32, I8]>:$feature_weights, // Time weights - TensorOf<[F32, I8]>:$time_weights, + TFL_TensorOf<[F32, I8]>:$time_weights, // Bias TFL_TensorOfOrNone<[F32]>:$input_gate_bias, @@ -3429,7 +3486,7 @@ def TFL_SVDFOp : TFL_AFAttr:$fused_activation_function ); - let results = (outs TensorOf<[F32, I8]>:$output); + let results = (outs TFL_TensorOf<[F32, I8]>:$output); let hasOptions = 1; @@ -3449,10 +3506,10 @@ def TFL_SegmentSumOp: TFL_Op<"segment_sum", [NoSideEffect]> { }]; let arguments = (ins - TensorOf<[F32, I32]>:$data, - I32Tensor:$segment_ids + TFL_TensorOf<[F32, I32]>:$data, + TFL_I32Tensor:$segment_ids ); - let results = (outs TensorOf<[F32, I32]>:$output); + let results = (outs TFL_TensorOf<[F32, I32]>:$output); } def TFL_YieldOp : Op { diff --git a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc index e7a6cf7f47d..f2b89aebb44 100644 --- a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc @@ -282,6 +282,7 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, if (pass_config.legalize_tf_while) { pm.addPass(mlir::TFL::CreateWhileOutlinePass()); } + pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass()); auto status = ConvertTFExecutorToTFLOrFlatbuffer( module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops, diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h index 749ee7a9f57..ed998510328 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h @@ -150,7 +150,8 @@ struct QuantizationPattern : public RewritePattern { explicit QuantizationPattern(MLIRContext* context, bool enable_verify, float error_tolerance, bool single_layer_verify) - : RewritePattern(DQ::getOperationName(), 1, context), + // Set the score to a large number so it is always preferred. + : RewritePattern(DQ::getOperationName(), 300, context), enable_verify(enable_verify), error_tolerance(error_tolerance), single_layer_verify(single_layer_verify) {} diff --git a/tensorflow/compiler/mlir/lite/tests/extract-ophint.mlir b/tensorflow/compiler/mlir/lite/tests/extract-ophint.mlir index bde800897c5..a18ba9cd91a 100644 --- a/tensorflow/compiler/mlir/lite/tests/extract-ophint.mlir +++ b/tensorflow/compiler/mlir/lite/tests/extract-ophint.mlir @@ -178,15 +178,20 @@ func @inputsAfterOutputs() { // ----- -// expected-error@+1 {{Found malformed ophint regions: missing inputs or outputs.}} module { -func @extractOphintFailure() { +func @extractOphintSame() { %0 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x16x1xf32> %1 = call @AnotherFunc(%0) : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> %2 = "tf.Sigmoid"(%1) {T = "tfdtype$DT_FLOAT", name = "Sigmoid"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> %3 = "tf.Mul"(%2, %1) {T = "tfdtype$DT_FLOAT", name = "mul"} : (tensor<1x16x16x1xf32>, tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> %4 = "tf.Identity"(%3) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "d4b1eb00b81211e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation-d4b1eb00b81211e99426dc4a3e957995-0-None-None"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> return + +// CHECK: [[VAL_0:%.*]] = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x16x1xf32> +// CHECK: [[VAL_1:%.*]] = call @AnotherFunc([[VAL_0]]) : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> +// CHECK: [[VAL_2:%.*]] = "tf.Sigmoid"([[VAL_1]]) {T = "tfdtype$DT_FLOAT", name = "Sigmoid"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> +// CHECK: [[VAL_3:%.*]] = "tf.Mul"([[VAL_2]], [[VAL_1]]) {T = "tfdtype$DT_FLOAT", name = "mul"} : (tensor<1x16x16x1xf32>, tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> +// CHECK: [[VAL_4:%.*]] = "tf.Identity"([[VAL_3]]) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "d4b1eb00b81211e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation-d4b1eb00b81211e99426dc4a3e957995-0-None-None"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> } func @AnotherFunc(%arg0: tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> { diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index e44128d587f..e40047ea216 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -739,6 +739,15 @@ func @matrix_diag_v3(%arg0: tensor<8x16xf32>) -> tensor<8x16x16xf32> { // CHECK: return [[VAL_6]] : tensor<8x16x16xf32> } +func @matrix_set_diag(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> tensor<3x3xi32> { + %0 = "tf.MatrixSetDiag"(%arg0, %arg1) : (tensor<3x3xi32>, tensor<3xi32>) -> tensor<3x3xi32> + return %0 : tensor<3x3xi32> + +// CHECK-LABEL: func @matrix_set_diag( +// CHECK: [[VAL_0:%.*]] = "tfl.matrix_set_diag"(%arg0, %arg1) : (tensor<3x3xi32>, tensor<3xi32>) -> tensor<3x3xi32> +// CHECK: return [[VAL_0]] +} + func @maximum(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xf32> { %0 = "tf.Maximum"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> return %0 : tensor<8x16xf32> @@ -1364,3 +1373,83 @@ func @reciprocal_i64(%arg0: tensor<8xi64>) -> tensor<8xi64> { // CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<1xi64>, tensor<8xi64>) -> tensor<8xi64> // CHECK: return } + +func @random_uniform() -> tensor<2x5xf32> { + %0 = "tf.Const"() { value = dense<[2, 5]> : tensor<2xi32> } : () -> tensor<2xi32> + %1 = "tf.RandomUniform"(%0) { seed = 1, seed2 = 0} : (tensor<2xi32>) -> tensor<2x5xf32> + return %1 : tensor<2x5xf32> + + // CHECK-LABEL: random_uniform + // CHECK: %[[CST:.*]] = constant dense + // CHECK: return %[[CST:.*]] : tensor<2x5xf32> +} + +func @random_uniform_no_fold(%arg0: tensor<2xi32>) -> tensor<2x5xf32> { + %1 = "tf.RandomUniform"(%arg0) { seed = 0, seed2 = 0} : (tensor<2xi32>) -> tensor<2x5xf32> + return %1 : tensor<2x5xf32> + + // CHECK-LABEL: random_uniform_no_fold + // CHECK: %[[RANDOM:.*]] = "tf.RandomUniform" +} + +func @random_uniform_no_fold2(%arg0: tensor<2xi32>) -> tensor<*xf32> { + %1 = "tf.RandomUniform"(%arg0) { seed = 1, seed2 = 2} : (tensor<2xi32>) -> tensor<*xf32> + return %1 : tensor<*xf32> + + // CHECK-LABEL: random_uniform_no_fold2 + // CHECK: %[[RANDOM:.*]] = "tf.RandomUniform" +} + +func @random_uniform_no_fold3() -> tensor<2x5xf64> { + %0 = "tf.Const"() { value = dense<[2, 5]> : tensor<2xi32> } : () -> tensor<2xi32> + %1 = "tf.RandomUniform"(%0) { seed = 1, seed2 = 0} : (tensor<2xi32>) -> tensor<2x5xf64> + return %1 : tensor<2x5xf64> + + // CHECK-LABEL: random_uniform_no_fold3 + // CHECK: %[[RANDOM:.*]] = "tf.RandomUniform" +} + +func @LstmWithoutProjection(%arg: tensor<28x1x28xf32>) -> (tensor<28x1x16xf32>) { + %1 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<16x28xf32>} : () -> tensor<16x28xf32> + %2 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<16x16xf32>} : () -> tensor<16x16xf32> + %3 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<16xf32>} : () -> tensor<16xf32> + %4 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<1x16xf32>} : () -> tensor<1x16xf32> + %5 = "tf.Const"() {device = "", dtype = f32, value = dense<-1.000000e+00> : tensor<1xf32>} : () -> tensor<1xf32> + %6:3 = "tf.UnidirectionalSequenceLstm"(%arg, %1, %1, %1, %1, %2, %2, %2, %2, %3, %3, %3, %3, %3, %3, %3, %5, %5, %4, %4) {_tflite_input_indices = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 18, 19], device = ""} : (tensor<28x1x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1x16xf32>, tensor<1x16xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<28x1x16xf32>) + return %6#2 : tensor<28x1x16xf32> +} + +// CHECK: func @LstmWithoutProjection([[VAL_0:%.*]]: tensor<28x1x28xf32>) -> tensor<28x1x16xf32> { +// CHECK: [[VAL_1:%.*]] = constant dense<0.000000e+00> : tensor<16x28xf32> +// CHECK: [[VAL_2:%.*]] = constant dense<0.000000e+00> : tensor<16x16xf32> +// CHECK: [[VAL_3:%.*]] = constant dense<0.000000e+00> : tensor<16xf32> +// CHECK: [[VAL_4:%.*]] = constant dense<0.000000e+00> : tensor<1x16xf32> +// CHECK: [[VAL_5:%.*]] = constant unit +// CHECK: [[VAL_6:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_1]], [[VAL_1]], [[VAL_1]], [[VAL_1]], [[VAL_2]], [[VAL_2]], [[VAL_2]], [[VAL_2]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_5]], [[VAL_5]], [[VAL_4]], [[VAL_4]], [[VAL_5]], [[VAL_5]], [[VAL_5]], [[VAL_5]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<28x1x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, none, none, tensor<1x16xf32>, tensor<1x16xf32>, none, none, none, none) -> tensor<28x1x16xf32> +// CHECK: return [[VAL_6]] : tensor<28x1x16xf32> +// CHECK: } + +func @LstmWithProjection(%arg: tensor<28x1x16xf32>) -> (tensor<28x1x8xf32>) { + %1 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<16x16xf32>} : () -> tensor<16x16xf32> + %2 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<16x8xf32>} : () -> tensor<16x8xf32> + %3 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<16xf32>} : () -> tensor<16xf32> + %4 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<1x16xf32>} : () -> tensor<1x16xf32> + %5 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<8x16xf32>} : () -> tensor<8x16xf32> + %6 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<1x8xf32>} : () -> tensor<1x8xf32> + %7 = "tf.Const"() {device = "", dtype = f32, value = dense<-1.000000e+00> : tensor<1xf32>} : () -> tensor<1xf32> + %8:3 = "tf.UnidirectionalSequenceLstm"(%arg, %1, %1, %1, %1, %2, %2, %2, %2, %7, %7, %7, %3, %3, %3, %3, %5, %7, %6, %4) {_tflite_input_indices = [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 18, 19], device = ""} : (tensor<28x1x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<8x16xf32>, tensor<1xf32>, tensor<1x8xf32>, tensor<1x16xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<28x1x8xf32>) + return %8#2 : tensor<28x1x8xf32> +} + +// CHECK-LABEL: func @LstmWithProjection( +// CHECK-SAME: [[VAL_7:%.*]]: tensor<28x1x16xf32>) -> tensor<28x1x8xf32> { +// CHECK: [[VAL_8:%.*]] = constant dense<0.000000e+00> : tensor<16x16xf32> +// CHECK: [[VAL_9:%.*]] = constant dense<0.000000e+00> : tensor<16x8xf32> +// CHECK: [[VAL_10:%.*]] = constant dense<0.000000e+00> : tensor<16xf32> +// CHECK: [[VAL_11:%.*]] = constant dense<0.000000e+00> : tensor<1x16xf32> +// CHECK: [[VAL_12:%.*]] = constant dense<0.000000e+00> : tensor<8x16xf32> +// CHECK: [[VAL_13:%.*]] = constant dense<0.000000e+00> : tensor<1x8xf32> +// CHECK: [[VAL_14:%.*]] = constant unit +// CHECK: [[VAL_15:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_7]], [[VAL_8]], [[VAL_8]], [[VAL_8]], [[VAL_8]], [[VAL_9]], [[VAL_9]], [[VAL_9]], [[VAL_9]], [[VAL_14]], [[VAL_14]], [[VAL_14]], [[VAL_10]], [[VAL_10]], [[VAL_10]], [[VAL_10]], [[VAL_12]], [[VAL_14]], [[VAL_13]], [[VAL_11]], [[VAL_14]], [[VAL_14]], [[VAL_14]], [[VAL_14]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<28x1x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<16x8xf32>, none, none, none, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<8x16xf32>, none, tensor<1x8xf32>, tensor<1x16xf32>, none, none, none, none) -> tensor<28x1x8xf32> +// CHECK: return [[VAL_15]] : tensor<28x1x8xf32> +// CHECK: } diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir index 6c9836005fc..da58b3704d0 100644 --- a/tensorflow/compiler/mlir/lite/tests/ops.mlir +++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt -split-input-file -verify-diagnostics %s | FileCheck %s --dump-input-on-failure +// RUN: tf-opt -split-input-file -verify-diagnostics -tfl-runtime-verify %s | FileCheck %s --dump-input-on-failure // Unary math ops // ----- @@ -878,6 +878,14 @@ func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> { // ----- +func @packUnranked(%arg0: tensor<2xi32>, %arg1: tensor<*xi32>) -> tensor<2x2xi32> { + // CHECK: "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} + %0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<2xi32>, tensor<*xi32>) -> tensor<2x2xi32> + return %0 : tensor<2x2xi32> +} + +// ----- + func @packInputRank(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> tensor<1x4x2xi32> { // CHECK: "tfl.pack"(%arg0, %arg1) {axis = 2 : i32, values_count = 2 : i32} %0 = "tfl.pack"(%arg0, %arg1) {axis = 2 : i32, values_count = 2 : i32} : (tensor<1x4xi32>, tensor<1x4xi32>) -> tensor<1x4x2xi32> diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir index 6c635bd3500..1aa1311318a 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir @@ -511,3 +511,34 @@ func @PadStridedSliceNewAxisMask2(%arg0: tensor<4x64x64x1xf32>) -> tensor<1x4x64 %1 = "tf.StridedSlice"(%0, %cst, %cst, %cst_0) {Index = i32, T = f32, _output_shapes = ["tfshape$dim { size: 1 } dim { size: 4 } dim { size: 64 } dim { size: 64 }"], begin_mask = 6 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 1 : i64, shrink_axis_mask = 0 : i64} : (tensor<4x64x64xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x4x64x64xf32> return %1 : tensor<1x4x64x64xf32> } + +// CHECK-LABEL: @MatrixSetDiagV2Conversion +func @MatrixSetDiagV2Conversion(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> tensor<3x3xi32> { + %cst = constant dense<0> : tensor + %0 = "tf.MatrixSetDiagV2"(%arg0, %arg1, %cst) : (tensor<3x3xi32>, tensor<3xi32>, tensor) -> tensor<3x3xi32> + return %0 : tensor<3x3xi32> + + // CHECK: %[[RES:.*]] = "tf.MatrixSetDiag"(%arg0, %arg1) : (tensor<3x3xi32>, tensor<3xi32>) -> tensor<3x3xi32> + // CHECK: return %[[RES]] +} + +// CHECK-LABEL: @MatrixSetDiagV2NonZeroK +func @MatrixSetDiagV2NonZeroK(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> tensor<3x3xi32> { + %cst = constant dense<1> : tensor + %0 = "tf.MatrixSetDiagV2"(%arg0, %arg1, %cst) : (tensor<3x3xi32>, tensor<3xi32>, tensor) -> tensor<3x3xi32> + return %0 : tensor<3x3xi32> + + // CHECK: %[[CST:.*]] = constant dense<1> : tensor + // CHECK: %[[RES:.*]] = "tf.MatrixSetDiagV2"(%arg0, %arg1, %[[CST]]) : (tensor<3x3xi32>, tensor<3xi32>, tensor) -> tensor<3x3xi32> + // CHECK: return %[[RES]] +} + +// CHECK-LABEL: @MatrixSetDiagV3Conversion +func @MatrixSetDiagV3Conversion(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> tensor<3x3xi32> { + %cst = constant dense<0> : tensor + %0 = "tf.MatrixSetDiagV3"(%arg0, %arg1, %cst) : (tensor<3x3xi32>, tensor<3xi32>, tensor) -> tensor<3x3xi32> + return %0 : tensor<3x3xi32> + + // CHECK: %[[RES:.*]] = "tf.MatrixSetDiag"(%arg0, %arg1) : (tensor<3x3xi32>, tensor<3xi32>) -> tensor<3x3xi32> + // CHECK: return %[[RES]] +} diff --git a/tensorflow/compiler/mlir/lite/tests/quantize.mlir b/tensorflow/compiler/mlir/lite/tests/quantize.mlir index 89d1e7cb7f4..0261644e6de 100644 --- a/tensorflow/compiler/mlir/lite/tests/quantize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/quantize.mlir @@ -2,39 +2,44 @@ // RUN: tf-opt %s -tfl-prepare-quantize -tfl-quantize -tfl-numeric-verify | FileCheck --check-prefix=DEBUG %s // CHECK-LABEL: QuantizeFloatConst -func @QuantizeFloatConst() -> tensor { +func @QuantizeFloatConst() -> tensor<2x2x!quant.uniform> { %0 = constant dense<-0.1> : tensor<2x2xf32> - %1 = "tfl.quantize"(%0) {qtype = tensor>} : (tensor<2x2xf32>) -> tensor> - %2 = "tfl.dequantize"(%1) : (tensor>) -> tensor - return %2 : tensor + %1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> + return %1 : tensor<2x2x!quant.uniform> -// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() {qtype = tensor>, value = dense<0> : tensor<2x2xi8>} -// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[cst]]) -// CHECK: return %[[dq]] : tensor +// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform>, value = dense<0> : tensor<2x2xi8>} +// CHECK: return %[[cst]] } // CHECK-LABEL: QuantizeDenseFloatConst -func @QuantizeDenseFloatConst() -> tensor<2x2xf32> { +func @QuantizeDenseFloatConst() -> tensor<2x2x!quant.uniform> { %0 = constant dense<[[-0.1, 1.0], [1.0, 3.0]]> : tensor<2x2xf32> %1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> - %2 = "tfl.dequantize"(%1) : (tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> - return %2 : tensor<2x2xf32> + return %1 : tensor<2x2x!quant.uniform> // CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform>, value = dense<{{\[\[}}0, -1], {{\[}}-1, -1]]> : tensor<2x2xi8>} -// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[cst]]) -// CHECK: return %[[dq]] : tensor<2x2xf32> +// CHECK: return %[[cst]] } // CHECK-LABEL: QuantizeSplatFloatConst -func @QuantizeSplatFloatConst() -> tensor<2x2xf32> { +func @QuantizeSplatFloatConst() -> tensor<2x2x!quant.uniform> { %0 = constant dense<3.0> : tensor<2x2xf32> %1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> + return %1 : tensor<2x2x!quant.uniform> + +// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform>, value = dense<-1> : tensor<2x2xi8>} +// CHECK: return %[[cst]] +} + +// CHECK-LABEL: NotQuantizeFloatConst +func @NotQuantizeFloatConst() -> tensor<2x2xf32> { + %0 = constant dense<-0.1> : tensor<2x2xf32> + %1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> %2 = "tfl.dequantize"(%1) : (tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> return %2 : tensor<2x2xf32> -// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform>, value = dense<-1> : tensor<2x2xi8>} -// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[cst]]) -// CHECK: return %[[dq]] : tensor<2x2xf32> +// CHECK: %[[cst:.*]] = constant dense<-1.000000e-01> : tensor<2x2xf32> +// CHECK: return %[[cst]] : tensor<2x2xf32> } // CHECK-LABEL: DequantizeAndQuantize diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc index 648f469e9b0..7f8ce4cf3d4 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc @@ -24,6 +24,7 @@ limitations under the License. #include "mlir/IR/Function.h" // TF:llvm-project #include "mlir/IR/MLIRContext.h" // TF:llvm-project #include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project #include "mlir/Support/FileUtilities.h" // TF:llvm-project #include "tensorflow/compiler/mlir/init_mlir.h" #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" @@ -32,8 +33,10 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h" #include "tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h" #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h" +#include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/stream_executor/lib/statusor.h" @@ -130,12 +133,24 @@ int main(int argc, char **argv) { llvm::SourceMgr source_mgr; mlir::SourceMgrDiagnosticHandler sourceMgrHandler(source_mgr, &context); - StatusOr module = - tensorflow::LoadFromGraphdefOrMlirSource( - input_file_name, input_mlir, use_splatted_constant, custom_opdefs, - debug_info_file, input_arrays, input_dtypes, input_shapes, - output_arrays, - /*prune_unused_nodes=*/true, &source_mgr, &context); + StatusOr module; + + // TODO(b/147435528): We need to test the e2e behavior once the graph freezing + // inside mlir is done. + if (import_saved_model || import_saved_model_v1) { + if (input_mlir) + module = tensorflow::errors::InvalidArgument( + "Importing saved model should not have input_mlir set"); + module = tensorflow::ImportSavedModel( + import_saved_model, import_saved_model_v1, input_file_name, + saved_model_tags, saved_model_exported_names, &context); + } else { + module = tensorflow::LoadFromGraphdefOrMlirSource( + input_file_name, input_mlir, use_splatted_constant, custom_opdefs, + debug_info_file, input_arrays, input_dtypes, input_shapes, + output_arrays, + /*prune_unused_nodes=*/true, &source_mgr, &context); + } // If errors occur, the library call in the above already logged the error // message. So we can just return here. @@ -182,6 +197,7 @@ int main(int argc, char **argv) { pass_config.inline_functions = inline_functions; tensorflow::AddTFToTFLConversionPasses(pass_config, &pm); + pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass()); std::string result; auto status = tensorflow::ConvertTFExecutorToTFLOrFlatbuffer( diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.cc index 3ec0769db30..de569a3496c 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.cc @@ -22,6 +22,33 @@ using llvm::cl::opt; opt input_file_name(llvm::cl::Positional, llvm::cl::desc(""), llvm::cl::init("-")); + +// NOLINTNEXTLINE +opt import_saved_model( + "savedmodel-to-mlir", + llvm::cl::desc("Import a saved model to its MLIR representation"), + llvm::cl::value_desc("dir")); + +// NOLINTNEXTLINE +opt import_saved_model_v1( + "savedmodel-v1-to-mlir", + llvm::cl::desc("Import a saved model V1 to its MLIR representation"), + llvm::cl::value_desc("dir")); + +// NOLINTNEXTLINE +opt saved_model_tags( + "tf-savedmodel-tags", + llvm::cl::desc("Tags used to indicate which MetaGraphDef to import, " + "separated by ','"), + llvm::cl::init("serve")); + +// NOLINTNEXTLINE +opt saved_model_exported_names( + "tf-savedmodel-exported-names", + llvm::cl::desc("Names to export from SavedModel, separated by ','. Empty " + "(the default) means export all."), + llvm::cl::init("")); + // NOLINTNEXTLINE opt output_file_name("o", llvm::cl::desc(""), llvm::cl::value_desc("filename"), diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h b/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h index faa74865f5f..d7e54d70b81 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h @@ -39,4 +39,10 @@ extern llvm::cl::opt inline_functions; extern llvm::cl::list custom_opdefs; extern llvm::cl::opt emit_quant_adaptor_ops; extern llvm::cl::opt quant_stats_file_name; + +// Import saved model. +extern llvm::cl::opt import_saved_model; +extern llvm::cl::opt import_saved_model_v1; +extern llvm::cl::opt saved_model_tags; +extern llvm::cl::opt saved_model_exported_names; #endif // TENSORFLOW_COMPILER_MLIR_LITE_TF_TFL_TRANSLATE_CL_H_ diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index 6ea1ca26d62..f5097e1c01b 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -15,6 +15,10 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h" +#include +#include + +#include "llvm/Support/raw_ostream.h" #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Module.h" // TF:llvm-project #include "mlir/Parser.h" // TF:llvm-project @@ -155,4 +159,37 @@ Status ConvertTFExecutorToTFLOrFlatbuffer( return Status::OK(); } +StatusOr ImportSavedModel( + bool import_saved_model, bool import_saved_model_v1, + const std::string& input_filename, const std::string& saved_model_tags, + const std::string& saved_model_exported_names, mlir::MLIRContext* context) { + if (import_saved_model) { + std::unordered_set tags = + absl::StrSplit(saved_model_tags, ','); + std::vector exported_names = + absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty()); + + auto module = tensorflow::SavedModelToMlirImport( + input_filename, tags, absl::Span(exported_names), context); + if (!module) + return tensorflow::errors::InvalidArgument("fail to open input file"); + + return module; + } else if (import_saved_model_v1) { + std::unordered_set tags = + absl::StrSplit(saved_model_tags, ','); + + auto module = + tensorflow::SavedModelV1ToMlirImport(input_filename, tags, context); + + if (!module) + return tensorflow::errors::InvalidArgument("fail to open input file"); + + return module; + } else { + return tensorflow::errors::InvalidArgument( + "Should be either saved model v1 or v2"); + } +} + } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h index 6f002af463b..f670ac8e52b 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h @@ -40,6 +40,12 @@ LoadFromGraphdefOrMlirSource( absl::string_view output_arrays, bool prune_unused_nodes, llvm::SourceMgr* source_mgr, mlir::MLIRContext* context); +// Load Saved model (either v1 or v2) into MLIR. +stream_executor::port::StatusOr ImportSavedModel( + bool import_saved_model, bool import_saved_model_v1, + const std::string& input_filename, const std::string& saved_model_tags, + const std::string& saved_model_exported_names, mlir::MLIRContext* context); + // Taking a MLIR module in TF executor dialect and a set of parameters, // applies a set of passes to convert the module to TF Lite dialect and // serializes the result to a string. Depending on an attribute in the module diff --git a/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc b/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc index 7aab9f08732..e07cea8535e 100644 --- a/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc +++ b/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc @@ -698,11 +698,10 @@ void ExtractOphintPass::runOnModule() { if (ophint_composite_ops.empty()) continue; // Verify: Make sure all ophint_composite_ops are valid. + // If not valid, we just don't do anything. for (const auto& kv : ophint_composite_ops) { if (failed(kv.getValue().VerifyOphint())) { - module.emitError() - << "Found malformed ophint regions: missing inputs or outputs."; - return signalPassFailure(); + return; } } diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td index d638a5f1a60..7bc08ee1c76 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td @@ -365,3 +365,7 @@ def : Pat< /*padding=*/ $padding, /*stride_h=*/ ExtractI32At<1>:$strides, /*stride_w=*/ ExtractI32At<2>:$strides)>; + +def : Pat< + (TF_MatrixSetDiagOp $input, $diagonal), + (TFL_MatrixSetDiagOp $input, $diagonal)>; diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index 062895e9b9f..7501832099a 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -49,6 +49,8 @@ limitations under the License. #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/random/philox_random.h" +#include "tensorflow/core/lib/random/random_distributions.h" #include "tensorflow/core/protobuf/error_codes.pb.h" namespace mlir { @@ -61,6 +63,9 @@ namespace { using xla::Status; using xla::StatusOr; +constexpr char kUnidirectionalSequenceLstm[] = "tf.UnidirectionalSequenceLstm"; +constexpr char kTfLiteInputIndices[] = "_tflite_input_indices"; + // Legalize operations in functions. struct LegalizeTF : public FunctionPass { void runOnFunction() override; @@ -114,9 +119,54 @@ DECL_CONVERT_OP(SplitV); DECL_CONVERT_OP(StridedSlice); DECL_CONVERT_OP(Unpack); DECL_CONVERT_OP(Reciprocal); +DECL_CONVERT_OP(RandomUniform); #undef DECL_CONVERT_OP +PatternMatchResult ConvertTFRandomUniformOp::matchAndRewrite( + Operation* op, PatternRewriter& rewriter) const { + auto random_uniform_op = cast(op); + if (random_uniform_op.seed() == 0 && random_uniform_op.seed2() == 0) { + return matchFailure(); + } + if (!random_uniform_op.dtype().isF32()) { + return matchFailure(); + } + typedef tensorflow::random::UniformDistribution< + tensorflow::random::PhiloxRandom, float> + Distribution; + + tensorflow::random::PhiloxRandom generator( + random_uniform_op.seed().getSExtValue(), + random_uniform_op.seed2().getSExtValue()); + Distribution dist; + int num_elements = 0; + if (auto output_type = + random_uniform_op.output().getType().dyn_cast_or_null()) { + if (auto ranked_output = output_type.dyn_cast_or_null()) { + if (!ranked_output.hasRank() || ranked_output.getNumDynamicDims() != 0) { + return matchFailure(); + } + num_elements = output_type.getNumElements(); + size_t offset = 0; + size_t num_samples = Distribution::kResultElementCount; + llvm::SmallVector data; + data.resize(num_elements); + while (offset < num_elements) { + const typename Distribution::ResultType samples = dist(&generator); + std::copy(&samples[0], + &samples[0] + std::min(num_samples, data.size() - offset), + &data[0] + offset); + offset += num_samples; + } + auto output_data = DenseFPElementsAttr::get(output_type, data); + rewriter.replaceOpWithNewOp(op, output_type, output_data); + return matchSuccess(); + } + } + return matchFailure(); +} + PatternMatchResult ConvertTFConcatOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_concat_op = cast(op); @@ -514,6 +564,74 @@ PatternMatchResult ConvertTFReciprocalOp::matchAndRewrite( return matchSuccess(); } +// Legalize unidirectional sequence lstm. +struct LegalizeUnidirectionalSequenceLstm : public RewritePattern { + explicit LegalizeUnidirectionalSequenceLstm(MLIRContext* context) + : RewritePattern(kUnidirectionalSequenceLstm, 1, context) {} + + PatternMatchResult matchAndRewrite(Operation* op, + PatternRewriter& rewriter) const override { + auto tflite_indices_attr = + op->getAttrOfType(kTfLiteInputIndices); + if (!tflite_indices_attr) return matchFailure(); + + SmallVector tflite_indices; + for (auto index_attr : tflite_indices_attr.getValue()) { + IntegerAttr index = index_attr.cast(); + tflite_indices.push_back(index.getInt()); + } + + // Optional input placeholder. + Value none = rewriter.create( + op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr()); + + // Populate inputs. + // UnidirectionalSequenceLstm is expected to have 24 inputs. + SmallVector inputs; + int count = 0; + int total_ophint_converted_inputs = tflite_indices.size(); + for (int i = 0; i < 24; ++i) { + if (count < total_ophint_converted_inputs && tflite_indices[count] == i) { + // specified input. + inputs.push_back(op->getOperand(i)); + count++; + } else { + // Non specified input. + inputs.push_back(none); + } + } + + // Populate outputs. + // UnidirectionalSequenceLstm should only have 1 output, and that is the + // original ophint converted node's 3rd output. + SmallVector result_types; + result_types.push_back(op->getOpResult(2).getType()); + + // Populate attributes. + SmallVector attributes; + // Activation will always be tanh. + attributes.push_back(rewriter.getNamedAttr("fused_activation_function", + rewriter.getStringAttr("TANH"))); + // cell_clip. + attributes.push_back( + rewriter.getNamedAttr("cell_clip", rewriter.getF32FloatAttr(10.0))); + // proj_clip. + attributes.push_back( + rewriter.getNamedAttr("proj_clip", rewriter.getF32FloatAttr(0.0))); + // will always be time_majored. + attributes.push_back( + rewriter.getNamedAttr("time_major", rewriter.getBoolAttr(true))); + + auto lstm_op = rewriter.create( + op->getLoc(), result_types, inputs, attributes); + + // Rewire the output. + op->getResult(2).replaceAllUsesWith(lstm_op.getResult()); + op->erase(); + return matchSuccess(); + } +}; + void LegalizeTF::runOnFunction() { OwningRewritePatternList patterns; auto* ctx = &getContext(); @@ -521,11 +639,15 @@ void LegalizeTF::runOnFunction() { // Add the generated patterns to the list. populateWithGenerated(ctx, &patterns); - patterns.insert(ctx); + patterns + .insert(ctx); + + // Ophint python converter converted tf node pattern. + patterns.insert(ctx); applyPatternsGreedily(func, patterns); } diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index bdf73ff3787..71017fe2801 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -199,6 +199,22 @@ def : Pat< (TFL_HardSwishOp $x), [(EqualOperands $x, $y)]>; +// Matching HardSwish with extra FakeQuant. These FakeQuant ops were due to +// incorrect placement in the quantization aware training. +// TODO(b/149735743): We should make the placement automatically. +def : Pat< + (TFL_MulOp (TFL_DequantizeOp (TFL_QuantizeOp + (TFL_MulOp + $x, (TFL_DequantizeOp (TFL_QuantizeOp (TFL_AddOp + $y, + (ConstantOp ConstantAttr, "3.0f">), + TFL_AF_Relu6), $qattr2)), + TFL_AF_None), $qattr1)), + (ConstantOp ConstantAttr, "0.166666666f">), + TFL_AF_None), + (TFL_HardSwishOp $x), + [(EqualOperands $x, $y)]>; + // Constraint that the attribute value is less than 'n' class ConstDoubleValueLessThan : Constraint< CPred<"$0.isa() && " diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.h b/tensorflow/compiler/mlir/lite/transforms/passes.h index 559bdc6d8e6..b713b474b3d 100644 --- a/tensorflow/compiler/mlir/lite/transforms/passes.h +++ b/tensorflow/compiler/mlir/lite/transforms/passes.h @@ -91,6 +91,9 @@ std::unique_ptr> CreateLegalizeTFWhilePass(); // Creates an instance of the TensorFlow Lite dialect WhileOp outline pass. std::unique_ptr> CreateWhileOutlinePass(); +// Verifies runtime supports types used. +std::unique_ptr> CreateRuntimeTypeVerifyPass(); + } // namespace TFL } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td b/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td index 7db615327e7..aed99a70bff 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td @@ -190,3 +190,16 @@ def : Pat<(TF_ReshapeOp:$old_value // parameters of the input, so we can remove the quantization ops. def : Pat<(TF_RankOp (TFL_DequantizeOp (TFL_QuantizeOp $input, $qtype))), (TF_RankOp $input)>; + +// `k` is expected to be 0, other values are not supported currently. +def : Pat<(TF_MatrixSetDiagV2Op $input, $diagonal, + (ConstantOp ConstantAttr)), + (TF_MatrixSetDiagOp $input, $diagonal)>; + +// `align` attribute can be ignored because we only support converting +// `MatrixSetDiagV3` to `MatrixSetDiag` with default `k` inputs. +def : Pat<(TF_MatrixSetDiagV3Op $input, $diagonal, + (ConstantOp ConstantAttr), + $align), + (TF_MatrixSetDiagOp $input, $diagonal)>; + diff --git a/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td index 5f61ae3efc3..07dd8ab4455 100644 --- a/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td @@ -21,12 +21,20 @@ include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" // Quantize attribute $0 by using quantization parameter from %1. def QuantizeByQuantizedType : NativeCodeCall<"quant::Quantize($0, $1.getValue())">; +def F32ElementsAttr : ElementsAttrBase< + CPred<"$_self.cast().getType().getElementType().isF32()">, "float constant tensor">; // Squash tfl.dequantize and tfl.quantize pairs. // TODO(fengliuai): Compare the scale of input and output. This can also be // squashed to a requantize op if the scales are different. def : Pat<(TFL_QuantizeOp (TFL_DequantizeOp $in), $qt), (replaceWithValue $in)>; +// If the tfl.dequantize op wasn't fused, we shouldn't quantize the floating +// point constant. +def : Pat<(TFL_DequantizeOp + (TFL_QuantizeOp (ConstantOp F32ElementsAttr:$cst), $qt)), + (ConstantOp $cst)>; + // Quantize the value of a constant op if the quantization parameters have been // propagated to the output. def : Pat<(TFL_QuantizeOp diff --git a/tensorflow/compiler/mlir/lite/transforms/runtime_type_verify.cc b/tensorflow/compiler/mlir/lite/transforms/runtime_type_verify.cc new file mode 100644 index 00000000000..2a35701f0e6 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/transforms/runtime_type_verify.cc @@ -0,0 +1,52 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/IR/OperationSupport.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" + +namespace mlir { +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.h.inc" +namespace TFL { +namespace { + +// This pass verifies that the operands and results types are supported by +// TFLite runtime. +class RuntimeTypeVerifyPass : public mlir::FunctionPass { + public: + explicit RuntimeTypeVerifyPass() {} + + private: + void runOnFunction() override; +}; + +void RuntimeTypeVerifyPass::runOnFunction() { + getFunction().walk([&](TflRuntimeVerifyOpInterface op) { + if (failed(op.VerifyTflRuntimeTypes(op.getOperation()))) + signalPassFailure(); + }); +} +} // namespace + +// Verifies runtime supports types used. +std::unique_ptr> CreateRuntimeTypeVerifyPass() { + return std::make_unique(); +} + +static PassRegistration pass( + "tfl-runtime-verify", "TFLite runtime verification"); + +} // namespace TFL +} // namespace mlir diff --git a/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc b/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc index babfb478881..63f558bc9c5 100644 --- a/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc +++ b/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc @@ -168,6 +168,10 @@ std::string OpOrArgLocNameMapper::GetName(OpOrVal op_or_val) { result.getResultNumber()); return std::string(result.getOwner()->getName().getStringRef()); } + // Use the ASM syntax for BloackArgument + if (auto arg = val.dyn_cast()) { + return "arg" + std::to_string(arg.getArgNumber()); + } return ""; } diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 0058e949969..2cfed42d76a 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -287,6 +287,7 @@ cc_library( "transforms/materialize_mlir_passthrough_op.cc", "transforms/optimize.cc", "transforms/optimize_global_tensors.cc", + "transforms/parallel_execute_to_islands.cc", "transforms/promote_resources_to_args.cc", "transforms/raise_control_flow.cc", "transforms/replicate_invariant_op_hoisting.cc", @@ -708,7 +709,6 @@ cc_library( deps = [ ":tensorflow_dialect_registration", ":tf_dialect_passes", - "@llvm-project//mlir:AllPassesAndDialects", ], ) @@ -913,7 +913,6 @@ cc_library( "//tensorflow/core/platform:logging", "//tensorflow/stream_executor/lib", "@llvm-project//llvm:support", - "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc index 5c277eeb9db..c88ddaf7806 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc @@ -41,11 +41,52 @@ limitations under the License. #include "mlir/Support/LLVM.h" // TF:llvm-project #include "mlir/Support/LogicalResult.h" // TF:llvm-project #include "mlir/Support/STLExtras.h" // TF:llvm-project +#include "mlir/Transforms/InliningUtils.h" // TF:llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/core/platform/logging.h" namespace mlir { namespace tf_device { +//===----------------------------------------------------------------------===// +// TF Device Dialect Interfaces +//===----------------------------------------------------------------------===// + +namespace { +struct TFInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + //===--------------------------------------------------------------------===// + // Analysis Hooks + //===--------------------------------------------------------------------===// + + // Defines the legality of inlining TF Device operations. + bool isLegalToInline(Operation*, Region*, BlockAndValueMapping&) const final { + // For now, enable inlining all operations. + return true; + } + + //===--------------------------------------------------------------------===// + // Transformation Hooks + //===--------------------------------------------------------------------===// + + // Attempts to materialize a conversion for a type mismatch between a call + // from this dialect, and a callable region. This method should generate an + // operation that takes 'input' as the only operand, and produces a single + // result of 'resultType'. If a conversion can not be generated, nullptr + // should be returned. + // This is just re-using the same logic as the TensorFlow dialect right now. + Operation* materializeCallConversion(OpBuilder& builder, Value input, + Type result_type, + Location conversion_loc) const final { + if (!result_type.isa() || !input.getType().isa()) + return nullptr; + return builder.create(conversion_loc, result_type, input, + /*truncate=*/builder.getBoolAttr(false)); + } +}; +} // end anonymous namespace + TensorFlowDeviceDialect::TensorFlowDeviceDialect(MLIRContext* context) : Dialect(/*name=*/"tf_device", context) { addOperations< @@ -54,6 +95,8 @@ TensorFlowDeviceDialect::TensorFlowDeviceDialect(MLIRContext* context) >(); addOperations(); + + addInterfaces(); } //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc index 4b6ff55e5ea..c6144ec21e3 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc @@ -573,9 +573,9 @@ void Print(SwitchNOp switchn, OpAsmPrinter &p) { ParseResult ParseSwitchNOp(OpAsmParser &parser, OperationState &result) { // Parsing: - // %2:6 = tf_executor.SwitchN %0, %1 by 5 : tensor + // %2:6 = tf_executor.SwitchN %0, %1 of 5 : tensor // Where the first operand is the data to replicate, the second is an i32 - // indicating which output to populate, followed by the keyword `by` and the + // indicating which output to populate, followed by the keyword `of` and the // number of outputs (+1 for the control token). SmallVector op_infos; SmallVector types; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td index 0987ae3d668..38f72f24bd1 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td @@ -165,7 +165,7 @@ def TfExecutor_IslandOp : TfExecutor_Op<"island", The `tf_executor.island` operation has a single region with a single block attached (only functional control flow is allowed). The block is terminated by a `tf_executor.yield` operation. The operands of the terminator - correspond to the result values of the `tf_executor.graph` operation. An + correspond to the result values of the `tf_executor.island` operation. An extra result of type `!tf_executor.control` is always produced by every `tf_executor.island`. Within an island, execution semantics follow standard sequential behavior as @@ -299,7 +299,7 @@ def TfExecutor_SwitchNOp : TfExecutor_Op<"SwitchN", .SetShapeFn(SwitchNShape); For example: - %2:6 = tf_executor.SwitchN %0, %1 by 5 : tensor + %2:6 = tf_executor.SwitchN %0, %1 of 5 : tensor Note: One additional result corresponds to the control output. }]; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 9b9a727d66e..411ba653bec 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -49,7 +49,7 @@ an output element, this operation computes \\(y = |x|\\). TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_AddOp : TF_Op<"Add", [NoSideEffect, ResultsBroadcastableShape]>, +def TF_AddOp : TF_Op<"Add", [NoSideEffect, ResultsBroadcastableShape, TF_LayoutAgnostic]>, WithBroadcastableBinOpBuilder { let summary = "Returns x + y element-wise."; @@ -98,7 +98,7 @@ Inputs must be of same size and shape. let hasFolder = 1; } -def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastableShape]>, +def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_LayoutAgnostic]>, WithBroadcastableBinOpBuilder { let summary = "Returns x + y element-wise."; @@ -508,8 +508,9 @@ Broadcasting is supported, so `value` may have any number of dimensions. let extraClassDeclaration = [{ // TF_LayoutSensitiveInterface: - SmallVector GetLayoutDependentArgs() { return {0}; } - SmallVector GetLayoutDependentResults() { return {0}; } + SmallVector GetLayoutDependentArgs() { return {0}; } + SmallVector GetLayoutDependentResults() { return {0}; } + LogicalResult UpdateDataFormat(StringRef data_format); }]; } @@ -980,7 +981,7 @@ tf.conj(input) ==> [-2.25 - 4.75j, 3.25 - 5.75j] let hasCanonicalizer = 1; } -def TF_Conv2DOp : TF_Op<"Conv2D", [NoSideEffect]> { +def TF_Conv2DOp : TF_Op<"Conv2D", [NoSideEffect, TF_LayoutSensitiveInterface]> { let summary = [{ Computes a 2-D convolution given 4-D `input` and `filter` tensors. }]; @@ -1030,6 +1031,13 @@ horizontal and vertices strides, `strides = [1, stride, stride, 1]`. let verifier = [{ return Verify(*this); }]; + + let extraClassDeclaration = [{ + // TF_LayoutSensitiveInterface: + SmallVector GetLayoutDependentArgs() { return {0}; } + SmallVector GetLayoutDependentResults() { return {0}; } + LogicalResult UpdateDataFormat(StringRef data_format); + }]; } def TF_Conv2DBackpropFilterOp : TF_Op<"Conv2DBackpropFilter", [NoSideEffect]> { @@ -2091,7 +2099,7 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors. TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<3>; } -def TF_FusedBatchNormV3Op : TF_Op<"FusedBatchNormV3", [NoSideEffect]> { +def TF_FusedBatchNormV3Op : TF_Op<"FusedBatchNormV3", [NoSideEffect, TF_FoldOperandsTransposeInterface]> { let summary = "Batch normalization."; let description = [{ @@ -2122,6 +2130,13 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>; + + let extraClassDeclaration = [{ + // TF_FoldOperandsTransposeInterface: + SmallVector GetLayoutDependentArgs() { return {0}; } + SmallVector GetLayoutDependentResults() { return {0}; } + LogicalResult FoldOperandsPermutation(ArrayRef permutation); + }]; } def TF_GatherOp : TF_Op<"Gather", [NoSideEffect]> { @@ -3392,6 +3407,130 @@ tf.matrix_diag(diagonal, k = -1, num_rows = 3, padding_value = 9) TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_MatrixSetDiagOp : TF_Op<"MatrixSetDiag", [NoSideEffect]> { + let summary = [{ +Returns a batched matrix tensor with new batched diagonal values. + }]; + + let description = [{ +Given `input` and `diagonal`, this operation returns a tensor with the +same shape and values as `input`, except for the main diagonal of the +innermost matrices. These will be overwritten by the values in `diagonal`. + +The output is computed as follows: + +Assume `input` has `k+1` dimensions `[I, J, K, ..., M, N]` and `diagonal` has +`k` dimensions `[I, J, K, ..., min(M, N)]`. Then the output is a +tensor of rank `k+1` with dimensions `[I, J, K, ..., M, N]` where: + + * `output[i, j, k, ..., m, n] = diagonal[i, j, k, ..., n]` for `m == n`. + * `output[i, j, k, ..., m, n] = input[i, j, k, ..., m, n]` for `m != n`. + }]; + + let arguments = (ins + TF_Tensor:$input, + TF_Tensor:$diagonal + ); + + let results = (outs + TF_Tensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_MatrixSetDiagV2Op : TF_Op<"MatrixSetDiagV2", [NoSideEffect]> { + let summary = [{ +Returns a batched matrix tensor with new batched diagonal values. + }]; + + let description = [{ +Given `input` and `diagonal`, this operation returns a tensor with the +same shape and values as `input`, except for the specified diagonals of the +innermost matrices. These will be overwritten by the values in `diagonal`. + +`input` has `r+1` dimensions `[I, J, ..., L, M, N]`. When `k` is scalar or +`k[0] == k[1]`, `diagonal` has `r` dimensions `[I, J, ..., L, max_diag_len]`. +Otherwise, it has `r+1` dimensions `[I, J, ..., L, num_diags, max_diag_len]`. +`num_diags` is the number of diagonals, `num_diags = k[1] - k[0] + 1`. +`max_diag_len` is the longest diagonal in the range `[k[0], k[1]]`, +`max_diag_len = min(M + min(k[1], 0), N + min(-k[0], 0))` + +The output is a tensor of rank `k+1` with dimensions `[I, J, ..., L, M, N]`. +If `k` is scalar or `k[0] == k[1]`: + +``` +output[i, j, ..., l, m, n] + = diagonal[i, j, ..., l, n-max(k[1], 0)] ; if n - m == k[1] + input[i, j, ..., l, m, n] ; otherwise +``` + +Otherwise, + +``` +output[i, j, ..., l, m, n] + = diagonal[i, j, ..., l, diag_index, index_in_diag] ; if k[0] <= d <= k[1] + input[i, j, ..., l, m, n] ; otherwise +``` +where `d = n - m`, `diag_index = k[1] - d`, and `index_in_diag = n - max(d, 0)`. + +For example: + +``` +# The main diagonal. +input = np.array([[[7, 7, 7, 7], # Input shape: (2, 3, 4) + [7, 7, 7, 7], + [7, 7, 7, 7]], + [[7, 7, 7, 7], + [7, 7, 7, 7], + [7, 7, 7, 7]]]) +diagonal = np.array([[1, 2, 3], # Diagonal shape: (2, 3) + [4, 5, 6]]) +tf.matrix_set_diag(diagonal) ==> [[[1, 7, 7, 7], # Output shape: (2, 3, 4) + [7, 2, 7, 7], + [7, 7, 3, 7]], + [[4, 7, 7, 7], + [7, 5, 7, 7], + [7, 7, 6, 7]]] + +# A superdiagonal (per batch). +tf.matrix_set_diag(diagonal, k = 1) + ==> [[[7, 1, 7, 7], # Output shape: (2, 3, 4) + [7, 7, 2, 7], + [7, 7, 7, 3]], + [[7, 4, 7, 7], + [7, 7, 5, 7], + [7, 7, 7, 6]]] + +# A band of diagonals. +diagonals = np.array([[[1, 2, 3], # Diagonal shape: (2, 2, 3) + [4, 5, 0]], + [[6, 1, 2], + [3, 4, 0]]]) +tf.matrix_set_diag(diagonals, k = (-1, 0)) + ==> [[[1, 7, 7, 7], # Output shape: (2, 3, 4) + [4, 2, 7, 7], + [0, 5, 3, 7]], + [[6, 7, 7, 7], + [3, 1, 7, 7], + [7, 4, 2, 7]]] + +``` + }]; + + let arguments = (ins + TF_Tensor:$input, + TF_Tensor:$diagonal, + I32Tensor:$k + ); + + let results = (outs + TF_Tensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_MatrixSetDiagV3Op : TF_Op<"MatrixSetDiagV3", [NoSideEffect]> { let summary = [{ Returns a batched matrix tensor with new batched diagonal values. @@ -3551,7 +3690,7 @@ retained with length 1. >]; } -def TF_MaxPoolOp : TF_Op<"MaxPool", [NoSideEffect]> { +def TF_MaxPoolOp : TF_Op<"MaxPool", [NoSideEffect, TF_FoldOperandsTransposeInterface]> { let summary = "Performs max pooling on the input."; let description = [{ @@ -3571,6 +3710,13 @@ def TF_MaxPoolOp : TF_Op<"MaxPool", [NoSideEffect]> { ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + + let extraClassDeclaration = [{ + // TF_FoldOperandsTransposeInterface: + SmallVector GetLayoutDependentArgs() { return {0}; } + SmallVector GetLayoutDependentResults() { return {0}; } + LogicalResult FoldOperandsPermutation(ArrayRef permutation); + }]; } def TF_MaxPoolGradOp : TF_Op<"MaxPoolGrad", [NoSideEffect]> { @@ -4714,7 +4860,7 @@ I.e., \\(y = 1 / x\\). let hasCanonicalizer = 1; } -def TF_ReluOp : TF_Op<"Relu", [NoSideEffect, SameOperandsAndResultType]> { +def TF_ReluOp : TF_Op<"Relu", [NoSideEffect, SameOperandsAndResultType, TF_LayoutAgnostic]> { let summary = "Computes rectified linear: `max(features, 0)`."; let description = [{ @@ -6657,7 +6803,7 @@ variables. TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>; } -def TF_TanhOp : TF_Op<"Tanh", [NoSideEffect, SameOperandsAndResultType]> { +def TF_TanhOp : TF_Op<"Tanh", [NoSideEffect, SameOperandsAndResultType, TF_LayoutAgnostic]> { let summary = "Computes hyperbolic tangent of `x` element-wise."; let description = [{ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td index b8d5e59f1a8..f3fdab674e4 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td @@ -58,6 +58,10 @@ TODO: Make invariants more structured so that we can reference them in ops. def TF_OperandsSameAsResultsTypeOrRef : NativeOpTrait< "TF::OperandsSameAsResultsTypeOrRef">; +// Layout agnostic operations do not depend on the operands data layout (data +// format), as an example all element wise operations are layout agnostic. +def TF_LayoutAgnostic : NativeOpTrait<"TF::LayoutAgnostic">; + //===----------------------------------------------------------------------===// // TensorFlow op definitions //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td index b887f966cbd..cc0819d71c9 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td @@ -44,11 +44,17 @@ def TF_LayoutSensitiveInterface : OpInterface<"LayoutSensitiveInterface"> { >, InterfaceMethod< [{Returns indices of layout dependent arguments.}], - "SmallVector", "GetLayoutDependentArgs", (ins) + "SmallVector", "GetLayoutDependentArgs", (ins) >, InterfaceMethod< [{Returns indices of layout dependent results.}], - "SmallVector", "GetLayoutDependentResults", (ins) + "SmallVector", "GetLayoutDependentResults", (ins) + >, + InterfaceMethod< + [{Updates operation attributes and operands to account for the updated + data format. If data format is not supported, must return failure.}], + "LogicalResult", "UpdateDataFormat", + (ins "StringRef":$data_format) >, ]; @@ -57,4 +63,42 @@ def TF_LayoutSensitiveInterface : OpInterface<"LayoutSensitiveInterface"> { }]; } +def TF_FoldOperandsTransposeInterface : OpInterface<"FoldOperandsTransposeInterface"> { + let description = [{ + Operation supports folding operand(s) transposes into the operation itself. + + (1) Operation might have layout dependent operands and results... + + Example: MaxPool(Transpose($arg, $perm)) + -> Transpose(MaxPool($arg, $perm)) + + (2) ... or it might have only layout dependent operands: + + Example: Mean(Transpose($arg, $reduction_dims)) + -> Mean($arg, Transpose($reduction_dims)) + }]; + + let methods = [ + InterfaceMethod< + [{Returns indices of layout dependent arguments.}], + "SmallVector", "GetLayoutDependentArgs", (ins) + >, + InterfaceMethod< + [{Returns indices of layout dependent results.}], + "SmallVector", "GetLayoutDependentResults", (ins) + >, + InterfaceMethod< + [{Updates operation attributes and operands to account for the folded + permutation. If folding of permutation is not possible, must return + failure.}], + "LogicalResult", "FoldOperandsPermutation", + (ins "ArrayRef":$permutation) + >, + ]; + + let verify = [{ + return VerifyFoldOperandsTransposeInterface($_op); + }]; +} + #endif // TF_OP_INTERFACES diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index 0d70d8793ee..b206b281754 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -292,6 +292,156 @@ static LogicalResult VerifyTypesCompatibility( return success(); } +//===----------------------------------------------------------------------===// +// TF op helper functions to work with layout transformation. +//===----------------------------------------------------------------------===// + +SmallVector GetDataFormatPermutation(StringRef from, StringRef to) { + if (from == "NHWC" && to == "NCHW") { + return {0, 3, 1, 2}; + } else if (from == "NCHW" && to == "NHWC") { + return {0, 2, 3, 1}; + } else { + return {}; + } +} + +// Shuffle elements in the `attr` according to the permutation. Optional +// `inner_size` allows to shuffle array attributes created from rank 2 tensors +// on outer dimension only. +ArrayAttr ShuffleArrayAttr(ArrayAttr attr, ArrayRef permutation, + int inner_size = 1) { + if (attr.size() == 0) return attr; + + assert(attr.size() % inner_size == 0); + assert(attr.size() / inner_size == permutation.size()); + + SmallVector values{attr.begin(), attr.end()}; + SmallVector shuffled(values.size()); + + for (size_t i = 0; i < permutation.size(); ++i) { + for (size_t j = 0; j < inner_size; ++j) { + shuffled[i * inner_size + j] = values[permutation[i] * inner_size + j]; + } + } + + return ArrayAttr::get(shuffled, attr.getContext()); +} + +// Shuffle ranked tensor dimensions according to the permutation. +Type ShuffleRankedTensorType(Type type, ArrayRef permutation) { + if (auto ranked_type = type.dyn_cast()) { + ArrayRef shape = ranked_type.getShape(); + assert(permutation.size() == shape.size()); + + SmallVector new_shape(permutation.size()); + for (size_t i = 0; i < permutation.size(); ++i) + new_shape[i] = shape[permutation[i]]; + + return RankedTensorType::get(new_shape, ranked_type.getElementType()); + } + + return type; +} + +static bool AreCancellablePermutations(DenseIntElementsAttr perm0, + DenseIntElementsAttr perm1) { + if (perm0.getNumElements() == 0 || perm1.getNumElements() == 0) return false; + if (perm0.getNumElements() != perm1.getNumElements()) return false; + + SmallVector perm0_values; + for (auto value : perm0.getIntValues()) + perm0_values.push_back(value.getSExtValue()); + + SmallVector perm1_values; + for (auto value : perm1.getIntValues()) + perm1_values.push_back(value.getSExtValue()); + + for (int i = 0; i < perm0_values.size(); ++i) { + if (perm0_values[perm1_values[i]] != i) return false; + } + + return true; +} + +// Default implementation of `LayoutSensitiveInterface::UpdateDataFormat` for +// layout sensitive operations that do not have any additional layout dependent +// attributes besides `data_format` string. +template +LogicalResult UpdateDataFormat(StringRef data_format, Op *op) { + auto perm = GetDataFormatPermutation(op->data_format(), data_format); + if (perm.empty()) return failure(); + + // Update data format attribute. + op->setAttr("data_format", StringAttr::get(data_format, op->getContext())); + + // Update types for all layout sensitive results. + auto layout_sensitive = cast(op->getOperation()); + for (unsigned idx : layout_sensitive.GetLayoutDependentResults()) { + OpResult result = op->getOperation()->getResult(idx); + result.setType(ShuffleRankedTensorType(result.getType(), perm)); + } + + return success(); +} + +// Default implementation for folding operand transpose into the operation. +// See `FoldOperandsTransposeInterface::FoldOperandsPermutation`. +template +LogicalResult FoldOperandsPermutation( + ArrayRef permutation, Op *op, + ArrayRef> shuffle_attrs = {}) { + MLIRContext *context = op->template getParentOfType().getContext(); + + // We only support NHWC <-> NCHW permutations. + static constexpr std::array kNchwToNhwc = {0, 2, 3, 1}; + static constexpr std::array kNhwcToNchw = {0, 3, 1, 2}; + + // Operation data format after folding `permutation`. + StringRef target_data_format = [&]() -> StringRef { + if (op->data_format() == "NHWC" && permutation.equals(kNchwToNhwc)) { + return "NCHW"; // cancel NCHW->NHWC operand permutation + } else if (op->data_format() == "NCHW" && permutation.equals(kNhwcToNchw)) { + return "NHWC"; // cancel NHWC->NCHW operand permutation + } else { + return ""; + } + }(); + if (target_data_format.empty()) return failure(); + + // To fold operand `permutation` into the `op` we need shuffle all layout + // dependent attributes and types with a reverse permutation, and change + // operation data format to `target_data_format`. + // + // Example: + // %1 = SomeOp(...) {data_format = NHWC} + // %2 = Transpose(%1) {permutation = NHWC->NCHW} + // %3 = Op(%2) {data_format = NCHW} + // + // To bypass %2 we have to change data format to shuffle data format from NCHW + // to NHWC, which is the reverse of operand permutation (function argument). + auto reverse_permutation = + GetDataFormatPermutation(op->data_format(), target_data_format); + if (reverse_permutation.empty()) return failure(); + + op->setAttr("data_format", StringAttr::get(target_data_format, context)); + + for (auto pair : shuffle_attrs) { + StringRef attr_name = pair.first; + ArrayAttr attr_value = pair.second; + op->setAttr(attr_name, ShuffleArrayAttr(attr_value, reverse_permutation)); + } + + auto fold = cast(op->getOperation()); + for (unsigned idx : fold.GetLayoutDependentResults()) { + OpResult result = op->getOperation()->getResult(idx); + result.setType( + ShuffleRankedTensorType(result.getType(), reverse_permutation)); + } + + return success(); +} + namespace { #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc" } // namespace @@ -459,6 +609,15 @@ static LogicalResult Verify(BiasAddOp op) { return success(); } +// TODO(ezhulenev): BiasAddOp is not really layout sensitive, it must only +// support folding operand transposes. +LogicalResult BiasAddOp::UpdateDataFormat(StringRef data_format) { + auto ranked = value().getType().dyn_cast(); + if (!ranked || ranked.getRank() != 4) return failure(); + + return ::mlir::TF::UpdateDataFormat(data_format, this); +} + //===----------------------------------------------------------------------===// // BiasAddGradOp //===----------------------------------------------------------------------===// @@ -817,6 +976,21 @@ static LogicalResult Verify(OpT op) { return success(); } +LogicalResult Conv2DOp::UpdateDataFormat(StringRef data_format) { + auto perm = GetDataFormatPermutation(this->data_format(), data_format); + if (perm.empty()) return failure(); + + // Update data_format attribute and result types. + if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure(); + + // Update convolution attributes. + setAttr("dilations", ShuffleArrayAttr(dilations(), perm)); + setAttr("strides", ShuffleArrayAttr(strides(), perm)); + setAttr("explicit_paddings", ShuffleArrayAttr(explicit_paddings(), perm, 2)); + + return success(); +} + //===----------------------------------------------------------------------===// // Conv2dBackpropInputOp //===----------------------------------------------------------------------===// @@ -1138,6 +1312,11 @@ static LogicalResult Verify(FusedBatchNormOp op) { return success(); } +LogicalResult FusedBatchNormV3Op::FoldOperandsPermutation( + ArrayRef permutation) { + return ::mlir::TF::FoldOperandsPermutation(permutation, this); +} + //===----------------------------------------------------------------------===// // GatherV2Op //===----------------------------------------------------------------------===// @@ -1330,6 +1509,16 @@ void MaxOp::build(Builder *builder, OperationState &result, Value input, build(builder, result, out_ty, input, reduction_indices, keep_dims); } +//===----------------------------------------------------------------------===// +// MaxPoolOp +//===----------------------------------------------------------------------===// + +LogicalResult MaxPoolOp::FoldOperandsPermutation( + ArrayRef permutation) { + return ::mlir::TF::FoldOperandsPermutation( + permutation, this, {{"strides", strides()}, {"ksize", ksize()}}); +} + //===----------------------------------------------------------------------===// // MaxPoolGradOp //===----------------------------------------------------------------------===// @@ -1347,6 +1536,38 @@ static LogicalResult Verify(MaxPoolGradOp op) { return success(); } +//===----------------------------------------------------------------------===// +// MeanOp +//===----------------------------------------------------------------------===// + +LogicalResult MeanOp::FoldOperandsPermutation(ArrayRef permutation) { + // Reduction indices must be defined by a constant operation. + auto reduction_op = + dyn_cast_or_null(reduction_indices().getDefiningOp()); + if (!reduction_op) return failure(); + + auto reductions_value = reduction_op.value().dyn_cast(); + if (!reductions_value) return failure(); + + // Prepare new reduction indices according to operand permutation. + SmallVector shuffled_reduction; + llvm::transform(reductions_value.getIntValues(), + std::back_inserter(shuffled_reduction), + [&](APInt idx) { return permutation[idx.getSExtValue()]; }); + + // Add constant operation with a new reduction indices. + OpBuilder builder(getOperation()); + auto type = mlir::RankedTensorType::get(shuffled_reduction.size(), + builder.getIntegerType(64)); + auto values = mlir::DenseIntElementsAttr::get(type, shuffled_reduction); + auto shuffled_reduction_op = builder.create(getLoc(), values); + + // Use new reduction indices. + setOperand(1, shuffled_reduction_op); + + return success(); +} + //===----------------------------------------------------------------------===// // NegOp //===----------------------------------------------------------------------===// @@ -2723,23 +2944,46 @@ void TransposeOp::build(Builder *builder, OperationState &result, Value x, perm); } -OpFoldResult TransposeOp::fold(ArrayRef operands) { - auto const_perm = dyn_cast_or_null(perm().getDefiningOp()); +namespace { - if (!const_perm) { - return {}; - } +OpFoldResult FoldIdentityTranspose(TransposeOp op) { + auto const_perm = dyn_cast_or_null(op.perm().getDefiningOp()); + if (!const_perm) return {}; auto const_value = const_perm.value(); - const auto &elements = const_value.getValues(); + for (auto it : llvm::enumerate(elements)) { - if (it.index() != it.value()) { - return {}; - } + if (it.index() != it.value()) return {}; } - return x(); + return op.x(); +} + +OpFoldResult FoldCancellableTranspose(TransposeOp op) { + // Operand is a TransposeOp. + auto transpose = dyn_cast_or_null(op.x().getDefiningOp()); + if (!transpose) return {}; + + // Permutations defined by constant operations. + auto perm0 = dyn_cast_or_null(op.perm().getDefiningOp()); + auto perm1 = dyn_cast_or_null(transpose.perm().getDefiningOp()); + if (!perm0 || !perm1) return {}; + + // With permutation indices that cancel each other + auto perm0_value = perm0.value().cast(); + auto perm1_value = perm1.value().cast(); + if (!AreCancellablePermutations(perm0_value, perm1_value)) return {}; + + return transpose.x(); +} + +} // namespace + +OpFoldResult TransposeOp::fold(ArrayRef operands) { + if (auto folded = FoldIdentityTranspose(*this)) return folded; + if (auto folded = FoldCancellableTranspose(*this)) return folded; + return {}; } //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index b391d5284a5..e95fcbbdad3 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -172,7 +172,7 @@ else_branch: A function that takes 'inputs' and returns a list of }]; } -def TF_MeanOp : TF_Op<"Mean", [NoSideEffect]> { +def TF_MeanOp : TF_Op<"Mean", [NoSideEffect, TF_FoldOperandsTransposeInterface]> { let summary = "Computes the mean of elements across dimensions of a tensor."; let description = [{ @@ -195,6 +195,13 @@ retained with length 1. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>; + + let extraClassDeclaration = [{ + // TF_FoldOperandsTransposeInterface: + SmallVector GetLayoutDependentArgs() { return {0}; } + SmallVector GetLayoutDependentResults() { return {}; } + LogicalResult FoldOperandsPermutation(ArrayRef permutation); + }]; } def TF_LegacyCallOp : TF_Op<"LegacyCall", diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc index 21b5354eeb8..8d3253ef81f 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc @@ -112,12 +112,26 @@ static LogicalResult VerifyIndexPath(Operation *op, NamedAttribute named_attr) { return mlir::success(); } -// Return true if `type` is a tensor of `!tf.resource`. This is the type that is -// used to represent mutable variables on exported functions' bound inputs. -static bool IsResourceVarType(Type type) { - TensorType tensor_type = type.dyn_cast(); - if (!tensor_type) return false; - return tensor_type.getElementType().isa(); +static LogicalResult VerifyBoundInputArgType(Operation *op_for_diagnostics, + Type arg_type, + GlobalTensorOp global_tensor) { + if (global_tensor.is_mutable()) { + auto expected_type = RankedTensorType::get( + {}, TF::ResourceType::get({global_tensor.type().cast()}, + arg_type.getContext())); + if (arg_type != expected_type) { + return op_for_diagnostics->emitError() + << "mutable bound input with type " << arg_type + << " expected to have type " << expected_type; + } + } else { + if (arg_type != global_tensor.type()) { + return op_for_diagnostics->emitError() + << "bound input for immutable 'tf_saved_model.global_tensor' must " + "match the global tensor's type"; + } + } + return success(); } LogicalResult TensorFlowSavedModelDialect::verifyRegionArgAttribute( @@ -137,20 +151,7 @@ LogicalResult TensorFlowSavedModelDialect::verifyRegionArgAttribute( << symbol_name << "'"; } auto arg_type = cast(op).getArgument(arg_index).getType(); - if (global_tensor.is_mutable()) { - if (!IsResourceVarType(arg_type)) { - return op->emitError() - << "bound inputs for mutable 'tf_saved_model.global_tensor's " - "must be tensors of '!tf.resource'"; - } - } else { - if (arg_type != global_tensor.type()) { - return op->emitError() << "bound input for immutable " - "'tf_saved_model.global_tensor' must " - "match the global tensor's type"; - } - } - return success(); + return VerifyBoundInputArgType(op, arg_type, global_tensor); } if (named_attr.first == "tf_saved_model.index_path") { return VerifyIndexPath(op, named_attr); diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h index 51315c4f90c..18beb23663c 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h @@ -68,6 +68,11 @@ class OperandsSameAsResultsTypeOrRef } }; +// Layout agnostic operations do not depend on the operands data layout (data +// format), as and example all element wise operations are layout agnostic. +template +class LayoutAgnostic : public TraitBase {}; + } // namespace TF } // namespace OpTrait } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.cc index 379797c99e4..247df44a90a 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.cc @@ -21,23 +21,35 @@ limitations under the License. namespace mlir { namespace TF { -LogicalResult VerifyLayoutSensitiveInterface(Operation* op) { - auto layout_sensitive_interface = cast(op); +namespace { - if (!llvm::all_of( - layout_sensitive_interface.GetLayoutDependentArgs(), - [&](int64_t index) { return index < op->getNumOperands(); })) { +template +LogicalResult VerifyLayoutDependentArgsAndResults(Operation* op, + Interface interface) { + auto valid_operand = [&](int64_t idx) { return idx < op->getNumOperands(); }; + if (!llvm::all_of(interface.GetLayoutDependentArgs(), valid_operand)) { return op->emitOpError("layout dependent argument index is out of bound"); } - if (!llvm::all_of( - layout_sensitive_interface.GetLayoutDependentResults(), - [&](int64_t index) { return index < op->getNumResults(); })) { + auto valid_result = [&](int64_t idx) { return idx < op->getNumResults(); }; + if (!llvm::all_of(interface.GetLayoutDependentResults(), valid_result)) { return op->emitOpError("layout dependent result index is out of bound"); } return success(); } +} // namespace + +LogicalResult VerifyLayoutSensitiveInterface(Operation* op) { + auto layout_sensitive_interface = cast(op); + return VerifyLayoutDependentArgsAndResults(op, layout_sensitive_interface); +} + +LogicalResult VerifyFoldOperandsTransposeInterface(Operation* op) { + auto fold_operands_transpose = cast(op); + return VerifyLayoutDependentArgsAndResults(op, fold_operands_transpose); +} + } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h index 776f0a9022a..5289328e73f 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h @@ -29,6 +29,12 @@ namespace TF { // [0, getNumOperands/getNumResults) range. LogicalResult VerifyLayoutSensitiveInterface(Operation* op); +// Verifies correctness of ops implementing FoldOperandsTransposeInterface (see +// definition in tf_op_base.td): +// (1) Layout dependent arguments and results indices must be in +// [0, getNumOperands/getNumResults) range. +LogicalResult VerifyFoldOperandsTransposeInterface(Operation* op); + } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir b/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir index d90c9201a83..61e0772726c 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir @@ -3,6 +3,10 @@ // All tests also test for idempotence. +// Test that external functions aren't processed (used to crash). +// CHECK-LABEL: func @unused_external_func +func @unused_external_func() + func @multiple_return(%arg0: tensor<*xi32>, %arg1: tensor) -> (tensor<*xi32>, tensor<*xi32>) { %graph:2 = tf_executor.graph { %island:3 = tf_executor.island { @@ -276,3 +280,67 @@ func @empty_island_multiple_data_results(%arg0: tensor<*xf32>, %arg1: tensor<*xi } return } + +// The following tests check that certain control dependencies between islands +// and certain tf_executor ops are added correctly. + +// CHECK: %[[CONTROL:[^ ,]*]] = tf_executor.island wraps "tf.Print" +// CHECK: tf_executor.NextIteration.Sink [{{.*}}] {{.*}}, %[[CONTROL]] +func @next_iteration_sink_control_input() { + tf_executor.graph { + %source:3 = tf_executor.NextIteration.Source : tensor<*xi32> + %island:2 = tf_executor.island { + %const = "tf.Const"() {value = dense<1> : tensor} : () -> tensor<*xi32> + %print = "tf.Print"(%const) : (tensor<*xi32>) -> (tensor<*xi32>) + tf_executor.yield %const : tensor<*xi32> + } + tf_executor.NextIteration.Sink[%source#1] %island#0 : tensor<*xi32> + tf_executor.fetch %island#0 : tensor<*xi32> + } + return +} + +// CHECK: %[[CONTROL:[^ ,]*]] = tf_executor.island wraps "tf.Print" +// CHECK: tf_executor.LoopCond {{.*}}, %[[CONTROL]] +func @loop_cond_control_input() { + tf_executor.graph { + %island:2 = tf_executor.island { + %const = "tf.Const"() {value = dense<1> : tensor} : () -> tensor<*xi1> + %print = "tf.Print"(%const) : (tensor<*xi1>) -> (tensor<*xi1>) + tf_executor.yield %const : tensor<*xi1> + } + %loop_cond:2 = tf_executor.LoopCond %island#0 : tensor<*xi1> + tf_executor.fetch %loop_cond#0 : tensor<*xi1> + } + return +} + +// CHECK: %[[CONTROL:[^ ,]*]] = tf_executor.island wraps "tf.Print" +// CHECK: tf_executor.Enter {{.*}}, %[[CONTROL]] +func @enter_control_input() { + tf_executor.graph { + %island:2 = tf_executor.island { + %const = "tf.Const"() {value = dense<1> : tensor} : () -> tensor<*xi32> + %print = "tf.Print"(%const) : (tensor<*xi32>) -> (tensor<*xi32>) + tf_executor.yield %const : tensor<*xi32> + } + %enter:2 = tf_executor.Enter %island#0 frame "some/frame" : tensor<*xi32> + tf_executor.fetch %enter#0 : tensor<*xi32> + } + return +} + +// CHECK: %[[CONTROL:[^ ,]*]] = tf_executor.island wraps "tf.Print" +// CHECK: tf_executor.SwitchN {{.*}}, {{.*}} of {{[0-9]*}} (%[[CONTROL]]) +func @switchn_control_input(%arg1: tensor) { + tf_executor.graph { + %island:2 = tf_executor.island { + %const = "tf.Const"() {value = dense<1> : tensor} : () -> tensor<*xi32> + %print = "tf.Print"(%const) : (tensor<*xi32>) -> (tensor<*xi32>) + tf_executor.yield %const : tensor<*xi32> + } + %switchn:4 = tf_executor.SwitchN %island#0, %arg1 of 3: tensor<*xi32> + tf_executor.fetch %switchn#0 : tensor<*xi32> + } + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index c91c1e2f7b5..5bf5b0610ae 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -383,6 +383,28 @@ func @nonIdentityTranspose(%arg0: tensor<2x3x4x5x6xf32>) -> tensor<2x3x4x6x5xf32 // CHECK: return %1 } +// CHECK-LABEL: @cancellableTranspose +func @cancellableTranspose(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32> { + %0 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32> + %1 = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32> + %2 = "tf.Transpose"(%arg0, %0) : (tensor<1x4x4x8xf32>, tensor<4xi32>) -> tensor<1x8x4x4xf32> + %3 = "tf.Transpose"(%2, %1) : (tensor<1x8x4x4xf32>, tensor<4xi32>) -> tensor<1x4x4x8xf32> + + return %3 : tensor<1x4x4x8xf32> + // CHECK: return %arg0 +} + +// CHECK-LABEL: @nonCancellableTranspose +func @nonCancellableTranspose(%arg0: tensor<1x4x4x8xf32>) -> tensor<4x1x4x8xf32> { + %0 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32> + %1 = "tf.Const"() {value = dense<[2, 0, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32> + %2 = "tf.Transpose"(%arg0, %0) : (tensor<1x4x4x8xf32>, tensor<4xi32>) -> tensor<1x8x4x4xf32> + %3 = "tf.Transpose"(%2, %1) : (tensor<1x8x4x4xf32>, tensor<4xi32>) -> tensor<4x1x4x8xf32> + + return %3 : tensor<4x1x4x8xf32> + // CHECK: return %3 +} + // CHECK-LABEL: func @addN func @addN(%arg0: tensor<*xf32>) -> tensor<*xf32> { // CHECK: return %arg0 diff --git a/tensorflow/compiler/mlir/tensorflow/tests/device_assignment.mlir b/tensorflow/compiler/mlir/tensorflow/tests/device_assignment.mlir index 1f1e6c63f30..6971cf06648 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/device_assignment.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/device_assignment.mlir @@ -9,5 +9,7 @@ func @device_test(%arg0: tensor<3x1xf32>) -> (tensor<3x3xf32>) { %1 = "tf.MatMul"(%arg0, %0) {T = f32, _output_shapes = ["tfshape$dim { size: 3 } dim { size: 3 }"], device = "", transpose_a = false, transpose_b = false} : (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32> // CHECK: device = "cpu" %2 = "tf.Relu"(%1) {T = f32, _output_shapes = ["tfshape$dim { size: 3 } dim { size: 3 }"], device = "cpu"} : (tensor<3x3xf32>) -> tensor<3x3xf32> - return %2 : tensor<3x3xf32> + // CHECK: device = "gpu" + %3 = "tf.Relu"(%2) {T = f32, _output_shapes = ["tfshape$dim { size: 3 } dim { size: 3 }"]} : (tensor<3x3xf32>) -> tensor<3x3xf32> + return %3 : tensor<3x3xf32> } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_island_coarsening.mlir b/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_island_coarsening/executor_tpuv1_island_coarsening.mlir similarity index 100% rename from tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_island_coarsening.mlir rename to tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_island_coarsening/executor_tpuv1_island_coarsening.mlir diff --git a/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_island_coarsening/while_op.mlir b/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_island_coarsening/while_op.mlir new file mode 100644 index 00000000000..59ece992756 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_island_coarsening/while_op.mlir @@ -0,0 +1,57 @@ +// RUN: tf-opt %s -tf-executor-tpu-v1-island-coarsening | FileCheck %s --dump-input=fail + + +// Test that islands with a function call are merged if the call is to a function +// that contains ops with the same attribute. +// CHECK-LABEL: func @control_input +func @control_input(%arg0 : tensor) -> tensor { + %0:6 = tf_executor.graph { + %1:2 = tf_executor.island wraps "tf.opA"(%arg0) {_tpu_replicate = "cluster"} : (tensor) -> tensor + %2:2 = tf_executor.island wraps "tf.While"(%1#0) {name = "A", body = @while_body_with_cluster_attr, cond = @while_cond_with_cluster_attr, is_stateless = false, parallel_iterations = 10 : i64} : (tensor) -> tensor + %3:2 = tf_executor.island wraps "tf.While"(%1#0) {name = "B", body = @while_body_with_wrong_cluster_attr, cond = @while_cond_with_wrong_cluster_attr, is_stateless = false, parallel_iterations = 10 : i64} : (tensor) -> tensor + %4:2 = tf_executor.island wraps "tf.While"(%1#0) {name = "C", body = @while_body_without_cluster_attr, cond = @while_cond_with_cluster_attr, is_stateless = false, parallel_iterations = 10 : i64} : (tensor) -> tensor + %6:2 = tf_executor.island wraps "tf.While"(%1#0) {name = "D", body = @while_body_without_cluster_attr, cond = @while_cond_without_cluster_attr, is_stateless = false, parallel_iterations = 10 : i64} : (tensor) -> tensor + %5:2 = tf_executor.island wraps "tf.While"(%1#0) {name = "E", body = @while_body_with_cluster_attr, cond = @while_cond_without_cluster_attr, is_stateless = false, parallel_iterations = 10 : i64} : (tensor) -> tensor + +// CHECK: "tf.opA" +// CHECK-NOT: island +// CHECK: name = "A" +// CHECK-NOT: island +// CHECK: name = "C" +// CHECK-NOT: island +// CHECK: name = "E" +// CHECK: island {{.*}}name = "B" +// CHECK: island {{.*}}name = "D" + + tf_executor.fetch %1#0, %2#0, %3#0, %4#0, %5#0, %6#0 : tensor, tensor, tensor, tensor, tensor, tensor + } + return %0#0 : tensor +} + +func @while_body_with_cluster_attr(%arg0: tensor) -> tensor { + %0 = "some.op"(%arg0) {_tpu_replicate = "cluster"} : (tensor) -> tensor + return %0 : tensor +} +func @while_cond_with_cluster_attr(%arg0: tensor) -> tensor { + %0 = "some.op"(%arg0) {_tpu_replicate = "cluster"} : (tensor) -> tensor + return %0 : tensor +} + +func @while_body_with_wrong_cluster_attr(%arg0: tensor) -> tensor { + %0 = "some.op"(%arg0) {_tpu_replicate = "wrong_cluster"} : (tensor) -> tensor + return %0 : tensor +} +func @while_cond_with_wrong_cluster_attr(%arg0: tensor) -> tensor { + %0 = "some.op"(%arg0) {_tpu_replicate = "wrong_cluster"} : (tensor) -> tensor + return %0 : tensor +} + +func @while_body_without_cluster_attr(%arg0: tensor) -> tensor { + %0 = "some.op"(%arg0) : (tensor) -> tensor + return %0 : tensor +} +func @while_cond_without_cluster_attr(%arg0: tensor) -> tensor { + %0 = "some.op"(%arg0) : (tensor) -> tensor + return %0 : tensor +} + diff --git a/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_inline_tpu_island.mlir b/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_island_inlining/executor_tpuv1_inline_tpu_island.mlir similarity index 100% rename from tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_inline_tpu_island.mlir rename to tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_island_inlining/executor_tpuv1_inline_tpu_island.mlir diff --git a/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_island_inlining/while_op.mlir b/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_island_inlining/while_op.mlir new file mode 100644 index 00000000000..010b5346e1e --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_island_inlining/while_op.mlir @@ -0,0 +1,44 @@ +// RUN: tf-opt %s -tf-executor-tpu-v1-island-inlining | FileCheck %s --dump-input=fail + +// CHECK-NOT: tf.PartitionedCall +// CHECK-NOT: module @_tpu_v1_compat_outlined + +module { + func @control_input(%arg0: tensor) -> tensor { + %0:4 = tf_executor.graph { + %outputs:4, %control = tf_executor.island wraps "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @_tpu_v1_compat_outlined::@_tpu_v1_compat_outlined_func0} : (tensor) -> (tensor, tensor, tensor, tensor) + tf_executor.fetch %outputs#0, %outputs#1, %outputs#2, %outputs#3 : tensor, tensor, tensor, tensor + } + return %0#0 : tensor + } + module @_tpu_v1_compat_outlined { + func @_tpu_v1_compat_outlined_func0(%arg0: tensor) -> (tensor, tensor, tensor, tensor) { + "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "device", num_replicas = 1 : i64, topology = "topology"} : () -> () + %0 = "tf.opA"(%arg0) {_tpu_replicate = "cluster"} : (tensor) -> tensor + %1 = "tf.While"(%0) {body = @while_body_with_cluster_attr, cond = @while_cond_with_cluster_attr, is_stateless = false, name = "A", parallel_iterations = 10 : i64} : (tensor) -> tensor + %2 = "tf.While"(%0) {body = @while_body_without_cluster_attr, cond = @while_cond_with_cluster_attr, is_stateless = false, name = "C", parallel_iterations = 10 : i64} : (tensor) -> tensor + %3 = "tf.While"(%0) {body = @while_body_with_cluster_attr, cond = @while_cond_without_cluster_attr, is_stateless = false, name = "E", parallel_iterations = 10 : i64} : (tensor) -> tensor + return %0, %1, %2, %3 : tensor, tensor, tensor, tensor + } + func @while_body_with_cluster_attr(%arg0: tensor) -> tensor { + %0 = "some.op"(%arg0) {_tpu_replicate = "cluster"} : (tensor) -> tensor + return %0 : tensor + } + func @while_cond_with_cluster_attr(%arg0: tensor) -> tensor { + %0 = "some.op"(%arg0) {_tpu_replicate = "cluster"} : (tensor) -> tensor + return %0 : tensor + } + func @while_body_without_cluster_attr(%arg0: tensor) -> tensor { + %0 = "some.op"(%arg0) : (tensor) -> tensor + return %0 : tensor + } + func @while_cond_without_cluster_attr(%arg0: tensor) -> tensor { + %0 = "tf.PartionedCalledOp"(%arg0) {f = @callee_func} : (tensor) -> tensor + return %0 : tensor + } + func @callee_func(%arg0: tensor) -> tensor { + %0 = "some.op"(%arg0) : (tensor) -> tensor + return %0 : tensor + } + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_outline_tpu_island.mlir b/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_outline_island/executor_tpuv1_outline_tpu_island.mlir similarity index 100% rename from tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_outline_tpu_island.mlir rename to tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_outline_island/executor_tpuv1_outline_tpu_island.mlir diff --git a/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_outline_island/while_op.mlir b/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_outline_island/while_op.mlir new file mode 100644 index 00000000000..b1dee63ca03 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_outline_island/while_op.mlir @@ -0,0 +1,48 @@ +// RUN: tf-opt %s -tf-executor-tpu-v1-island-outlining | FileCheck %s --dump-input=fail + +// CHECK: func @control_input +// CHECK-NOT: func @ +// CHECK-LABEL: module @_tpu_v1_compat_outlined +// CHECK: @_tpu_v1_compat_outlined_func0 +// CHECK: func @while_body_with_cluster_attr +// CHECK: func @while_cond_with_cluster_attr +// CHECK: func @while_body_without_cluster_attr +// CHECK: func @while_cond_without_cluster_attr +// CHECK: func @callee_func +module { + func @control_input(%arg0: tensor) -> tensor { + %0:4 = tf_executor.graph { + %outputs:4, %control = tf_executor.island { + "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "device", num_replicas = 1, topology = "topology"} : () -> () + %1 = "tf.opA"(%arg0) {_tpu_replicate = "cluster"} : (tensor) -> tensor + %2 = "tf.While"(%1) {body = @while_body_with_cluster_attr, cond = @while_cond_with_cluster_attr, is_stateless = false, name = "A", parallel_iterations = 10 : i64} : (tensor) -> tensor + %3 = "tf.While"(%1) {body = @while_body_without_cluster_attr, cond = @while_cond_with_cluster_attr, is_stateless = false, name = "C", parallel_iterations = 10 : i64} : (tensor) -> tensor + %4 = "tf.While"(%1) {body = @while_body_with_cluster_attr, cond = @while_cond_without_cluster_attr, is_stateless = false, name = "E", parallel_iterations = 10 : i64} : (tensor) -> tensor + tf_executor.yield %1, %2, %3, %4 : tensor, tensor, tensor, tensor + } + tf_executor.fetch %outputs#0, %outputs#1, %outputs#2, %outputs#3 : tensor, tensor, tensor, tensor + + } + return %0#0 : tensor + } + func @while_body_with_cluster_attr(%arg0: tensor) -> tensor { + %0 = "some.op"(%arg0) {_tpu_replicate = "cluster"} : (tensor) -> tensor + return %0 : tensor + } + func @while_cond_with_cluster_attr(%arg0: tensor) -> tensor { + %0 = "some.op"(%arg0) {_tpu_replicate = "cluster"} : (tensor) -> tensor + return %0 : tensor + } + func @while_body_without_cluster_attr(%arg0: tensor) -> tensor { + %0 = "some.op"(%arg0) : (tensor) -> tensor + return %0 : tensor + } + func @while_cond_without_cluster_attr(%arg0: tensor) -> tensor { + %0 = "tf.PartionedCalledOp"(%arg0) { f = @callee_func} : (tensor) -> tensor + return %0 : tensor + } + func @callee_func(%arg0: tensor) -> tensor { + %0 = "some.op"(%arg0) : (tensor) -> tensor + return %0 : tensor + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization.mlir index f632e657421..44330d675e2 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization.mlir @@ -1,41 +1,24 @@ -// RUN: tf-opt %s -tf-layout-assignment=force-data-format=NCHW -verify-diagnostics | FileCheck %s +// RUN: tf-opt %s -tf-layout-optimization=force-data-format=NCHW -verify-diagnostics | FileCheck %s --dump-input=always // CHECK-LABEL: func @transposeBiasAdd -func @transposeBiasAdd(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<8xf32>) -> tensor<1x4x4x8xf32> { +func @transposeBiasAdd(%arg0: tensor<1x8x4x4xf32>, %arg1: tensor<8xf32>) -> tensor<1x8x4x4xf32> { - // Check that BiasAdd was converted to forced data format, and layout - // dependent arguments and results passed through transpose nodes. + // Convert input: NCHW -> NHWC + %0 = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64> + %1 = "tf.Transpose"(%arg0, %0) : (tensor<1x8x4x4xf32>, tensor<4xi64>) -> tensor<1x4x4x8xf32> - // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} - // CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]]) - // CHECK: %[[BIAS_ADD:[0-9]*]] = "tf.BiasAdd"(%[[ARG_TRANSPOSE]], %arg1) {data_format = "NCHW"} {{.*}} tensor<1x8x4x4xf32> - // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} - // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[BIAS_ADD]], %[[RES_PERM]]) - // CHECK: return %[[RES_TRANSPOSE]] - %0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NHWC"} : (tensor<1x4x4x8xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32> + // Compute in NHWC + %2 = "tf.BiasAdd"(%1, %arg1) {data_format = "NHWC"} : (tensor<1x4x4x8xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32> - return %0 : tensor<1x4x4x8xf32> -} + // Convert result back: NHWC -> NCHW + %3 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64> + %4 = "tf.Transpose"(%2, %3) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32> -// CHECK-LABEL: func @transposeBiasAddWithDefaultAttr -func @transposeBiasAddWithDefaultAttr(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<8xf32>) -> tensor<1x4x4x8xf32> { + // Check that BiasAdd computed in NCHW format, and all redundant transpose + // operations removed from the function. - // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} - // CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]]) - // CHECK: %[[BIAS_ADD:[0-9]*]] = "tf.BiasAdd"(%[[ARG_TRANSPOSE]], %arg1) {data_format = "NCHW"} {{.*}} tensor<1x8x4x4xf32> - // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} - // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[BIAS_ADD]], %[[RES_PERM]]) - // CHECK: return %[[RES_TRANSPOSE]] - %0 = "tf.BiasAdd"(%arg0, %arg1) : (tensor<1x4x4x8xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32> + // CHECK: %[[BIAS_ADD:[0-9]*]] = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NCHW"} {{.*}} tensor<1x8x4x4xf32> + // CHECK: return %[[BIAS_ADD]] - return %0 : tensor<1x4x4x8xf32> -} - -// CHECK-LABEL: func @transposeBiasWithUnknownShape -func @transposeBiasWithUnknownShape(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<8xf32>) -> tensor<*xf32> { - - // CHECK: %[[BIAS_ADD:[0-9]*]] = "tf.BiasAdd"(%[[ARG_TRANSPOSE]], %arg1) {data_format = "NCHW"} {{.*}} tensor<*xf32> - %0 = "tf.BiasAdd"(%arg0, %arg1) : (tensor<1x4x4x8xf32>, tensor<8xf32>) -> tensor<*xf32> - - return %0 : tensor<*xf32> + return %4 : tensor<1x8x4x4xf32> } \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nchw.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nchw.mlir new file mode 100644 index 00000000000..983eabbbb02 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nchw.mlir @@ -0,0 +1,75 @@ +// RUN: tf-opt %s -tf-layout-assignment=force-data-format=NCHW -verify-diagnostics | FileCheck %s --dump-input=always + +// CHECK-LABEL: func @transposeBiasAdd +func @transposeBiasAdd(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<8xf32>) -> tensor<1x4x4x8xf32> { + + // Check that BiasAdd was converted to forced data format, and layout + // dependent arguments and results passed through transpose nodes. + + // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} + // CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]]) + // CHECK: %[[BIAS_ADD:[0-9]*]] = "tf.BiasAdd"(%[[ARG_TRANSPOSE]], %arg1) {data_format = "NCHW"} {{.*}} tensor<1x8x4x4xf32> + // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} + // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[BIAS_ADD]], %[[RES_PERM]]) + // CHECK: return %[[RES_TRANSPOSE]] + %0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NHWC"} : (tensor<1x4x4x8xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32> + + return %0 : tensor<1x4x4x8xf32> +} + +// CHECK-LABEL: func @transposeBiasAddWithDefaultAttr +func @transposeBiasAddWithDefaultAttr(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<8xf32>) -> tensor<1x4x4x8xf32> { + + // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} + // CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]]) + // CHECK: %[[BIAS_ADD:[0-9]*]] = "tf.BiasAdd"(%[[ARG_TRANSPOSE]], %arg1) {data_format = "NCHW"} {{.*}} tensor<1x8x4x4xf32> + // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} + // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[BIAS_ADD]], %[[RES_PERM]]) + // CHECK: return %[[RES_TRANSPOSE]] + %0 = "tf.BiasAdd"(%arg0, %arg1) : (tensor<1x4x4x8xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32> + + return %0 : tensor<1x4x4x8xf32> +} + +// CHECK-LABEL: func @transposeBiasWithUnknownShape +func @transposeBiasWithUnknownShape(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<8xf32>) -> tensor<*xf32> { + + // CHECK: %[[BIAS_ADD:[0-9]*]] = "tf.BiasAdd"(%[[ARG_TRANSPOSE]], %arg1) {data_format = "NCHW"} {{.*}} tensor<*xf32> + %0 = "tf.BiasAdd"(%arg0, %arg1) : (tensor<1x4x4x8xf32>, tensor<8xf32>) -> tensor<*xf32> + + return %0 : tensor<*xf32> +} + +// CHECK-LABEL: func @transposeConv2D +func @transposeConv2D(%input: tensor<1x32x32x3xf32>, %filter: tensor<1x1x3x8xf32>) -> tensor<1x32x32x8xf32> { + + // IMPORTANT: Tensor shapes do not match convolution parameters (stride, + // dilations, etc...). This test only verifies that changing convolution data + // layout will update all the attributes. + + // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} + // CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]]) + + // CHECK: %[[CONV2D:[0-9]*]] = "tf.Conv2D"(%[[ARG_TRANSPOSE]], %arg1) + // CHECK-SAME: data_format = "NCHW" + // CHECK-SAME: dilations = [1, 4, 2, 3] + // CHECK-SAME: explicit_paddings = [1, 2, 7, 8, 3, 4, 5, 6] + // CHECK-SAME: padding = "EXPLICIT" + // CHECK-SAME: strides = [5, 8, 6, 7] + // CHECK-SAME: (tensor<1x3x32x32xf32>, tensor<1x1x3x8xf32>) -> tensor<1x8x32x32xf32> + + // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} + // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[CONV2D]], %[[RES_PERM]]) + // CHECK: return %[[RES_TRANSPOSE]] + + %0 = "tf.Conv2D"(%input, %filter) + { + data_format = "NHWC", + dilations = [1, 2, 3, 4], + explicit_paddings = [1, 2, 3, 4, 5, 6, 7, 8], + padding = "EXPLICIT", + strides = [5, 6, 7, 8] + } : (tensor<1x32x32x3xf32>, tensor<1x1x3x8xf32>) -> tensor<1x32x32x8xf32> + + return %0 : tensor<1x32x32x8xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nhwc.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nhwc.mlir new file mode 100644 index 00000000000..2d87d5ccd9c --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nhwc.mlir @@ -0,0 +1,35 @@ +// RUN: tf-opt %s -tf-layout-assignment=force-data-format=NHWC -verify-diagnostics | FileCheck %s --dump-input=always + +// CHECK-LABEL: func @transposeConv2D +func @transposeConv2D(%input: tensor<1x3x32x32xf32>, %filter: tensor<1x1x3x8xf32>) -> tensor<1x8x32x32xf32> { + + // IMPORTANT: Tensor shapes do not match convolution parameters (stride, + // dilations, etc...). This test only verifies that changing convolution data + // layout will update all the attributes. + + // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} + // CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]]) + + // CHECK: %[[CONV2D:[0-9]*]] = "tf.Conv2D"(%[[ARG_TRANSPOSE]], %arg1) + // CHECK-SAME: data_format = "NHWC" + // CHECK-SAME: dilations = [1, 3, 4, 2] + // CHECK-SAME: explicit_paddings = [1, 2, 5, 6, 7, 8, 3, 4] + // CHECK-SAME: padding = "EXPLICIT" + // CHECK-SAME: strides = [5, 7, 8, 6] + // CHECK-SAME: (tensor<1x32x32x3xf32>, tensor<1x1x3x8xf32>) -> tensor<1x32x32x8xf32> + + // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} + // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[CONV2D]], %[[RES_PERM]]) + // CHECK: return %[[RES_TRANSPOSE]] + + %0 = "tf.Conv2D"(%input, %filter) + { + data_format = "NCHW", + dilations = [1, 2, 3, 4], + explicit_paddings = [1, 2, 3, 4, 5, 6, 7, 8], + padding = "EXPLICIT", + strides = [5, 6, 7, 8] + } : (tensor<1x3x32x32xf32>, tensor<1x1x3x8xf32>) -> tensor<1x8x32x32xf32> + + return %0 : tensor<1x8x32x32xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_move_transposes_begin.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_move_transposes_begin.mlir new file mode 100644 index 00000000000..f61f1216064 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_move_transposes_begin.mlir @@ -0,0 +1,67 @@ +// RUN: tf-opt %s -tf-move-transposes=direction=begin -verify-diagnostics | FileCheck %s --dump-input=always + +// CHECK-LABEL: func @move_across_single_op +func @move_across_single_op(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> { + + // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} + // CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]]) + // CHECK: %[[TANH:[0-9]*]] = "tf.Tanh"(%[[ARG_TRANSPOSE]]) {{.*}} tensor<1x8x4x4xf32> + // CHECK: return %[[TANH]] + + %0 = "tf.Tanh"(%arg0) : (tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32> + %1 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64> + %2 = "tf.Transpose"(%0, %1) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32> + + return %2 : tensor<1x8x4x4xf32> +} + +// CHECK-LABEL: func @move_across_multiple_ops +func @move_across_multiple_ops(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> { + + // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} + // CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]]) + // CHECK: %[[TANH:[0-9]*]] = "tf.Tanh"(%[[ARG_TRANSPOSE]]) {{.*}} tensor<1x8x4x4xf32> + // CHECK: %[[RELU:[0-9]*]] = "tf.Relu"(%[[TANH]]) {{.*}} tensor<1x8x4x4xf32> + // CHECK: return %[[RELU]] + + %0 = "tf.Tanh"(%arg0) : (tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32> + %1 = "tf.Relu"(%0) : (tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32> + + %2 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64> + %3 = "tf.Transpose"(%1, %2) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32> + + return %3 : tensor<1x8x4x4xf32> +} + +// CHECK-LABEL: func @move_across_multi_operand_op +func @move_across_multi_operand_op(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> { + + // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} + // CHECK: %[[ARG0_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]]) + // CHECK: %[[ARG1_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg1, %[[ARG_PERM]]) + // CHECK: %[[ADD:[0-9]*]] = "tf.AddV2"(%[[ARG0_TRANSPOSE]], %[[ARG1_TRANSPOSE]]) {{.*}} tensor<1x8x4x4xf32> + // CHECK: return %[[ADD]] + + %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<1x4x4x8xf32>, tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32> + %1 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64> + %2 = "tf.Transpose"(%0, %1) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32> + + return %2 : tensor<1x8x4x4xf32> +} + +// CHECK-LABEL: func @move_with_multiple_uses +func @move_with_multiple_uses(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> { + + // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} + // CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]]) + // CHECK: %[[TANH:[0-9]*]] = "tf.Tanh"(%[[ARG_TRANSPOSE]]) {{.*}} tensor<1x8x4x4xf32> + // CHECK: %[[ADD:[0-9]*]] = "tf.AddV2"(%[[TANH]], %[[TANH]]) {{.*}} tensor<1x8x4x4xf32> + // CHECK: return %[[ADD]] + + %0 = "tf.Tanh"(%arg0) : (tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32> + %1 = "tf.AddV2"(%0, %0) : (tensor<1x4x4x8xf32>, tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32> + %2 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64> + %3 = "tf.Transpose"(%1, %2) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32> + + return %3 : tensor<1x8x4x4xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_move_transposes_end.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_move_transposes_end.mlir new file mode 100644 index 00000000000..1bc61387a0d --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_move_transposes_end.mlir @@ -0,0 +1,120 @@ +// RUN: tf-opt %s -tf-move-transposes=direction=end -verify-diagnostics | FileCheck %s --dump-input=always + +// CHECK-LABEL: func @move_across_single_op +func @move_across_single_op(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> { + + // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} + // CHECK: %[[TANH:[0-9]*]] = "tf.Tanh"(%arg0) {{.*}} tensor<1x4x4x8xf32> + // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[TANH]], %[[RES_PERM]]) {{.*}} tensor<1x8x4x4xf32> + // CHECK: return %[[RES_TRANSPOSE]] + + %0 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64> + %1 = "tf.Transpose"(%arg0, %0) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32> + %2 = "tf.Tanh"(%1) : (tensor<1x8x4x4xf32>) -> tensor<1x8x4x4xf32> + + return %2 : tensor<1x8x4x4xf32> +} + +// CHECK-LABEL: func @move_across_multiple_ops +func @move_across_multiple_ops(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> { + + // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} + // CHECK: %[[TANH:[0-9]*]] = "tf.Tanh"(%arg0) {{.*}} tensor<1x4x4x8xf32> + // CHECK: %[[RELU:[0-9]*]] = "tf.Relu"(%[[TANH]]) {{.*}} tensor<1x4x4x8xf32> + // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[RELU]], %[[RES_PERM]]) + // CHECK: return %[[RES_TRANSPOSE]] + + %0 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64> + %1 = "tf.Transpose"(%arg0, %0) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32> + %2 = "tf.Tanh"(%1) : (tensor<1x8x4x4xf32>) -> tensor<1x8x4x4xf32> + %3 = "tf.Relu"(%2) : (tensor<1x8x4x4xf32>) -> tensor<1x8x4x4xf32> + + return %3 : tensor<1x8x4x4xf32> +} + +// CHECK-LABEL: func @move_across_multi_operand_op +func @move_across_multi_operand_op(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> { + + // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} + // CHECK: %[[ADD:[0-9]*]] = "tf.AddV2"(%arg0, %arg1) {{.*}} tensor<1x4x4x8xf32> + // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[ADD]], %[[RES_PERM]]) + // CHECK: return %[[RES_TRANSPOSE]] + + %0 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64> + %1 = "tf.Transpose"(%arg0, %0) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32> + %2 = "tf.Transpose"(%arg1, %0) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32> + %3 = "tf.AddV2"(%1, %2) : (tensor<1x8x4x4xf32>, tensor<1x8x4x4xf32>) -> tensor<1x8x4x4xf32> + + return %3 : tensor<1x8x4x4xf32> +} + +// CHECK-LABEL: func @fold_into_max_pool +func @fold_into_max_pool(%arg0: tensor<1x64x112x112xf32>) -> tensor<1x56x56x64xf32> { + + // MaxPool operand transpose must be folded into the op and MaxPool + // must use NCHW data format with updated kernel size and strides. + + // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} + // CHECK: %[[MAX_POOL:[0-9]*]] = "tf.MaxPool"(%arg0) {data_format = "NCHW", ksize = [1, 1, 3, 3], padding = "SAME", strides = [1, 1, 2, 2]} : (tensor<1x64x112x112xf32>) -> tensor<1x64x56x56xf32> + // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[MAX_POOL]], %[[RES_PERM]]) + // CHECK: return %[[RES_TRANSPOSE]] + + // Transpose NCHW -> NHWC + %0 = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64> + %1 = "tf.Transpose"(%arg0, %0) : (tensor<1x64x112x112xf32>, tensor<4xi64>) -> tensor<1x112x112x64xf32> + + // Compute MaxPool in NHWC format + %2 = "tf.MaxPool"(%1) + { + data_format = "NHWC", ksize = [1, 3, 3, 1], + padding = "SAME", strides = [1, 2, 2, 1] + } : (tensor<1x112x112x64xf32>) -> tensor<1x56x56x64xf32> + + return %2 : tensor<1x56x56x64xf32> +} + +// CHECK-LABEL: func @fold_into_mean +func @fold_into_mean(%arg0: tensor<1x64x112x112xf32>) -> tensor<1x64xf32> { + + // CHECK: %[[RED_IDX:[0-9]*]] = "tf.Const"() {value = dense<[2, 3]> : tensor<2xi64>} + // CHECK: %[[MEAN:[0-9]*]] = "tf.Mean"(%arg0, %[[RED_IDX]]) + // CHECK-SAME: (tensor<1x64x112x112xf32>, tensor<2xi64>) -> tensor<1x64xf32> + // CHECK: return %[[MEAN]] + + // Transpose NCHW -> NHWC + %0 = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64> + %1 = "tf.Transpose"(%arg0, %0) : (tensor<1x64x112x112xf32>, tensor<4xi64>) -> tensor<1x112x112x64xf32> + + // Compute Mean over spatial dimensions in NHWC format. + %2 = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi64>} : () -> tensor<2xi64> + %3 = "tf.Mean"(%1, %2) : (tensor<1x112x112x64xf32>, tensor<2xi64>) -> tensor<1x64xf32> + + return %3 : tensor<1x64xf32> +} + +// CHECK-LABEL: func @fold_into_fused_batch_norm +func @fold_into_fused_batch_norm(%arg0: tensor<1x64x112x112xf32>, %arg1: tensor<64xf32>) -> tensor<1x112x112x64xf32> { + + // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} + // CHECK: "tf.FusedBatchNormV3"(%arg0, {{.*}} {data_format = "NCHW" + // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%y, %[[RES_PERM]]) + // CHECK: return %[[RES_TRANSPOSE]] + + // Transpose NCHW -> NHWC + %0 = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64> + %1 = "tf.Transpose"(%arg0, %0) : (tensor<1x64x112x112xf32>, tensor<4xi64>) -> tensor<1x112x112x64xf32> + + // Compute FusedBatchNormV3 in NHWC format + %2, %batch_mean, %batch_var, %reserve_1, %reserve_2, %reserve_3 + = "tf.FusedBatchNormV3"(%1, %arg1, %arg1, %arg1, %arg1) + { + data_format = "NHWC", + epsilon = 1.001 : f32, + exponential_avg_factor = 1.0 : f32, + is_training = false + } + : (tensor<1x112x112x64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) + -> (tensor<1x112x112x64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) + + return %2#0 : tensor<1x112x112x64xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/parallel_execute_to_islands.mlir b/tensorflow/compiler/mlir/tensorflow/tests/parallel_execute_to_islands.mlir new file mode 100644 index 00000000000..be23da672e5 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/parallel_execute_to_islands.mlir @@ -0,0 +1,194 @@ +// RUN: tf-opt %s -tf-parallel-execute-to-islands | FileCheck %s --dump-input=fail + +// CHECK-LABEL: func @check_regions_to_islands +func @check_regions_to_islands() { + tf_executor.graph { + tf_executor.island() { + "tf_device.parallel_execute"() ({ + tf_device.return + }, + { + tf_device.return + }) {} : () -> () + tf_executor.yield + } + tf_executor.fetch + } + return +} + +// CHECK: %[[ISLAND_INPUT_CTL:[a-z_0-9]*]] = tf_executor.island { +// CHECK-NEXT: tf_executor.yield +// CHECK: %[[ISLAND_1_CTL:[a-z_0-9]*]] = tf_executor.island(%[[ISLAND_INPUT_CTL]]) { +// CHECK: tf_executor.yield +// CHECK: %[[ISLAND_2_CTL:[a-z_0-9]*]] = tf_executor.island(%[[ISLAND_INPUT_CTL]]) { +// CHECK: tf_executor.yield +// CHECK: %{{.*}} = tf_executor.island(%[[ISLAND_1_CTL]], %[[ISLAND_2_CTL]]) { +// CHECK-NEXT: tf_executor.yield + + +// CHECK-LABEL: func @check_regions_to_islands_with_inputs +// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) +func @check_regions_to_islands_with_inputs(%arg0 : tensor) { + tf_executor.graph { + %1:2 = tf_executor.island { + %2 = "tf.opA"(%arg0) : (tensor) -> tensor + tf_executor.yield %2 : tensor + } + tf_executor.island() { + "tf_device.parallel_execute"() ({ + %3 = "tf.opB"(%1#0) : (tensor) -> tensor + tf_device.return %3 : tensor + }, + { + %5 = "tf.opC"(%1#0) : (tensor) -> tensor + tf_device.return %5 : tensor + }) {} : () -> (tensor, tensor) + tf_executor.yield + } + tf_executor.fetch + } + return +} + +// CHECK: %[[INPUT_A:[a-z_0-9]*]], %{{.*}} = tf_executor.island { +// CHECK-NEXT: %[[OP_A_OUTPUT:[a-z_0-9]*]] = "tf.opA"(%[[ARG_0]]) : (tensor) -> tensor +// CHECK-NEXT: tf_executor.yield %[[OP_A_OUTPUT]] : tensor +// CHECK: %[[INPUT_0:[a-z_0-9]*]], %[[INPUT_CONTROL:[a-z_0-9]*]] = tf_executor.island { +// CHECK-NEXT: tf_executor.yield %[[INPUT_A]] : tensor +// CHECK: %[[ISLAND_1_OUTPUT:[a-z_0-9]*]], %[[ISLAND_1_CTL:[a-z_0-9]*]] = tf_executor.island { +// CHECK-NEXT: %[[OP_B_OUTPUT:[a-z_0-9]*]] = "tf.opB"(%[[INPUT_0]]) : (tensor) -> tensor +// CHECK: tf_executor.yield %[[OP_B_OUTPUT]] : tensor +// CHECK: %[[ISLAND_2_OUTPUT:[a-z_0-9]*]], %[[ISLAND_2_CTL:[a-z_0-9]*]] = tf_executor.island { +// CHECK-NEXT: %[[OP_C_OUTPUT:[a-z_0-9]*]] = "tf.opC"(%outputs_0) : (tensor) -> tensor +// CHECK: tf_executor.yield %[[OP_C_OUTPUT]] : tensor +// CHECK: %{{.*}} = tf_executor.island(%[[ISLAND_1_CTL]], %[[ISLAND_2_CTL]]) { +// CHECK-NEXT: tf_executor.yield + + +// CHECK-LABEL: func @check_input_sink_island_forwards_control_inputs +// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) +func @check_input_sink_island_forwards_control_inputs(%arg0 : tensor) { + tf_executor.graph { + %1:2 = tf_executor.island { + %2 = "tf.opA"(%arg0) : (tensor) -> tensor + tf_executor.yield %2 : tensor + } + %7 = tf_executor.ControlTrigger {} + %8 = tf_executor.ControlTrigger {} + tf_executor.island(%7, %8) { + "tf_device.parallel_execute"() ({ + %3 = "tf.opB"(%1#0) : (tensor) -> tensor + tf_device.return %3 : tensor + }, + { + %5 = "tf.opC"() : () -> tensor + tf_device.return %5 : tensor + }) {} : () -> (tensor, tensor) + tf_executor.yield + } + tf_executor.fetch + } + return +} + +// CHECK: %[[INPUT_A:[a-z_0-9]*]], %{{.*}} = tf_executor.island { +// CHECK-NEXT: %[[OP_A_OUTPUT:[a-z_0-9]*]] = "tf.opA"(%[[ARG_0]]) : (tensor) -> tensor +// CHECK-NEXT: tf_executor.yield %[[OP_A_OUTPUT]] : tensor +// CHECK: %[[CT_0:[0-9]*]] = tf_executor.ControlTrigger +// CHECK: %[[CT_1:[0-9]*]] = tf_executor.ControlTrigger +// CHECK: %[[INPUT_0:[a-z_0-9]*]], %[[INPUT_CONTROL:[a-z_0-9]*]] = tf_executor.island(%[[CT_0]], %[[CT_1]]) { +// CHECK-NEXT: tf_executor.yield %[[INPUT_A]] : tensor +// CHECK: %[[ISLAND_1_OUTPUT:[a-z_0-9]*]], %[[ISLAND_1_CTL:[a-z_0-9]*]] = tf_executor.island { +// CHECK-NEXT: %[[OP_B_OUTPUT:[a-z_0-9]*]] = "tf.opB"(%[[INPUT_0]]) : (tensor) -> tensor +// CHECK: tf_executor.yield %[[OP_B_OUTPUT]] : tensor +// CHECK: %[[ISLAND_2_OUTPUT:[a-z_0-9]*]], %[[ISLAND_2_CTL:[a-z_0-9]*]] = tf_executor.island(%[[INPUT_CONTROL]]) { +// CHECK-NEXT: %[[OP_C_OUTPUT:[a-z_0-9]*]] = "tf.opC"() : () -> tensor +// CHECK: tf_executor.yield %[[OP_C_OUTPUT]] : tensor +// CHECK: %{{.*}} = tf_executor.island(%[[ISLAND_1_CTL]], %[[ISLAND_2_CTL]]) { +// CHECK-NEXT: tf_executor.yield + + +// CHECK-LABEL: func @check_control_dep_added_when_region_does_not_have_inputs +// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) +func @check_control_dep_added_when_region_does_not_have_inputs(%arg0 : tensor) { + tf_executor.graph { + %1:2 = tf_executor.island { + %2 = "tf.opA"(%arg0) : (tensor) -> tensor + tf_executor.yield %2 : tensor + } + %7:3 = tf_executor.island() { + %8:2 = "tf_device.parallel_execute"() ( + { + %3 = "tf.opB"() : () -> tensor + tf_device.return %3 : tensor + }, + { + %5 = "tf.opC"(%1#0) : (tensor) -> tensor + tf_device.return %5 : tensor + } + ) {} : () -> (tensor, tensor) + + tf_executor.yield %8#0, %8#1 : tensor, tensor + } + + tf_executor.island { + "tf.opD"(%7#0, %7#1) : (tensor, tensor) -> () + tf_executor.yield + } + tf_executor.fetch + } + return +} + +// CHECK: %[[INPUT_A:[a-z_0-9]*]], %{{.*}} = tf_executor.island { +// CHECK-NEXT: %[[OP_A_OUTPUT:[a-z_0-9]*]] = "tf.opA"(%[[ARG_0]]) : (tensor) -> tensor +// CHECK-NEXT: tf_executor.yield %[[OP_A_OUTPUT]] : tensor +// CHECK: %[[INPUT_0:[a-z_0-9]*]], %[[INPUT_CTL:[a-z_0-9]*]] = tf_executor.island { +// CHECK-NEXT: tf_executor.yield %[[INPUT_A]] : tensor +// CHECK: %[[ISLAND_1_OUTPUT:[a-z_0-9]*]], %{{.*}} = tf_executor.island(%[[INPUT_CTL]]) { +// CHECK-NEXT: %[[OP_B_OUTPUT:[a-z_0-9]*]] = "tf.opB"() : () -> tensor +// CHECK: tf_executor.yield %[[OP_B_OUTPUT]] : tensor +// CHECK: %[[ISLAND_2_OUTPUT:[a-z_0-9]*]], %{{.*}} = tf_executor.island { +// CHECK-NEXT: %[[OP_C_OUTPUT:[a-z_0-9]*]] = "tf.opC"(%outputs_0) : (tensor) -> tensor +// CHECK: tf_executor.yield %[[OP_C_OUTPUT]] : tensor +// CHECK: %{{.*}} = tf_executor.island { +// CHECK-NEXT: tf_executor.yield %[[ISLAND_1_OUTPUT]], %[[ISLAND_2_OUTPUT]] + + +// CHECK-LABEL: func @check_output_barrier_correctly_forwards_outputs +func @check_output_barrier_correctly_forwards_outputs(%arg0 : tensor) -> tensor { + %0 = tf_executor.graph { + %1:2 = tf_executor.island { + %2 = "tf.opA"(%arg0) : (tensor) -> tensor + tf_executor.yield %2 : tensor + } + %8:3 = tf_executor.island() { + %7:2 = "tf_device.parallel_execute"() ({ + %3 = "tf.opB"() : () -> tensor + tf_device.return %3 : tensor + }, + { + %5 = "tf.opC"(%1#0) : (tensor) -> tensor + tf_device.return %5 : tensor + }) {} : () -> (tensor, tensor) + tf_executor.yield %7#0, %7#1 : tensor, tensor + } + tf_executor.fetch %8#0 : tensor + } + return %0 : tensor +} + +// CHECK: %[[INPUT_A:[a-z_0-9]*]], %{{.*}} = tf_executor.island { +// CHECK-NEXT: %[[OP_A_OUTPUT:[a-z_0-9]*]] = "tf.opA"(%[[ARG_0]]) : (tensor) -> tensor +// CHECK-NEXT: tf_executor.yield %[[OP_A_OUTPUT]] : tensor +// CHECK: %[[INPUT_0:[a-z_0-9]*]], %[[INPUT_CTL:[a-z_0-9]*]] = tf_executor.island { +// CHECK-NEXT: tf_executor.yield %[[INPUT_A]] : tensor +// CHECK: %[[ISLAND_1_OUTPUT:[a-z_0-9]*]], %{{.*}} = tf_executor.island(%[[INPUT_CTL]]) { +// CHECK-NEXT: %[[OP_B_OUTPUT:[a-z_0-9]*]] = "tf.opB"() : () -> tensor +// CHECK: tf_executor.yield %[[OP_B_OUTPUT]] : tensor +// CHECK: %[[ISLAND_2_OUTPUT:[a-z_0-9]*]], %{{.*}} = tf_executor.island { +// CHECK-NEXT: %[[OP_C_OUTPUT:[a-z_0-9]*]] = "tf.opC"(%[[INPUT_0]]) : (tensor) -> tensor +// CHECK: tf_executor.yield %[[OP_C_OUTPUT]] : tensor +// CHECK: %[[OUTPUT_SINK_OUTPUT:[a-z_0-9]*]]:2, %[[OUTPUT_SINK_CTL:[a-z_0-9]*]] = tf_executor.island { +// CHECK-NEXT: tf_executor.yield %[[ISLAND_1_OUTPUT]], %[[ISLAND_2_OUTPUT]] : tensor, tensor diff --git a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir index 016b06b662a..52bc0f878fc 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir @@ -542,3 +542,116 @@ func @if_else(%arg0: tensor<*x!tf.resource>>, %arg1: tensor<*x!tf. -> (tensor<*x!tf.resource>>) { return %arg1 : tensor<*x!tf.resource>> } + +// ----- + +// Tests that the pass lifts resources on two partitioned call ops sharing the +// same callee. The lifting should clone the callee then modify the clone. + +// CHECK-LABEL: @launch_with_partitioned_call +func @launch_with_partitioned_call() -> tensor { + // CHECK: %[[VH:.*]] = "tf.VarHandleOp"() + %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> + // CHECK: %[[CONST:.*]] = "tf.Const"() + %1 = "tf.Const"() {value = dense<10.0> : tensor} : () -> tensor + // CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VH]]) + // CHECK: %[[LAUNCH:.*]] = "tf_device.launch"() + %2 = "tf_device.launch"() ( { + // CHECK: %[[PC0:.*]] = "tf.PartitionedCall"(%[[CONST]], %[[READ]], %[[CONST]]) + // CHECK-SAME: f = @callee_resource_lifted + %3 = "tf.PartitionedCall"(%1, %0, %1) {f = @callee, config = "", config_proto = "", executor_type = ""} + : (tensor, tensor<*x!tf.resource>>, tensor) -> tensor + // CHECK: %[[PC1:.*]] = "tf.PartitionedCall"(%[[CONST]], %[[READ]], %[[CONST]]) + // CHECK-SAME: f = @callee_resource_lifted + %4 = "tf.PartitionedCall"(%1, %0, %1) {f = @callee, config = "", config_proto = "", executor_type = ""} + : (tensor, tensor<*x!tf.resource>>, tensor) -> tensor + // CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[PC0]], %[[PC1]]) + %5 = "tf.AddV2"(%3, %4) : (tensor, tensor) -> tensor + // CHECK: tf_device.return %[[ADD]] : tensor + tf_device.return %5 : tensor + }) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor + return %2 : tensor +} +// CHECK: @callee(%[[OA0:.*]]: tensor, %[[OA1:.*]]: tensor<*x!tf.resource>>, %[[OA2:.*]]: tensor) -> tensor +func @callee(%arg0: tensor, %arg1: tensor<*x!tf.resource>>, %arg2: tensor) -> tensor { + // CHECK: "tf.ReadVariableOp"(%[[OA1]]) + %0 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource>>) -> tensor + %1 = "tf.AddV2"(%0, %arg0) : (tensor, tensor) -> tensor + %2 = "tf.AddV2"(%1, %arg2) : (tensor, tensor) -> tensor + return %2 : tensor +} +// CHECK: func @callee_resource_lifted(%[[A0:.*]]: tensor, %[[A1:.*]]: tensor, %[[A2:.*]]: tensor) -> tensor +// CHECK-NEXT: %[[ADD0:.*]] = "tf.AddV2"(%[[A1]], %[[A0]]) +// CHECK-NEXT: %[[ADD1:.*]] = "tf.AddV2"(%[[ADD0]], %[[A2]]) +// CHECK-NEXT: return %[[ADD1]] + + +// ----- + +// Tests that the pass lifts resources on two stateful partitioned call ops +// sharing the same callee. The lifting should clone the callee then modify the +// clone. + +// CHECK-LABEL: @launch_with_stateful_partitioned_call +func @launch_with_stateful_partitioned_call() -> () { + // CHECK: %[[VH0:.*]] = "tf.VarHandleOp"() + %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> + // CHECK: %[[VH1:.*]] = "tf.VarHandleOp"() + %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource>> + // CHECK: %[[CONST:.*]] = "tf.Const"() + %2 = "tf.Const"() {value = dense<10.0> : tensor} : () -> tensor + // CHECK-DAG: %[[READ0:.*]] = "tf.ReadVariableOp"(%[[VH0]]) + // CHECK-DAG: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[VH1]]) + // CHECK: %[[LAUNCH:.*]] = "tf_device.launch"() + "tf_device.launch"() ( { + // CHECK: %[[PC0:.*]] = "tf.StatefulPartitionedCall"(%[[READ0]], %[[READ1]], %[[CONST]]) + // CHECK-SAME: f = @callee_resource_lifted + %3 = "tf.StatefulPartitionedCall"(%0, %1, %2) {f = @callee, config = "", config_proto = "", executor_type = ""} + : (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, tensor) -> tensor<*x!tf.resource>> + // CHECK: %[[PC1:.*]] = "tf.StatefulPartitionedCall"(%[[PC0]], %[[READ1]], %[[CONST]]) + // CHECK-SAME: f = @callee_resource_lifted + %4 = "tf.StatefulPartitionedCall"(%3, %1, %2) {f = @callee, config = "", config_proto = "", executor_type = ""} + : (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, tensor) -> tensor<*x!tf.resource>> + // CHECK: tf_device.return %[[PC1]] : tensor + tf_device.return + // CHECK: {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor + }) {device = "tpu0", launch_attr = "launch_attr"} : () -> () + // CHECK: "tf.AssignVariableOp"(%[[VH0]], %[[LAUNCH]]) + return +} +// CHECK: @callee(%[[OA0:.*]]: tensor<*x!tf.resource>>, %[[OA1:.*]]: tensor<*x!tf.resource>>, %[[OA2:.*]]: tensor) -> tensor<*x!tf.resource>> +func @callee(%arg0: tensor<*x!tf.resource>>, %arg1: tensor<*x!tf.resource>>, %arg2: tensor) -> tensor<*x!tf.resource>> { + // CHECK: "tf.ReadVariableOp"(%[[OA1]]) + %0 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource>>) -> tensor + %1 = "tf.AddV2"(%0, %arg2) : (tensor, tensor) -> tensor + "tf.AssignVariableOp"(%arg0, %1) {dtype = i32} : (tensor<*x!tf.resource>>, tensor) -> () + return %arg0 : tensor<*x!tf.resource>> +} +// CHECK: func @callee_resource_lifted(%[[A0:.*]]: tensor, %[[A1:.*]]: tensor, %[[A2:.*]]: tensor) -> tensor +// CHECK-NEXT: %[[ADD:.*]] = "tf.AddV2"(%[[A1]], %[[A2]]) +// CHECK-NEXT: return %[[ADD]] + + +// ----- + +// Tests that the pass reports error on called function that has resource output +// which doesn't alias an input. + +func @launch_with_stateful_partitioned_call() -> () { + %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> + %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource>> + %2 = "tf.Const"() {value = dense<10.0> : tensor} : () -> tensor + "tf_device.launch"() ( { + %3 = "tf.StatefulPartitionedCall"(%0, %1, %2) {f = @callee, config = "", config_proto = "", executor_type = ""} + : (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, tensor) -> tensor<*x!tf.resource>> + %4 = "tf.StatefulPartitionedCall"(%3, %1, %2) {f = @callee, config = "", config_proto = "", executor_type = ""} + : (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, tensor) -> tensor<*x!tf.resource>> + tf_device.return + }) {device = "tpu0", launch_attr = "launch_attr"} : () -> () + return +} +// expected-error @+1 {{Unsupported function call: resource return value does not alias an input.}} +func @callee(%arg0: tensor<*x!tf.resource>>, %arg1: tensor<*x!tf.resource>>, %arg2: tensor) -> tensor<*x!tf.resource>> { + %0 = "tf._Unknown_"() : () -> tensor<*x!tf.resource>> + return %0 : tensor<*x!tf.resource>> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir index 23cc06de453..c9db7e0a1dc 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir @@ -45,6 +45,17 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr return %1 : tensor<*xf32> } +// CHECK-LABEL: func @multiple_blocks_one_return(%arg0: tensor) -> tensor +func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { + br ^bb1 +^bb1: +// CHECK: %[[IDENTITY:.*]] = "tf.Identity"(%arg0) : (tensor) -> tensor +// CHECK: return %[[IDENTITY]] : tensor + %ret = "tf.Identity"(%arg0) : (tensor) -> tensor<*xf32> + return %ret : tensor<*xf32> +} + + // Tests the case where an inference opportunity relies on folding. // CHECK-LABEL: func @simple_folding diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/basic.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/basic.py index 52ed0b4ed2b..4248099637c 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/basic.py +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/basic.py @@ -46,7 +46,7 @@ class TestModule(tf.Module): # CHECK: "tf_saved_model.global_tensor"() {sym_name = "[[CONST:[a-zA-Z_0-9]+]]", tf_saved_model.exported_names = [], type = tensor, value = dense<4.300000e+01> : tensor} : () -> () # CHECK: func {{@[a-zA-Z_0-9]+}}( # CHECK-SAME: %arg0: tensor {tf_saved_model.index_path = [0]}, - # CHECK-SAME: %arg1: tensor<*x!tf.resource> {tf_saved_model.bound_input = @[[VAR]]}, + # CHECK-SAME: %arg1: tensor>> {tf_saved_model.bound_input = @[[VAR]]}, # CHECK-SAME: %arg2: tensor {tf_saved_model.bound_input = @[[CONST]]}) -> ( # CHECK-SAME: tensor {tf_saved_model.index_path = []}) # CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["some_function"] diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/call_to_exported.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/call_to_exported.py index 8e9e197d62f..658cc37a22f 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/call_to_exported.py +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/call_to_exported.py @@ -46,7 +46,7 @@ class TestModule(tf.Module): # # CHECK: func {{@[a-zA-Z_0-9]+}}( # CHECK-SAME: %arg0: tensor {tf_saved_model.index_path = [0]}, - # CHECK-SAME: %arg1: tensor<*x!tf.resource> {tf_saved_model.bound_input = {{@[a-zA-Z_0-9]+}}} + # CHECK-SAME: %arg1: tensor> {tf_saved_model.bound_input = {{@[a-zA-Z_0-9]+}}} # CHECK-SAME: ) -> ( # CHECK-SAME: tensor {tf_saved_model.index_path = [0]}, # CHECK-SAME: tensor {tf_saved_model.index_path = [1]}) @@ -55,7 +55,7 @@ class TestModule(tf.Module): # # CHECK: func {{@[a-zA-Z_0-9]+}}( # CHECK-SAME: %arg0: tensor {tf_saved_model.index_path = [0]}, - # CHECK-SAME: %arg1: tensor<*x!tf.resource> {tf_saved_model.bound_input = {{@[a-zA-Z_0-9]+}}} + # CHECK-SAME: %arg1: tensor> {tf_saved_model.bound_input = {{@[a-zA-Z_0-9]+}}} # CHECK-SAME: ) -> ( # CHECK-SAME: tensor {tf_saved_model.index_path = [0]}, # CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = [1]}) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_inline_global_tensors.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_inline_global_tensors.mlir index d1e1c9d6b09..365a5a3f402 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_inline_global_tensors.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_inline_global_tensors.mlir @@ -25,8 +25,8 @@ module attributes {tf_saved_model.semantics} { // CHECK: tf_saved_model.global_tensor "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor, value = dense<1.0> : tensor } : () -> () - // CHECK: func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v}) - func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v}) + // CHECK: func @f(%arg0: tensor>> {tf_saved_model.bound_input = @v}) + func @f(%arg0: tensor>> {tf_saved_model.bound_input = @v}) attributes {tf_saved_model.exported_names = ["f"]} { // CHECK-NOT: tf.Const return diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops.mlir index cc809909f79..1bf172b2655 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops.mlir @@ -26,7 +26,7 @@ module attributes {tf_saved_model.semantics} { func @__concrete_function_run_computation( %arg0: tensor {tf_saved_model.index_path = [0, "foo"]}, %arg1: tensor<1x64xf32> {tf_saved_model.bound_input = @some_constant}, - %arg2: tensor<*x!tf.resource> {tf_saved_model.bound_input = @some_variable} + %arg2: tensor>> {tf_saved_model.bound_input = @some_variable} ) -> ( tensor {tf_saved_model.index_path = [0, "bar"]} ) attributes { tf_saved_model.exported_names = ["some_func"] } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops_invalid.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops_invalid.mlir index 0a5fe2708c1..6e6c8ae3821 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops_invalid.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops_invalid.mlir @@ -219,8 +219,8 @@ module attributes {tf_saved_model.semantics} { "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor, value = dense<42.0> : tensor } : () -> () // expected-error@+1 {{duplicate 'tf_saved_model.bound_input' binding}} func @f( - %arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v}, - %arg1: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v} + %arg0: tensor>> {tf_saved_model.bound_input = @v}, + %arg1: tensor>> {tf_saved_model.bound_input = @v} ) attributes {tf_saved_model.exported_names = ["f"]} { return } @@ -232,9 +232,9 @@ module attributes {tf_saved_model.semantics} { "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor, value = dense<1.> : tensor<1xf32> } : () -> () // expected-error@+1 {{can only apply 'tf_saved_model' argument attributes to exported functions}} - func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v}) + func @f(%arg0: tensor>> {tf_saved_model.bound_input = @v}) -> (tensor {tf_saved_model.index_path = []}) { - %0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>) -> tensor + %0 = "tf.ReadVariableOp"(%arg0) : (tensor>>) -> tensor return %0 : tensor } } @@ -244,7 +244,7 @@ module attributes {tf_saved_model.semantics} { module attributes {tf_saved_model.semantics} { "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor, value = dense<1.> : tensor<1xf32> } : () -> () - // expected-error@+1 {{bound inputs for mutable 'tf_saved_model.global_tensor's must be tensors of '!tf.resource'}} + // expected-error@+1 {{mutable bound input with type 'tensor' expected to have type 'tensor>>'}} func @f(%arg0: tensor {tf_saved_model.bound_input = @v}) attributes {tf_saved_model.exported_names = ["f"]} { return @@ -257,7 +257,7 @@ module attributes {tf_saved_model.semantics} { "tf_saved_model.global_tensor"() { sym_name = "v", type = tensor<1xf32>, value = dense<1.> : tensor<1xf32> } : () -> () // expected-error@+1 {{bound input for immutable 'tf_saved_model.global_tensor' must match the global tensor's type}} - func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v}) + func @f(%arg0: tensor>> {tf_saved_model.bound_input = @v}) attributes {tf_saved_model.exported_names = ["f"]} { return } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors.mlir index 95b0bd54d70..f2a4373c777 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors.mlir @@ -14,10 +14,10 @@ module attributes {tf_saved_model.semantics} { "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor, value = dense<42.> : tensor } : () -> () // CHECK: func @f(%arg0: tensor {tf_saved_model.bound_input = @v}) - func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v}) -> (tensor {tf_saved_model.index_path = []}) + func @f(%arg0: tensor>> {tf_saved_model.bound_input = @v}) -> (tensor {tf_saved_model.index_path = []}) attributes {tf_saved_model.exported_names = ["f"]} { // CHECK-NOT: tf.ReadVariableOp - %val = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>) -> tensor + %val = "tf.ReadVariableOp"(%arg0) : (tensor>>) -> tensor // CHECK: return %arg0 return %val : tensor } @@ -35,12 +35,12 @@ module attributes {tf_saved_model.semantics} { // CHECK-SAME: } : () -> () "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor, value = dense<42.> : tensor } : () -> () - // CHECK: func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v}) - func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v}) + // CHECK: func @f(%arg0: tensor>> {tf_saved_model.bound_input = @v}) + func @f(%arg0: tensor>> {tf_saved_model.bound_input = @v}) attributes {tf_saved_model.exported_names = ["f"]} { %c0 = "tf.Const"() { value = dense<1.0> : tensor } : () -> tensor // CHECK: tf.AssignVariableOp - "tf.AssignVariableOp"(%arg0, %c0) : (tensor<*x!tf.resource>, tensor) -> () + "tf.AssignVariableOp"(%arg0, %c0) : (tensor>>, tensor) -> () return } @@ -57,10 +57,10 @@ module attributes {tf_saved_model.semantics} { // CHECK-SAME: } : () -> () "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", tf_saved_model.exported_names = ["v"], type = tensor, value = dense<42.> : tensor } : () -> () - // CHECK: func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v}) - func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v}) -> (tensor {tf_saved_model.index_path = []}) + // CHECK: func @f(%arg0: tensor>> {tf_saved_model.bound_input = @v}) + func @f(%arg0: tensor>> {tf_saved_model.bound_input = @v}) -> (tensor {tf_saved_model.index_path = []}) attributes {tf_saved_model.exported_names = ["f"]} { - %val = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>) -> tensor + %val = "tf.ReadVariableOp"(%arg0) : (tensor>>) -> tensor return %val : tensor } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc index 80fcd52056d..9660367cb68 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Twine.h" @@ -70,10 +71,20 @@ void TPUBridgeExecutorIslandInlining::runOnModule() { call_op.emitOpError() << "Failed to inline\n"; return WalkResult::interrupt(); } + called_func.erase(); call_op.erase(); return WalkResult::advance(); }); if (walk_result.wasInterrupted()) return signalPassFailure(); + // Move all remaining nested functions back into the parent module. + Block &nested_block = nested_module->getRegion(0).front(); + for (FuncOp func_op : + llvm::make_early_inc_range(nested_block.getOps())) { + if (!symbol_table.lookupSymbolIn(getModule(), func_op.getName())) { + nested_block.getOperations().remove(func_op.getOperation()); + symbol_table.insert(func_op.getOperation()); + } + } nested_module->erase(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_island_coarsening.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_island_coarsening.cc index cd669abcc24..cc87bd31486 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_island_coarsening.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_island_coarsening.cc @@ -29,10 +29,12 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Casting.h" +#include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Block.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project #include "mlir/IR/Location.h" // TF:llvm-project #include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/SymbolTable.h" // TF:llvm-project #include "mlir/IR/UseDefLists.h" // TF:llvm-project #include "mlir/IR/Visitors.h" // TF:llvm-project #include "mlir/Pass/Pass.h" // TF:llvm-project @@ -57,8 +59,8 @@ constexpr llvm::StringRef kTpuStatusAttr = "_tpu_compilation_status"; // TPU-annotated operations and intended to preserve backward compatibility with // TFv1. struct TpuV1BridgeExecutorIslandCoarsening - : public FunctionPass { - void runOnFunction() override; + : public ModulePass { + void runOnModule() override; }; // Sort the Operations in the provided range to enforce dominance. @@ -88,9 +90,10 @@ LogicalResult SortTopologically(Block::iterator first_op, Operation* producer_in_block = block->findAncestorOpInBlock(*defining_op); if (producer_in_block && producer_in_block != &op && - unscheduled_ops.count(producer_in_block)) + unscheduled_ops.count(producer_in_block)) { // Found an operand that isn't scheduled yet, interrupt the walk. return WalkResult::interrupt(); + } } return WalkResult::advance(); }); @@ -113,7 +116,9 @@ LogicalResult SortTopologically(Block::iterator first_op, // A failure is returned if a cycle preventing the merge from happening // correctly without breaking dominance. The IR is left in invalid state in case // of failure. -LogicalResult MergeIsland(Operation* op, bool* changed) { +LogicalResult MergeIsland(llvm::function_ref + is_op_calling_func_for_cluster, + Operation* op, bool* changed) { // Find the first island wrapping a single operation with the `_tpu_replicate` // attribute, it'll be used as the root of the algorithm to find the other // operations that are part of the same cluster. @@ -146,7 +151,9 @@ LogicalResult MergeIsland(Operation* op, bool* changed) { if (!candidate_cluster_name) candidate_cluster_name = candidate_wrapped_op.getAttrOfType(kTpuStatusAttr); - if (candidate_cluster_name != cluster_name) continue; + if (candidate_cluster_name != cluster_name && + !is_op_calling_func_for_cluster(cluster_name, &candidate_wrapped_op)) + continue; // Look at captured operands to bring-in ReplicatedInputOp in the // island as well. TODO: also pull in tf.Const, some optimizations can @@ -250,34 +257,71 @@ LogicalResult MergeIsland(Operation* op, bool* changed) { first_op_after); } -void TpuV1BridgeExecutorIslandCoarsening::runOnFunction() { - getFunction().walk([&](GraphOp graph) { - Block& graph_body = graph.GetBody(); +void TpuV1BridgeExecutorIslandCoarsening::runOnModule() { + SymbolTable symbol_table(getModule()); - // Iterate until fixed point on the block, as it may contain multiple - // clusters. - bool changed = true; - while (changed) { - changed = false; - for (Operation& op : graph_body) { - if (failed(MergeIsland(&op, &changed))) { - graph.emitError() << "Merging island failed: the TPU cluster likely " - << "contains a cycle with non-TPU operations\n"; - signalPassFailure(); - return WalkResult::interrupt(); - } - // If islands were merged, restart scanning the block from the beginning - // as we lost track of where to continue. - if (changed) break; - } + // Map tpu cluster names to the functions that contain operations for this + // cluster. + DenseMap> tpu_funcs; + for (FuncOp func_op : getModule().getOps()) { + func_op.walk([&](Operation* op) { + StringAttr cluster_name = + op->getAttrOfType(kTpuReplicateAttr); + if (!cluster_name) + cluster_name = op->getAttrOfType(kTpuStatusAttr); + if (!cluster_name) return; + tpu_funcs[cluster_name.getValue()].insert(func_op); + }); + } + + // Return true if the operation is containing a reference to a function + // containing operations for this cluster. + auto is_op_calling_func_for_cluster = [&](StringAttr cluster, Operation* op) { + auto funcs_for_cluster = tpu_funcs.find(cluster.getValue()); + assert(funcs_for_cluster != tpu_funcs.end()); + assert(!funcs_for_cluster->second.empty()); + if (funcs_for_cluster->second.size() == 1) return false; + for (NamedAttribute attr : op->getAttrs()) { + auto symbol_ref = attr.second.dyn_cast(); + if (!symbol_ref) continue; + FuncOp callee = symbol_table.lookup(symbol_ref.getValue()); + if (!callee) continue; + if (funcs_for_cluster->second.count(callee)) return true; } - return WalkResult::advance(); - }); + return false; + }; + + for (FuncOp func_op : getModule().getOps()) { + func_op.walk([&](GraphOp graph) { + Block& graph_body = graph.GetBody(); + + // Iterate until fixed point on the block, as it may contain multiple + // clusters. + bool changed = true; + while (changed) { + changed = false; + for (Operation& op : graph_body) { + if (failed( + MergeIsland(is_op_calling_func_for_cluster, &op, &changed))) { + graph.emitError() + << "Merging island failed: the TPU cluster likely " + << "contains a cycle with non-TPU operations\n"; + signalPassFailure(); + return WalkResult::interrupt(); + } + // If islands were merged, restart scanning the block from the + // beginning as we lost track of where to continue. + if (changed) break; + } + } + return WalkResult::advance(); + }); + } } } // namespace -std::unique_ptr> +std::unique_ptr> CreateTFExecutorTPUV1IslandCoarseningPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc index b553a74d097..57ea1822b5b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc @@ -133,9 +133,23 @@ void TPUBridgeExecutorIslandOutlining::runOnModule() { /*executor_type=*/builder.getStringAttr("")); SmallVector yield_operands(call_op.getResults()); builder.create(island_op.getLoc(), yield_operands); + } - // TODO(aminim): handle transitively referenced function and clone them in - // the new module. + // Outlined all the transitively called functions by moving them in the + // outlined module. + for (FuncOp func : outlined_module.getOps()) { + func.walk([&](Operation *op) { + for (NamedAttribute attr : op->getAttrs()) { + auto symbol_ref = attr.second.dyn_cast(); + if (!symbol_ref) continue; + if (outlined_symbol_table.lookup(symbol_ref.getValue())) + continue; + FuncOp callee = symbol_table.lookup(symbol_ref.getValue()); + callee.getOperation()->getBlock()->getOperations().remove( + callee.getOperation()); + outlined_symbol_table.insert(callee); + } + }); } } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc b/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc index 24624e356ea..e16c71673e4 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc @@ -13,10 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "llvm/ADT/STLExtras.h" +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project #include "mlir/IR/Function.h" // TF:llvm-project #include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/PassManager.h" // TF:llvm-project #include "mlir/Pass/PassRegistry.h" // TF:llvm-project +#include "mlir/Transforms/Passes.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #define DEBUG_TYPE "tf-layout-optimization" @@ -25,9 +31,15 @@ namespace TF { namespace { +// LayoutAssignmentPass assigns optimal data layout (data format) for all +// layout sensitive operations. class LayoutAssignmentPass : public FunctionPass { public: LayoutAssignmentPass() = default; + explicit LayoutAssignmentPass(const std::string& force_data_format) { + force_data_format_ = force_data_format; + } + LayoutAssignmentPass(const LayoutAssignmentPass& pass) {} void runOnFunction() final; @@ -39,6 +51,29 @@ class LayoutAssignmentPass : public FunctionPass { llvm::cl::desc("Force data format for all layout sensitive ops")}; }; +// MoveTransposesPass moves all Transpose ops to the beginning or to the end of +// the basic block where they are defined. This will allow canonicalzer to +// delete redundant transposes. +class MoveTransposesPass : public FunctionPass { + public: + enum class Direction { kBegin, kEnd }; + + MoveTransposesPass() = default; + explicit MoveTransposesPass(Direction direction) { direction_ = direction; } + MoveTransposesPass(const MoveTransposesPass& pass) {} + + void runOnFunction() final; + + private: + Option direction_{ + *this, "direction", + llvm::cl::desc("Move transposes to the beginning or the end of the block " + "where they are defined."), + llvm::cl::values( + clEnumValN(Direction::kBegin, "begin", "beginning of the block"), + clEnumValN(Direction::kEnd, "end", "end of the block"))}; +}; + using Permutation = SmallVector; Permutation GetDataFormatPermutation(StringRef from_data_format, @@ -52,22 +87,6 @@ Permutation GetDataFormatPermutation(StringRef from_data_format, } } -Type PermuteRankedTensorType(Type type, Permutation permutation) { - if (auto ranked_type = type.dyn_cast()) { - ArrayRef shape = ranked_type.getShape(); - assert(permutation.size() == shape.size()); - - SmallVector new_shape(permutation.size()); - for (size_t i = 0; i < permutation.size(); ++i) { - new_shape[i] = shape[permutation[i]]; - } - - return RankedTensorType::get(new_shape, ranked_type.getElementType()); - } - - return type; -} - void LayoutAssignmentPass::runOnFunction() { FuncOp func = getFunction(); @@ -100,8 +119,8 @@ void LayoutAssignmentPass::runOnFunction() { }; // Change operation data format. - op->setAttr("data_format", - StringAttr::get(force_data_format_, op->getContext())); + if (failed(layout_sensitive_interface.UpdateDataFormat(force_data_format_))) + return; // Permute arguments into the target data format. builder.setInsertionPoint(op); @@ -118,8 +137,6 @@ void LayoutAssignmentPass::runOnFunction() { for (int64_t res : layout_sensitive_interface.GetLayoutDependentResults()) { OpResult result = op->getResult(res); - result.setType( - PermuteRankedTensorType(result.getType(), args_permutation)); auto transposed_res = builder.create(loc, result, res_perm); result.replaceAllUsesWith(transposed_res); @@ -128,10 +145,287 @@ void LayoutAssignmentPass::runOnFunction() { }); } +// Move Transpose operations that permute `op` results before the `op`. +void MoveTransposeBefore(Operation* op, SmallVector* work_list) { + // TODO(ezhulenev): Move transpose across layout sensitive operations. + if (!op->hasTrait()) return; + + // Transpose operations that use operation results. + SmallVector transpose_ops; + + // Constant operation that defines permutation indices for result transposes. + ConstOp permutation_op; + + // All operation results must be used by transpose operations with the same + // permutation indices. + for (OpResult result : op->getResults()) { + for (Operation* user : result.getUsers()) { + // Result user must be a transpose operation. + TransposeOp transpose = dyn_cast(user); + if (!transpose) return; + + // With permutation defined by constant operation. + ConstOp perm = + dyn_cast_or_null(transpose.getOperand(1).getDefiningOp()); + if (!perm) return; + + // With the same permutation indices. + auto dense_elem_attr = perm.value().dyn_cast(); + if (!dense_elem_attr) return; + + if (!permutation_op) permutation_op = perm; + + // Check that permutation matches for all result transposes. + if (perm.value() != permutation_op.value()) return; + + // Add a transpose operation for later reuse. + transpose_ops.push_back(transpose); + } + } + + // Nothing to do here. + if (!permutation_op || transpose_ops.empty()) return; + + // At this point we checked that we can safely move Transpose node before + // `op`, and bypass all result transposes. + Location loc = op->getLoc(); + + // Move constant op defining result permutation to the beginning of the block. + permutation_op.getOperation()->moveBefore(&op->getBlock()->front()); + + // Bypass Transpose nodes for all results. + for (OpResult result : op->getResults()) { + result.setType(cast(*result.getUsers().begin()).y().getType()); + for (Operation* transpose : result.getUsers()) { + transpose->getResult(0).replaceAllUsesWith(result); + } + } + + // Maybe add a Transpose node for all operands (or reuse existing transposes). + OpBuilder builder(op); + builder.setInsertionPoint(op); + + for (OpOperand& operand : op->getOpOperands()) { + // Try to push transpose further up. + if (Operation* operand_op = operand.get().getDefiningOp()) + work_list->push_back(operand_op); + + // Try to reuse result transposes. + TransposeOp transpose; + if (!transpose_ops.empty()) { + transpose = transpose_ops.pop_back_val(); + transpose.getOperation()->moveBefore(op); + transpose.setOperand(0, operand.get()); + transpose.setOperand(1, permutation_op); + } else { + transpose = + builder.create(loc, operand.get(), permutation_op); + } + + operand.set(transpose); + } + + // Remove unused transpose operations. + while (!transpose_ops.empty()) { + TransposeOp transpose = transpose_ops.pop_back_val(); + transpose.erase(); + } +} + +// Move Transpose operations that permute `op` operands after the `op`. +void MoveTransposeAfter(Operation* op, SmallVector* work_list) { + // Indices of operands and results that depend on data layout. + SmallVector layout_dependent_operands; + SmallVector layout_dependent_results; + + auto fold_operands = dyn_cast(op); + bool layout_agnostic = op->hasTrait(); + + if (fold_operands) { + layout_dependent_operands = fold_operands.GetLayoutDependentArgs(); + layout_dependent_results = fold_operands.GetLayoutDependentResults(); + + } else if (layout_agnostic) { + // For layout agnostic operation (e.g. element wise operations) all operands + // and results must have the same data layout. + for (unsigned i = 0; i < op->getNumOperands(); ++i) + layout_dependent_operands.push_back(i); + for (unsigned i = 0; i < op->getNumResults(); ++i) + layout_dependent_results.push_back(i); + } + + // Transpose operations that are operands of the `op`. + SmallVector transpose_ops; + + // Constant operation that defines permutation indices for operand transposes. + ConstOp permutation_op; + + // Layout dependent operands must be transpose operations with the same + // permutation indices. + for (unsigned idx : layout_dependent_operands) { + OpOperand& operand = op->getOpOperand(idx); + + // Operand must be defined by a transpose op. + TransposeOp transpose = + dyn_cast_or_null(operand.get().getDefiningOp()); + if (!transpose) return; + + // With permutation defined by constant operation. + ConstOp perm = + dyn_cast_or_null(transpose.getOperand(1).getDefiningOp()); + if (!perm) return; + + // With the same permutation indices. + auto dense_elem_attr = perm.value().dyn_cast(); + if (!dense_elem_attr) return; + + if (!permutation_op) permutation_op = perm; + + // Check that permutation matches for all result transposes. + if (perm.value() != permutation_op.value()) return; + + // Add a transpose operation for later reuse only if it's used once. + if (transpose.getResult().hasOneUse()) transpose_ops.push_back(transpose); + } + + // Nothing to do here. + if (!permutation_op) return; + + // All results after transpose must preserve the original result type. + SmallVector original_type(op->getNumResults()); + for (unsigned idx : layout_dependent_results) + original_type[idx] = op->getResult(idx).getType(); + + // Check if we can fold transpose into the operation. + if (fold_operands) { + SmallVector permutation; + + auto attr = permutation_op.value().cast(); + for (auto value : attr.getIntValues()) + permutation.push_back(value.getSExtValue()); + + if (failed(fold_operands.FoldOperandsPermutation(permutation))) return; + } + + // At this point we checked that we can safely move Transpose node after + // `op`, bypass all operands transposes, and transpose op results. + Location loc = op->getLoc(); + + // Move constant op defining result permutation to the beginning of the block. + permutation_op.getOperation()->moveBefore(&op->getBlock()->front()); + + // Bypass Transpose nodes for layout dependent operands. + for (unsigned idx : layout_dependent_operands) { + OpOperand& operand = op->getOpOperand(idx); + TransposeOp transpose = + dyn_cast(operand.get().getDefiningOp()); + operand.set(transpose.getOperand(0)); + } + + // Maybe add Transpose nodes for layout dependent results + // (or reuse existing transposes). + OpBuilder builder(op); + builder.setInsertionPoint(op); + + for (unsigned idx : layout_dependent_results) { + OpResult result = op->getResult(idx); + + // Forward operand type only for layout agnostic operations, operations with + // custom folding will update the result type in `FoldOperandsPermutation`. + if (layout_agnostic) result.setType(op->getOperand(0).getType()); + + // Try to push transpose further down. + for (Operation* user : result.getUsers()) work_list->push_back(user); + + // Try to reuse operand transposes. + TransposeOp transpose; + if (!transpose_ops.empty()) { + transpose = transpose_ops.pop_back_val(); + transpose.getOperation()->moveBefore(op->getNextNode()); + transpose.setOperand(0, result); + transpose.setOperand(1, permutation_op); + transpose.getResult().setType(original_type[idx]); + } else { + transpose = builder.create(loc, result, permutation_op); + } + + // Forward all users to the transpose operation. + result.replaceAllUsesWith(transpose); + transpose.setOperand(0, result); + } + + // Remove unused transpose operations. + while (!transpose_ops.empty()) { + TransposeOp transpose = transpose_ops.pop_back_val(); + transpose.erase(); + } +} + +void MoveTransposesPass::runOnFunction() { + FuncOp func = getFunction(); + + SmallVector work_list; + + func.walk([&](TransposeOp transpose) { + if (direction_ == Direction::kBegin) { + // Try to push transpose before the operand operation. + for (auto operand : transpose.getOperands()) { + if (auto op = operand.getDefiningOp()) work_list.push_back(op); + } + } else { + // Try to push transpose after the user operation. + for (Operation* user : transpose.y().getUsers()) { + work_list.push_back(user); + } + } + }); + + while (!work_list.empty()) { + Operation* op = work_list.pop_back_val(); + if (direction_ == Direction::kBegin) { + MoveTransposeBefore(op, &work_list); + } else if (direction_ == Direction::kEnd) { + MoveTransposeAfter(op, &work_list); + } + } + + func.walk([&](TransposeOp transpose) { + OpBuilder builder(transpose); + SmallVector fold_result; + if (succeeded(builder.tryFold(transpose.getOperation(), fold_result))) { + assert(fold_result.size() == 1); + transpose.replaceAllUsesWith(fold_result[0]); + } + }); +} + } // namespace -static PassRegistration pass("tf-layout-assignment", - "Layout assignment pass"); +void CreateLayoutOptimizationPipeline( + OpPassManager& pm, // NOLINT - MLIR contract is pass by mutable reference. + const LayoutOptimizationPipelineOptions& options) { + using Direction = MoveTransposesPass::Direction; + + // Assign optimal layout for layout sensitive ops. + pm.addPass(std::make_unique(options.force_data_format)); + + // Move transposes to the beginning of the block and try to fold them. + pm.addPass(std::make_unique(Direction::kBegin)); + + // Move transposes to the end of the block and try to fold them. + pm.addPass(std::make_unique(Direction::kEnd)); +} + +static PassRegistration layout_assignment( + "tf-layout-assignment", "Layout assignment pass"); +static PassRegistration move_transposes( + "tf-move-transposes", "Move transposes pass"); + +static mlir::PassPipelineRegistration + pipeline("tf-layout-optimization", + "Assigns optimal data layout to all layout sensitive operations " + "and cancel redundant transpose operations.", + CreateLayoutOptimizationPipeline); } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/parallel_execute_to_islands.cc b/tensorflow/compiler/mlir/tensorflow/transforms/parallel_execute_to_islands.cc new file mode 100644 index 00000000000..5caf08c672e --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/parallel_execute_to_islands.cc @@ -0,0 +1,263 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This pass forms `tf_executor.island` per region of +// `tf_device.parallel_execute`. +// +// For example: +// %1:2 = tf_executor.island { +// %2 = "tf.opA"(%arg0) : (tensor) -> tensor +// tf_executor.yield %2 : tensor +// } +// tf_executor.island() { +// "tf_device.parallel_execute"() ({ +// %3 = "tf.opB"() : () -> tensor +// tf_device.return %3 : tensor +// }, +// { +// %5 = "tf.opC"(%1#0) : (tensor) -> tensor +// tf_device.return +// }) {} : () -> (tensor) +// tf_executor.yield +// } +// tf_executor.fetch +// +// Would become: +// %1:2 = tf_executor.island { +// %2 = "tf.opA"(%arg0) : (tensor) -> tensor +// tf_executor.yield %2 : tensor +// } +// +// // Input barrier sink island that forwards all inputs. +// %output_0, %control_1 = tf_executor.island { +// tf_executor.yield %1#0: tensor +// } +// +// // Island for the first region of above parallel_execute. +// %output_2, %control_3 = tf_executor.island(%control_1) { +// %3 = "tf.opB"() : () -> tensor +// tf_executor.yield %3 : tensor +// } +// +// // Island for the second region of above parallel_execute. +// %control_5 = tf_executor.island { +// %5 = "tf.opC"(%output_0) : (tensor) -> tensor +// tf_executor.yield +// } +// +// // Output barrier sink island that forwards all outputs. +// %output_5, %control_6 = tf_executor.island(%control_5) { +// tf_executor.yield %output_2 +// } +// +// When tf_device.parallel_execute op is enclosed after tf_device.replicate, +// then this pass will run following `replicate-to-island` pass and +// `tf-executor-break-up-islands` pass. + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Block.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Support/LLVM.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" + +namespace mlir { +namespace TFDevice { +namespace { + +struct ParallelExecuteToIslandsPass + : public FunctionPass { + void runOnFunction() override; +}; + +// Convert parallel_execute op to a set of islands where each region of +// parallel_execute op becomes a separate island. This ensures that +// regions of parallel_execute op gets executed concurrently. +LogicalResult ExpandParallelExecuteToIslands( + tf_executor::IslandOp island_op, tf_executor::IslandOp input_sink_island, + tf_device::ParallelExecuteOp parallel_execute_op, OpBuilder* builder, + llvm::SmallVector* islands) { + const int num_executions = + parallel_execute_op.getOperation()->getNumRegions(); + llvm::SmallVector executions; + executions.reserve(num_executions); + builder->setInsertionPoint(island_op); + + auto control_type = tf_executor::ControlType::get(island_op.getContext()); + for (int i : llvm::seq(0, num_executions)) { + auto execute_region = + parallel_execute_op.GetRegionBlockWithIndex(i).getParent(); + + // If region does not have any inputs, then add explicit control dependency + // from the input sink island. This guarantees that all inputs of + // parallel_execute op must be materialized before any of the islands are + // executed. + llvm::SetVector region_inputs; + getUsedValuesDefinedAbove(*execute_region, region_inputs); + llvm::SmallVector execution_control_inputs; + if (region_inputs.empty()) + execution_control_inputs.emplace_back(input_sink_island.control()); + + // Collect result types and operands. + Operation* terminator = execute_region->front().getTerminator(); + llvm::SmallVector output_types(terminator->getOperandTypes()); + + // Replace terminator with YieldOp as island op always ends with yield op. + builder->setInsertionPoint(terminator); + builder->create(terminator->getLoc(), + terminator->getOperands()); + terminator->erase(); + + // Create new island for each region. + builder->setInsertionPoint(island_op); + auto execution_island = builder->create( + island_op.getLoc(), output_types, control_type, + execution_control_inputs); + + // Move over tf_device.parallel_execute body region into newly a + // created island. + execution_island.body().takeBody(*execute_region); + islands->push_back(execution_island); + } + + return success(); +} + +// Creates an island that works as input sync point for islands. This guarantees +// that all (implicitly captured) inputs of parallel_execute are materialized +// before any of the islands are executed. +tf_executor::IslandOp CreateInputBarrierIsland( + OpBuilder* builder, tf_executor::IslandOp island_op) { + builder->setInsertionPoint(island_op); + + llvm::SetVector island_inputs; + getUsedValuesDefinedAbove(island_op.body(), island_inputs); + + llvm::SmallVector input_types; + input_types.reserve(island_inputs.size()); + for (const auto& input_val : island_inputs) + input_types.emplace_back(input_val.getType()); + + // Create new island for that forwards all inputs. + auto control_type = tf_executor::ControlType::get(island_op.getContext()); + auto input_sink_island = builder->create( + island_op.getLoc(), input_types, control_type, island_op.controlInputs()); + input_sink_island.body().push_back(new Block); + + for (auto input_index_and_value : llvm::enumerate(island_inputs)) { + int index = input_index_and_value.index(); + Value input_value = input_index_and_value.value(); + replaceAllUsesInRegionWith(input_value, input_sink_island.getResult(index), + island_op.body()); + } + + // Create YieldOp for the new input sink island. + builder->setInsertionPointToEnd(&input_sink_island.GetBody()); + builder->create(island_op.getLoc(), + llvm::to_vector<8>(island_inputs)); + return input_sink_island; +} + +// Creates an islands that works as output sync point. This guarantees that +// execution of all islands must be completed before op following +// parallel_execute runs. +tf_executor::IslandOp CreateOutputBarrierIsland( + OpBuilder* builder, tf_executor::IslandOp island_op, + llvm::SmallVectorImpl* islands) { + // Add control dependency to island operand if island output has no uses. + llvm::SmallVector island_operands; + for (auto& island : *islands) + if (island.use_empty()) island_operands.push_back(island.control()); + + // Create single island forwarding all island results. + builder->setInsertionPoint(island_op); + auto island_output_sink = builder->create( + island_op.getLoc(), llvm::to_vector<8>(island_op.getResultTypes()), + island_operands, llvm::ArrayRef{}); + island_output_sink.body().push_back(new Block); + return island_output_sink; +} + +LogicalResult CreateIslandsFromParallelExecute( + tf_executor::IslandOp island_op, + tf_device::ParallelExecuteOp parallel_execute_op) { + OpBuilder builder(island_op); + auto input_sink_island = CreateInputBarrierIsland(&builder, island_op); + + // Create N islands where N is the number of regions inside parallel_execute + // op. + llvm::SmallVector islands; + auto result = ExpandParallelExecuteToIslands( + island_op, input_sink_island, parallel_execute_op, &builder, &islands); + if (failed(result)) return result; + + // Remap all results of parallel_execute op with outputs from newly + // created islands. + llvm::SmallVector parallel_execute_outputs; + parallel_execute_outputs.reserve( + parallel_execute_op.getOperation()->getNumResults()); + + for (auto island : islands) + for (auto output_value : island.outputs()) + parallel_execute_outputs.emplace_back(output_value); + + parallel_execute_op.getOperation()->replaceAllUsesWith( + parallel_execute_outputs); + + auto island_output_sink = + CreateOutputBarrierIsland(&builder, island_op, &islands); + + // Move island YieldOp over to new single island and remap island results. + island_op.GetYield().getOperation()->moveBefore( + &island_output_sink.GetBody(), island_output_sink.GetBody().begin()); + island_op.replaceAllUsesWith(island_output_sink); + island_op.erase(); + + return success(); +} + +// Finds islands with a single `tf_device.parallel_execute` and create +// individual islands per region of parallel_execute. +void LowerSingleIslandParallelExecuteToIslands( + tf_executor::IslandOp island_op) { + if (!has_single_element(island_op.GetBody().without_terminator())) return; + + if (auto parallel_execute_op = llvm::dyn_cast( + &island_op.GetBody().front())) + CreateIslandsFromParallelExecute(island_op, parallel_execute_op); +} + +void ParallelExecuteToIslandsPass::runOnFunction() { + getFunction().walk([&](tf_executor::IslandOp island_op) { + LowerSingleIslandParallelExecuteToIslands(island_op); + }); +} +} // anonymous namespace + +std::unique_ptr> CreateParallelExecuteToIslandsPass() { + return std::make_unique(); +} + +static PassRegistration pass( + "tf-parallel-execute-to-islands", + "Lowers device parallel_execute to executor islands"); + +} // namespace TFDevice +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index 02cdb9dc229..548fbf8a8cf 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -46,6 +46,19 @@ std::unique_ptr> CreateTFShapeInferencePass(); // Optimizes Tensorflow graph. std::unique_ptr> CreateTFOptimizePass(); +struct LayoutOptimizationPipelineOptions + : public PassPipelineOptions { + Option force_data_format{ + *this, "force-data-format", + llvm::cl::desc("Force data format for all layout sensitive ops")}; +}; + +// Layout optimization assigns optimal data layout for layout sensitive +// operations, and cancels all redundant transposes. +void CreateLayoutOptimizationPipeline( + OpPassManager& pm, // NOLINT - MLIR contract is pass by mutable reference. + const LayoutOptimizationPipelineOptions& options); + struct StandardPipelineOptions : public PassPipelineOptions { Option enable_inliner{*this, "enable-inliner", @@ -106,7 +119,8 @@ std::unique_ptr> CreateTFExecutorIslandCoarseningPass(); // Creates a pass to merge IslandOps for operation marked for execution on TPU. // This is a V1 backward compatibility. -std::unique_ptr> CreateTFExecutorTPUV1IslandCoarseningPass(); +std::unique_ptr> +CreateTFExecutorTPUV1IslandCoarseningPass(); // Creates a pass to outlining TPU clusters from single IslandOp into a nested // module suitable for being processed as-if it was a V2 module. @@ -164,6 +178,10 @@ std::unique_ptr> CreateReplicateInvariantOpHoistingPass(); // `tf_device.replicate` island. std::unique_ptr> CreateReplicateToIslandPass(); +// Creates a pass that creates `tf_executor.island` from a single +// `tf_device.parallel_execute` island. +std::unique_ptr> CreateParallelExecuteToIslandsPass(); + // Creates a pass that annotates whether a LaunchFuncOp's parameters have the // same data across replicas. std::unique_ptr> CreateAnnotateParameterReplicationPass(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc index 7f0b1b96560..8dc21feca90 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc @@ -31,6 +31,7 @@ limitations under the License. #include "mlir/IR/Function.h" // TF:llvm-project #include "mlir/IR/Module.h" // TF:llvm-project #include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/SymbolTable.h" // TF:llvm-project #include "mlir/IR/TypeUtilities.h" // TF:llvm-project #include "mlir/IR/Types.h" // TF:llvm-project #include "mlir/IR/Value.h" // TF:llvm-project @@ -811,16 +812,185 @@ LogicalResult HanldeIfOP(TF::IfOp if_op, FuncOp then_branch, return success(); } +// A resource-lifted function for (potentially multiple) PartitionedCallOps and +// information about the lifting changes. +struct PartitionedCallLiftingInfo { + // Function with resources lifted. Can be nullptr if nothing needs to change. + FuncOp lifted_callee; + // Mapping from old resource outputs to their aliasing output inputs. + llvm::SmallDenseMap old_outputs_aliasing_old_inputs; + // Mapping from old to new output indices in case any output is removed. + llvm::SmallVector old_to_new_output_indices; + // ResourceArgUseInfo for each old resource argument. + llvm::SmallDenseMap use_info; + // Input for AddLoadsStoresOutsideControlFlowOp(), see its comment. + llvm::SmallDenseMap> + arg_data_type_and_updated_output_index; +}; + +// Lifts loads/stores from a PartitionedCallOp's callee function. If anything +// needs to be changed, the original function will be preserved, and the lifting +// happens on a clone, which will be stored in `result`. +LogicalResult HandlePartitionedCallOpCallee( + FuncOp callee, PartitionedCallLiftingInfo* result) { + // Remove identity nodes to avoid aliasing. + RemoveIdentity(&callee.front()); + // Sanity check: return of resources should be aliases of inputs. Such outputs + // will be removed later. + int64_t non_resource_results = 0; + for (auto entry : + llvm::enumerate(callee.front().getTerminator()->getOperands())) { + auto retval = entry.value(); + if (!getElementTypeOrSelf(retval.getType()).isa()) { + result->old_to_new_output_indices.push_back(non_resource_results++); + continue; + } + auto aliasing_arg = retval.dyn_cast(); + if (!aliasing_arg) { + return callee.emitOpError( + "Unsupported function call: resource return value does not alias an " + "input."); + } + result->old_outputs_aliasing_old_inputs[entry.index()] = + aliasing_arg.getArgNumber(); + result->old_to_new_output_indices.push_back(-1); + } + + if (failed(FindResourceArgUseInfo(callee, &result->use_info))) { + return failure(); + } + if (result->use_info.empty()) { + result->lifted_callee = nullptr; + return success(); + } + + // Clone the callee before making changes. + SmallString<64> name_base = callee.getName(); + auto module = callee.getParentOfType(); + name_base += "_resource_lifted"; + auto name = name_base; + { + int64_t counter = 0; + while (module.lookupSymbol(name)) { + auto name = name_base; + name += "_" + std::to_string(counter++); + } + } + callee = callee.clone(); + callee.setName(name); + SymbolTable(module).insert(callee); + result->lifted_callee = callee; + + // Remove unused resources in functions. + llvm::SmallDenseMap remaining_resource_data_types; + RemoveUnusedResourceArgumentsAndForwardedRetvals( + result->use_info, callee, /*old_to_new_arg_indices=*/nullptr, + &remaining_resource_data_types); + for (const auto& entry : remaining_resource_data_types) { + result->arg_data_type_and_updated_output_index[entry.getFirst()] = { + entry.getSecond(), -1}; + } + llvm::SmallVector new_retvals; + for (auto val : callee.front().getTerminator()->getOperands()) { + // Remove resource type outputs. + if (getElementTypeOrSelf(val.getType()).isa()) continue; + new_retvals.push_back(val); + } + // Lift resources. + LiftArgRetResourcesForFunction( + callee, remaining_resource_data_types, [&](int64_t index, Value value) { + result->arg_data_type_and_updated_output_index[index].second = + new_retvals.size(); + new_retvals.push_back(value); + }); + auto old_return = callee.front().getTerminator(); + // Replace old return with the new ones with update values. + OpBuilder builder(old_return); + auto new_return = builder.create(old_return->getLoc(), new_retvals); + old_return->erase(); + callee.setType(FunctionType::get( + callee.getType().getInputs(), + llvm::to_vector<4>(new_return.getOperandTypes()), callee.getContext())); + return success(); +} + +// Updates a PartitionedCallOp/StatefulPartitionedCallOp according to the +// resource-lifted new callee function in lifting_info. +template +void UpdatePartitionedCallOpWithNewCallee( + CallOpType call_op, const PartitionedCallLiftingInfo& lifting_info) { + if (lifting_info.lifted_callee == nullptr) return; + // Replace output resource uses with the aliasing input, so that we can remove + // this output. + for (const auto& entry : lifting_info.old_outputs_aliasing_old_inputs) { + call_op.getResult(entry.getFirst()) + .replaceAllUsesWith(call_op.getOperand(entry.getSecond())); + } + // Recreate the call op. + OpBuilder builder(call_op); + // Now use the filtered original operands, which will be replaced by + // AddLoadsStoresOutsideControlFlowOp(). + auto new_operands = + FilterRange(call_op.args(), lifting_info.use_info); + auto new_call = builder.create( + call_op.getLoc(), + const_cast(lifting_info.lifted_callee).getType().getResults(), + new_operands, call_op.getAttrs()); + new_call.setAttr( + "f", builder.getSymbolRefAttr( + const_cast(lifting_info.lifted_callee).getName())); + AddLoadsStoresOutsideControlFlowOp( + new_call, lifting_info.arg_data_type_and_updated_output_index); + // Replace uses. + for (int64_t i = 0; i < lifting_info.old_to_new_output_indices.size(); ++i) { + if (lifting_info.old_to_new_output_indices[i] >= 0) { + call_op.getResult(i).replaceAllUsesWith( + new_call.getResult(lifting_info.old_to_new_output_indices[i])); + } + } + call_op.erase(); +} + +LogicalResult HoistForFunctionalControlFlow( + Block*, ModuleOp, llvm::SmallDenseMap*); + +// A templated routine for handling both PartitionedCallOp and +// StatefulPartitionedCallOp. If the callee is already lifted, it just updates +// the caller op itself; otherwise, it first recursively handles nested control +// flow, then performs lifting on the callee. +template +LogicalResult HandlePartitionedCallOp( + CallOpType call_op, FuncOp callee, ModuleOp module, + llvm::SmallDenseMap* lifted_callees) { + auto emplace_res = + lifted_callees->try_emplace(callee, PartitionedCallLiftingInfo()); + if (emplace_res.second) { + // Unseen callee. Perform resource lifting on it. + HoistForFunctionalControlFlow(&callee.front(), module, lifted_callees); + if (failed(HandlePartitionedCallOpCallee( + callee, &emplace_res.first->getSecond()))) { + return failure(); + } + } + UpdatePartitionedCallOpWithNewCallee(call_op, emplace_res.first->getSecond()); + return success(); +} + // Hoists resource loads/stores from control flow ops in `block` outside the -// body/cond/branch functions. -LogicalResult HoistForFunctionalControlFlow(Block* block, ModuleOp module) { +// body/cond/branch/callee functions. +LogicalResult HoistForFunctionalControlFlow( + Block* block, ModuleOp module, + llvm::SmallDenseMap* + lifted_partitioned_call_callees) { for (Operation& op : llvm::make_early_inc_range(*block)) { if (auto while_op = llvm::dyn_cast(&op)) { auto body = llvm::cast(module.lookupSymbol(while_op.body())); auto cond = llvm::cast(module.lookupSymbol(while_op.cond())); // Recursively handle the nested control flow. - HoistForFunctionalControlFlow(&body.front(), module); - HoistForFunctionalControlFlow(&cond.front(), module); + HoistForFunctionalControlFlow(&body.front(), module, + lifted_partitioned_call_callees); + HoistForFunctionalControlFlow(&cond.front(), module, + lifted_partitioned_call_callees); if (failed(HanldeWhileLoop(while_op, body, cond))) return failure(); } else if (auto if_op = llvm::dyn_cast(&op)) { auto then_branch = @@ -828,9 +998,30 @@ LogicalResult HoistForFunctionalControlFlow(Block* block, ModuleOp module) { auto else_branch = llvm::cast(module.lookupSymbol(if_op.else_branch())); // Recursively handle the nested control flow. - HoistForFunctionalControlFlow(&then_branch.front(), module); - HoistForFunctionalControlFlow(&else_branch.front(), module); + HoistForFunctionalControlFlow(&then_branch.front(), module, + lifted_partitioned_call_callees); + HoistForFunctionalControlFlow(&else_branch.front(), module, + lifted_partitioned_call_callees); if (failed(HanldeIfOP(if_op, then_branch, else_branch))) return failure(); + } else if (auto call_op = llvm::dyn_cast(&op)) { + if (!call_op.f().isa()) { + return call_op.emitError( + "Resource lifting does not support call with nested references."); + } + auto callee = llvm::cast( + module.lookupSymbol(call_op.f().getRootReference())); + if (failed(HandlePartitionedCallOp(call_op, callee, module, + lifted_partitioned_call_callees))) { + // Nested control flow handling is done in HandlePartitionedCallOp(). + return failure(); + } + } else if (auto call_op = + llvm::dyn_cast(&op)) { + auto callee = llvm::cast(module.lookupSymbol(call_op.f())); + if (failed(HandlePartitionedCallOp(call_op, callee, module, + lifted_partitioned_call_callees))) { + return failure(); + } } } return success(); @@ -840,10 +1031,13 @@ LogicalResult HoistForFunctionalControlFlow(Block* block, ModuleOp module) { // outside. Returns failure if there are remaining resource-type values that can // not be lifted. void ResourceOpLiftingPass::runOnModule() { + llvm::SmallDenseMap + lifted_partitioned_call_callees; auto result = getModule().walk([&](FuncOp func_op) { return func_op.walk([&](tf_device::LaunchOp launch_op) { - if (failed(HoistForFunctionalControlFlow(&launch_op.GetBody(), - getModule())) || + if (failed(HoistForFunctionalControlFlow( + &launch_op.GetBody(), getModule(), + &lifted_partitioned_call_callees)) || failed(HoistResourceOpsFromLaunchOp(launch_op))) { return WalkResult::interrupt(); } @@ -901,8 +1095,11 @@ LogicalResult ResourceLiftingForFunctionalControlFlow(FuncOp function) { << function.getBlocks().size(); } + llvm::SmallDenseMap + lifted_partitioned_call_callees; return HoistForFunctionalControlFlow(&function.front(), - cast(function.getParentOp())); + cast(function.getParentOp()), + &lifted_partitioned_call_callees); } } // namespace TF diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index fd485d17374..c44f0f97fd6 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -60,16 +60,23 @@ namespace TF { namespace { Optional> InferShapeForFunctionReturnType( FuncOp func) { - // Only infer shape when there is one return op for now. - if (!has_single_element(func.getBody()) || func.front().empty()) { + // Find any return ops. + SmallVector return_ops; + for (Block& block : func) { + if (auto return_op = dyn_cast(block.getTerminator())) { + return_ops.push_back(return_op); + } + } + + // Right now we only handle the case of a single return op. + // To handle multiple return ops, we would need to look at all their shapes + // and come up with a common shape and insert appropriate casts. + if (return_ops.size() != 1) { return None; } // Find the return type. - auto return_op = dyn_cast(func.front().back()); - if (!return_op) { - return None; - } + auto return_op = return_ops.front(); // Manually fold tf.Cast that precedes the return instruction and only differs // in shape refinement level. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_assignment.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_assignment.cc index a4a8c1ab95f..83451e130ba 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_assignment.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_assignment.cc @@ -34,15 +34,16 @@ class SimpleTFDeviceAssignmentPass void runOnFunction() override { Builder builder(&getContext()); - getFunction().walk([this, &builder](Operation* op) { + Dialect* tf = getContext().getRegisteredDialect(); + getFunction().walk([&](Operation* op) { if (auto device_attr = op->getAttrOfType("device")) { // We assign default device to ops with device attribute that is empty. if (device_attr.getValue() == "") { op->setAttr("device", builder.getStringAttr(default_device_)); } - } else if (llvm::isa(op)) { - // tf.Const may sometimes contain no device attribute. In this case, we - // assign it the default device. + } else if (op->getDialect() == tf) { + // Assign default device to all ops in Tensorflow dialect that do not + // have device attribute. op->setAttr("device", builder.getStringAttr(default_device_)); } }); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc index 84ae3e735f2..e7bd44464d0 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc @@ -263,6 +263,7 @@ tf_device::ReplicateOp AddInputsToReplicateOp( llvm::SmallVector, Type>, 8> new_replicated_inputs; llvm::SmallVector, 8> replicated_inputs; + replicated_inputs.reserve(replicate.GetBody().getNumArguments()); for (auto arg : llvm::enumerate(replicate.GetBody().getArguments())) { int64_t i = arg.index(); replicated_inputs.emplace_back(); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc b/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc index cef1f4e5567..8136db7d164 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc @@ -42,8 +42,8 @@ namespace mlir { namespace { -struct BreakUpIslands : OperationPass { - void runOnOperation() final; +struct BreakUpIslands : FunctionPass { + void runOnFunction() final; void BreakUpIsland(tf_executor::IslandOp island_op, const TF::SideEffectAnalysis& side_effect_analysis, @@ -51,8 +51,8 @@ struct BreakUpIslands : OperationPass { new_control_inputs); }; -void BreakUpIslands::runOnOperation() { - auto graph_op_range = getOperation().getBody().front().without_terminator(); +void BreakUpIslands::runOnFunction() { + auto graph_op_range = getFunction().getBody().front().without_terminator(); tf_executor::GraphOp graph_op; if (graph_op_range.begin() != graph_op_range.end() && std::next(graph_op_range.begin()) == graph_op_range.end()) { @@ -60,7 +60,7 @@ void BreakUpIslands::runOnOperation() { getOperation().getBody().front().front()); } if (!graph_op) { - getOperation().emitError("Expected function to contain only a graph_op"); + getOperation().emitError("expected function to contain only a graph_op"); signalPassFailure(); return; } @@ -239,7 +239,7 @@ void BreakUpIslands::BreakUpIsland( } else { // TODO(parkers): Any defining op that has a control output can be handled // just like an island. - fetch.getDefiningOp()->emitError("Fetching non-island as dependency."); + fetch.getDefiningOp()->emitError("fetching non-island as dependency"); return signalPassFailure(); } } @@ -298,18 +298,21 @@ void BreakUpIslands::BreakUpIsland( auto& sink_island_control = sink_island_controls[0]; island_op.control().replaceAllUsesWith(sink_island_control); // All existing outputs need to add sink_island_control as control input. + // GraphOp, YieldOp and NextIterationSourceOp don't have control inputs so + // exclude them below. for (Value out : island_op.outputs()) { for (auto& use : out.getUses()) { Operation* owner = use.getOwner(); if (auto other_island_op = llvm::dyn_cast(owner->getParentOp())) { (*new_control_inputs)[other_island_op].push_back(sink_island_control); - } else if (llvm::isa(owner) || - llvm::isa(owner) || - llvm::isa(owner)) { + } else if (owner->getDialect() == island_op.getDialect() && + !llvm::isa(owner) && + !llvm::isa(owner) && + !llvm::isa(owner)) { (*new_control_inputs)[owner].push_back(sink_island_control); } else { - use.getOwner()->emitError("Adding control dependency not supported"); + owner->emitOpError("adding control dependency not supported"); return signalPassFailure(); } } diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index f6939abdf9f..39fe17800c9 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -51,6 +51,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // TF:llvm-project #include "mlir/IR/Module.h" // TF:llvm-project #include "mlir/IR/OpDefinition.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project #include "mlir/IR/Types.h" // TF:llvm-project #include "tensorflow/compiler/jit/shape_inference_helpers.h" #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" @@ -2515,6 +2516,43 @@ void StructuredValueLinearizer::RecursivelyFindLeaves( } } +// For exported functions with mutable bound inputs, rewrite the function +// signature to annotate resource subtypes on the types. +// +// The raw imported functions have `tensor<*x!tf.resource>` as the type for +// mutable bound inputs. Here we turn that into +// `tensor>>`. +void SetResourceSubtypes(mlir::ModuleOp module) { + mlir::SymbolTable symbol_table(module); + for (auto func : module.getOps()) { + if (!mlir::tf_saved_model::IsExported(func)) continue; + mlir::OpBuilder builder(func.getBody()); + llvm::SmallVector new_input_types; + for (int i = 0, e = func.getNumArguments(); i < e; i++) { + auto arg = func.front().getArgument(i); + auto global_tensor = + mlir::tf_saved_model::LookupBoundInput(func, i, symbol_table); + if (global_tensor && global_tensor.is_mutable()) { + auto old_type = arg.getType(); + auto new_type = mlir::RankedTensorType::get( + {}, mlir::TF::ResourceType::get( + {global_tensor.type().cast()}, + module.getContext())); + arg.setType(new_type); + auto arg_with_original_type = builder.create( + global_tensor.getLoc(), old_type, arg, + /*Truncate=*/builder.getBoolAttr(false)); + arg.replaceAllUsesWith(arg_with_original_type); + // The RAUW replaces the arg with itself, so we need to set it back. + arg_with_original_type.setOperand(arg); + } + new_input_types.push_back(arg.getType()); + } + func.setType(mlir::FunctionType::get( + new_input_types, func.getType().getResults(), module.getContext())); + } +} + // Reorder the ops in the module to make testing easier and less dependent // on implementation details such as the order of functions in the // FunctionDefLibrary. @@ -2755,6 +2793,7 @@ Status CreateSavedModelIR( builder.getStrArrayAttr(object_names.GetExportedNames(node_id))); } } + SetResourceSubtypes(module); module.setAttr("tf_saved_model.semantics", builder.getUnitAttr()); SortSavedModelModule(module); return Status::OK(); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc index ead26c8f17d..1b8ae8403bf 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc @@ -27,6 +27,7 @@ limitations under the License. #include "mlir/IR/Operation.h" // TF:llvm-project #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/path.h" namespace tensorflow { @@ -42,7 +43,8 @@ std::string MakeUniqueFilename(string name) { // Remove illegal characters from `name`. for (int i = 0; i < name.size(); ++i) { char ch = name[i]; - if (ch == '/' || ch == '[' || ch == ']' || ch == '*' || ch == '?') { + if (ch == '/' || ch == '[' || ch == ']' || ch == '*' || ch == '?' || + ch == '\\') { name[i] = '_'; } } @@ -97,18 +99,18 @@ struct WritableFileRawStream : public llvm::raw_ostream { Status CreateFileForDumping(llvm::StringRef name, std::unique_ptr* os, std::string* filepath, llvm::StringRef dirname) { - const char* dir = nullptr; + std::string dir; if (!dirname.empty()) - dir = dirname.data(); + dir = std::string(dirname); else dir = GetDumpDirFromEnvVar(); - if (!dir) { + if (dir.empty()) { return Status(error::Code::INVALID_ARGUMENT, "(TF_DUMP_GRAPH_PREFIX not specified)"); } - if (std::strncmp(dir, "-", 2) == 0) { + if (dir == "-") { *os = std::make_unique(); *filepath = "LOG(INFO)"; return Status(); @@ -122,10 +124,7 @@ Status CreateFileForDumping(llvm::StringRef name, << "' directory for dumping: " << status; return Status(error::Code::UNAVAILABLE, "(unavailable)"); } - *filepath = llvm::Twine(dir) - .concat("/") - .concat(MakeUniqueFilename(std::string(name))) - .str(); + *filepath = io::JoinPath(dir, MakeUniqueFilename(std::string(name))); // Try to open the file and generate a raw_ostream. std::unique_ptr file; @@ -151,25 +150,24 @@ std::string DumpMlirOpToFile(llvm::StringRef name, mlir::Operation* op, return filepath; } -const char* GetDumpDirFromEnvVar() { +std::string GetDumpDirFromEnvVar() { const char* prefix_env = getenv("TF_DUMP_GRAPH_PREFIX"); if (!prefix_env) { LOG(WARNING) << "Failed to dump MLIR module because dump location is not " << " specified through TF_DUMP_GRAPH_PREFIX environment variable."; - return nullptr; + return ""; } - if (absl::EqualsIgnoreCase(prefix_env, "sponge")) { - const char* tmp_dir = getenv("TEST_UNDECLARED_OUTPUTS_DIR"); - if (!tmp_dir) { - LOG(WARNING) << "TF_DUMP_GRAPH_PREFIX=sponge but " - "TEST_UNDECLARED_OUTPUT_DIRS is not set"; - return nullptr; - } - return tmp_dir; + std::string result = prefix_env; + + if (absl::EqualsIgnoreCase(result, "sponge") && + !io::GetTestUndeclaredOutputsDir(&result)) { + LOG(WARNING) << "TF_DUMP_GRAPH_PREFIX=sponge but " + "TEST_UNDECLARED_OUTPUT_DIRS is not set"; + return ""; } - return prefix_env; + return result; } } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h index 7c25a809089..14c0d1f0b6e 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h @@ -54,7 +54,7 @@ std::string DumpMlirOpToFile(llvm::StringRef name, mlir::Operation* op, // Default is reading from TF_DUMP_GRAPH_PREFIX, and if the string is 'sponge' // read from TEST_UNDECLARED_OUTPUTS_DIR. Returns nullptr if the directory // cannot be determined and generates a warning message. -const char* GetDumpDirFromEnvVar(); +std::string GetDumpDirFromEnvVar(); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/error_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/error_util.cc index 2181f4f8c9b..60646ae764e 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/error_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/error_util.cc @@ -63,21 +63,21 @@ Status StatusScopedDiagnosticHandler::Combine(Status status) { } LogicalResult StatusScopedDiagnosticHandler::handler(Diagnostic* diag) { -#ifndef NDEBUG + // Non-error diagnostic are ignored when VLOG isn't enabled. + if (diag->getSeverity() != DiagnosticSeverity::Error && VLOG_IS_ON(1)) + return success(); + size_t current_diag_str_size_ = diag_str_.size(); -#endif // Emit the diagnostic and flush the stream. emitDiagnostic(*diag); diag_stream_.flush(); -#ifndef NDEBUG // Emit non-errors to VLOG instead of the internal status. if (diag->getSeverity() != DiagnosticSeverity::Error) { VLOG(1) << diag_str_.substr(current_diag_str_size_); diag_str_.resize(current_diag_str_size_); } -#endif // Return failure to signal propagation if necessary. return failure(propagate_); diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index 8a2b18cd906..bf2d8103872 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -131,12 +131,37 @@ cc_library( ) cc_library( - name = "lhlo_legalize_to_affine", - srcs = ["transforms/lhlo_legalize_to_affine.cc"], + name = "map_xla_to_scalar_op", + srcs = [], hdrs = ["transforms/map_xla_to_scalar_op.h"], deps = [ ":hlo", ":lhlo", + "@llvm-project//llvm:support", + "@llvm-project//mlir:StandardOps", + ], +) + +cc_library( + name = "hlo_shape_derivation", + srcs = [], + hdrs = ["transforms/hlo_shape_derivation.h"], + deps = [ + ":hlo", + ":lhlo", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "lhlo_legalize_to_affine", + srcs = ["transforms/lhlo_legalize_to_affine.cc"], + deps = [ + ":hlo", + ":lhlo", + ":map_xla_to_scalar_op", "//tensorflow/compiler/xla:status", "@com_google_absl//absl/memory", "@llvm-project//llvm:support", @@ -151,13 +176,12 @@ cc_library( cc_library( name = "xla_legalize_to_linalg", srcs = ["transforms/xla_legalize_to_linalg.cc"], - hdrs = ["transforms/map_xla_to_scalar_op.h"], deps = [ ":hlo", ":lhlo", + ":map_xla_to_scalar_op", "@com_google_absl//absl/memory", "@llvm-project//llvm:support", - "@llvm-project//mlir:AllPassesAndDialects", # TODO: only Linalg is needed "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgOps", "@llvm-project//mlir:Pass", @@ -170,10 +194,10 @@ cc_library( cc_library( name = "lhlo_legalize_to_gpu", srcs = ["transforms/lhlo_legalize_to_gpu.cc"], - hdrs = ["transforms/map_xla_to_scalar_op.h"], deps = [ ":hlo", ":lhlo", + ":map_xla_to_scalar_op", "@com_google_absl//absl/memory", "@llvm-project//llvm:support", "@llvm-project//mlir:GPUDialect", @@ -193,7 +217,7 @@ cc_library( deps = [ ":lhlo", "@com_google_absl//absl/memory", - "@llvm-project//mlir:AllPassesAndDialects", # TODO: only Linalg is needed + "@llvm-project//llvm:support", "@llvm-project//mlir:LinalgOps", "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:Pass", @@ -207,9 +231,9 @@ cc_library( srcs = ["transforms/hlo_legalize_to_lhlo.cc"], deps = [ ":hlo", + ":hlo_shape_derivation", ":lhlo", "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:StandardOps", @@ -306,6 +330,7 @@ cc_library( deps = [ ":hlo", "@llvm-project//mlir:IR", + "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Transforms", ], ) @@ -334,6 +359,7 @@ cc_library( ":xla_unfuse_batch_norm", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Transforms", ], alwayslink = 1, diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc index 545bcb4f44f..bc9bdf49a39 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc @@ -370,6 +370,22 @@ StatusOr HloFunctionImporter::ImportInstruction( Convert(interior_padding)) .getOperation(); } + case HloOpcode::kScatter: { + auto scatter = static_cast(instruction); + attributes.push_back( + ConvertScatterDimensionNumbers(scatter->scatter_dimension_numbers())); + attributes.push_back(builder_->getNamedAttr( + "indices_are_sorted", + builder_->getBoolAttr(scatter->indices_are_sorted()))); + attributes.push_back(builder_->getNamedAttr( + "unique_indices", builder_->getBoolAttr(scatter->unique_indices()))); + + auto scatter_op = func_builder->create( + loc, result_type, operands, attributes); + TF_RETURN_IF_ERROR(ImportComputation(scatter->to_apply(), + &scatter_op.update_computation())); + return scatter_op.getOperation(); + } case HloOpcode::kSetDimensionSize: { attributes.push_back(builder_->getNamedAttr( "dimension", builder_->getIntegerAttr(builder_->getIntegerType(32), @@ -385,6 +401,16 @@ StatusOr HloFunctionImporter::ImportInstruction( ConvertDimensions(instruction->slice_strides())) .getOperation(); } + case HloOpcode::kSort: { + auto sort_instruction = static_cast(instruction); + auto sort_op = func_builder->create( + loc, result_type, operands, + builder_->getI64IntegerAttr(sort_instruction->sort_dimension()), + builder_->getBoolAttr(sort_instruction->is_stable())); + TF_RETURN_IF_ERROR(ImportComputation(sort_instruction->to_apply(), + &sort_op.comparator())); + return sort_op.getOperation(); + } case HloOpcode::kConditional: { llvm::SmallVector rets; TF_RETURN_IF_ERROR(GetMlirTypes( @@ -834,6 +860,22 @@ mlir::NamedAttribute HloFunctionImporter::ConvertGatherDimensionNumbers( return builder_->getNamedAttr("dimension_numbers", attr); } +mlir::NamedAttribute HloFunctionImporter::ConvertScatterDimensionNumbers( + const xla::ScatterDimensionNumbers& dnums) { + std::vector update_window_dims(dnums.update_window_dims().begin(), + dnums.update_window_dims().end()); + std::vector inserted_window_dims( + dnums.inserted_window_dims().begin(), dnums.inserted_window_dims().end()); + std::vector scatter_dims_to_operand_dims( + dnums.scatter_dims_to_operand_dims().begin(), + dnums.scatter_dims_to_operand_dims().end()); + auto attr = mlir::xla_hlo::ScatterDimensionNumbers::get( + Convert(update_window_dims), Convert(inserted_window_dims), + Convert(scatter_dims_to_operand_dims), + builder_->getI64IntegerAttr(dnums.index_vector_dim()), context_); + return builder_->getNamedAttr("scatter_dimension_numbers", attr); +} + mlir::NamedAttribute HloFunctionImporter::ConvertSourceTargetPairs( const std::vector>& source_target_pairs) { diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.h b/tensorflow/compiler/mlir/xla/hlo_function_importer.h index d373e88e1c0..93c8e6e818c 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.h +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.h @@ -121,6 +121,10 @@ class HloFunctionImporter { mlir::NamedAttribute ConvertGatherDimensionNumbers( const xla::GatherDimensionNumbers& dnums); + // Converts the scatter dimensions to attributes. + mlir::NamedAttribute ConvertScatterDimensionNumbers( + const xla::ScatterDimensionNumbers& dnums); + // Converts XLA instruction source target pairs to MLIR attribute. mlir::NamedAttribute ConvertSourceTargetPairs( const std::vector>& diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc index 23c25e7d0cd..41ef8690735 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc @@ -177,29 +177,18 @@ void ConstOp::build(Builder* builder, OperationState& result, Attribute value) { // IotaOp //===----------------------------------------------------------------------===// -OpFoldResult IotaOp::fold(ArrayRef operands) { - const auto output_type = getResult().getType().cast(); - const auto output_size = output_type.getNumElements(); - const auto dimension = iota_dimension().getSExtValue(); - const auto max_dim_size = output_type.getDimSize(dimension); - int bitwidth = output_type.getElementType().getIntOrFloatBitWidth(); +static LogicalResult Verify(IotaOp op) { + auto shape = op.getType().cast(); + if (!shape.hasRank()) return success(); - llvm::SmallVector values; - values.reserve(output_size); + if (shape.getRank() == 0) + return op.emitOpError() << "does not support scalars."; - int64_t increase_stride = output_size; - for (int i = 0; i <= dimension; i++) { - increase_stride /= output_type.getDimSize(i); - } - - int64_t current_value = 0; - for (int i = 0; i < output_size; i++) { - int64_t value = (current_value / increase_stride) % max_dim_size; - values.push_back(APInt(bitwidth, value)); - ++current_value; - } - - return DenseIntElementsAttr::get(output_type, values); + auto iota_dimension = op.iota_dimension().getSExtValue(); + if (iota_dimension >= shape.getRank() || iota_dimension < 0) + return op.emitOpError() << "iota dimension cannot go beyond the output " + "rank or be negative."; + return success(); } //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td index 869995fe68f..269e1cc8897 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td @@ -60,6 +60,13 @@ def HLO_Tuple : NestedTupleOf<[HLO_Tensor, HLO_Token]>; def HLO_TensorOrTuple : AnyTypeOf<[HLO_Tensor, HLO_Tuple]>; +// Dynamic representation of a shape vector as a tensor. Ideally this would be +// an index type (as it stores indices) but that is currently disallowed in +// MLIR. +def HLO_DimensionTensor : ShapedContainerType< + [AnyInteger], And<[IsTensorTypePred, HasAnyRankOfPred<[1]>]>, + "a 1D tensor of dimensions">; + // In general, static shaped tensor constraints should be avoided unless // it is for a legacy op which is only correct with static shapes. def HLO_StaticShapeTensor : StaticShapeTensorOf<[ @@ -113,9 +120,7 @@ def HLO_ConstOp : HLO_Op<"constant", [NoSideEffect]>, BASE_HLO_ConstOp { def HLO_IotaOp : HLO_Op<"iota", [NoSideEffect]>, BASE_HLO_IotaOp { let arguments = (ins I64Attr:$iota_dimension); - let results = (outs HLO_Tensor:$output); - - let hasFolder = 1; + let results = (outs HLO_IntFpOrComplexTensor:$output); // TODO(b/130357376): Iota has special conversion logic to HLO. let hasCustomHLOConverter = 1; @@ -770,11 +775,39 @@ def HLO_BroadcastInDimOp : HLO_Op<"broadcast_in_dim", let hasCustomHLOConverter = 1; } +def HLO_ScalarsToDimensionTensorOp : HLO_Op<"scalars_to_dimension_tensor", + [SameOperandsElementType, NoSideEffect]> { + string summary = "Converts a sequence of scalars into a 1d tensor."; + + string description = [{ + This is a useful operation that is currently missing in Standard. Used to + compute shape arguments to dynamic operations. + }]; + + let arguments = (ins Variadic); + let results = (outs HLO_DimensionTensor); + + // Cannot be exported to legacy formats. + let hasCustomHLOConverter = 1; +} + def HLO_DynamicBroadcastInDimOp : HLO_Op<"dynamic_broadcast_in_dim", - [NoSideEffect]>, BASE_HLO_DynamicBroadcastInDimOp { + [NoSideEffect]> { + string summary = "Broadcast a tensor into the given dynamic shape by adding dimensions."; + string description = [{ + This is a generalization of the BroadcastInDimOp which accepts its output + dimensions as an argument. It should eventually supercede the statically + shaped original, but is being phased as a separate op in order to support + compatibility with lowerings and translations that precede dynamic + shapes. + + Note that the `broadcast_dimensions` attribute is optional and if omitted, + it is assumed to be an ordered, right-aligned mapping from input to + output dimensions. + }]; let arguments = (ins HLO_Tensor:$operand, - HLO_BASE_DimensionTensor:$output_dimensions, + HLO_DimensionTensor:$output_dimensions, BroadcastDimAttr:$broadcast_dimensions ); diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td index cace05a0913..64303e86fe0 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td @@ -27,13 +27,6 @@ def HLO_Pred : TypeAlias; // matching the matrix to dimensions 1 and 2 of the cuboid. def BroadcastDimAttr : OptionalAttr; -// Dynamic representation of a shape vector as a tensor. Ideally this would be -// an index type (as it stores indices) but that is currently disallowed in -// MLIR. -def HLO_BASE_DimensionTensor : ShapedContainerType< - [AnyInteger], And<[IsTensorTypePred, HasAnyRankOfPred<[1]>]>, - "a 1D tensor of dimensions">; - //===----------------------------------------------------------------------===// // XLA nullary op definitions. //===----------------------------------------------------------------------===// @@ -817,22 +810,6 @@ class BASE_HLO_BroadcastInDimOp { }]; } -class BASE_HLO_DynamicBroadcastInDimOp { - string summary = "Broadcast a tensor into the given dynamic shape by adding dimensions."; - - string description = [{ - This is a generalization of the BroadcastInDimOp which accepts its output - dimensions as an argument. It should eventually supercede the statically - shaped original, but is being phased as a separate op in order to support - compatibility with lowerings and translations that precede dynamic - shapes. - - Note that the `broadcast_dimensions` attribute is optional and if omitted, - it is assumed to be an ordered, right-aligned mapping from input to - output dimensions. - }]; -} - class BASE_HLO_CholeskyOp { string summary = "Cholesky operator"; diff --git a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td index 411c8a89396..794fee181a6 100644 --- a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td @@ -242,16 +242,6 @@ def LHLO_BroadcastInDimOp : LHLO_Op<"broadcast_in_dim", ); } -def HLO_DynamicBroadcastInDimOp : LHLO_Op<"dynamic_broadcast_in_dim", - [NoSideEffect]>, BASE_HLO_DynamicBroadcastInDimOp { - let arguments = (ins - LHLO_Buffer:$operand, - HLO_BASE_DimensionTensor:$output_dimensions, - LHLO_Buffer:$output, - BroadcastDimAttr:$broadcast_dimensions - ); -} - def LHLO_ClampOp : LHLO_Op<"clamp", []>, BASE_HLO_ClampOp { let arguments = (ins LHLO_Buffer:$min, diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index c45baef855b..8fa7d809024 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -533,6 +533,12 @@ LogicalResult ExportXlaOp(BroadcastInDimOp op, OpLoweringContext ctx) { return success(); } +LogicalResult ExportXlaOp(ScalarsToDimensionTensorOp op, + OpLoweringContext ctx) { + // This op has no expression in the legacy export format. + return failure(); +} + LogicalResult ExportXlaOp(DynamicBroadcastInDimOp op, OpLoweringContext ctx) { // This op has no expression in the legacy export format. return failure(); diff --git a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir index fa39b77918a..2232063fd6a 100644 --- a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir @@ -49,6 +49,14 @@ func @complex_collapse_fold(%arg0: tensor<4xcomplex>) -> tensor<4xcomplex> } +// CHECK-LABEL: @iota_not_lowered_to_constant +func @iota_not_lowered_to_constant() -> tensor<4xi32> { + // CHECK: [[RESULT:%.*]] = "xla_hlo.iota" + // CHECK: return [[RESULT]] + %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32> + return %0 : tensor<4xi32> +} + // CHECK-LABEL: @unary_einsum func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> { // CHECK: %[[ONE:.*]] = xla_hlo.constant dense<1.000000e+00> : tensor diff --git a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir index 7ed4e97053d..be6f0e6a949 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir +++ b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt -hlo-legalize-to-lhlo -lhlo-redundant-copies-removal %s -o - | FileCheck %s --dump-input=always +// RUN: tf-opt -hlo-legalize-to-lhlo -lhlo-redundant-copies-removal -split-input-file %s -o - | FileCheck %s -dump-input-on-failure // CHECK-LABEL: func @attrs func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { @@ -11,6 +11,8 @@ func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { return } +// ----- + // CHECK-LABEL: func @func_op func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>) @@ -20,6 +22,8 @@ func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: "xla_lhlo.terminator"() : () -> () } +// ----- + // CHECK-LABEL: func @func_op_long func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>) @@ -45,6 +49,8 @@ func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: "xla_lhlo.terminator"() : () -> () } +// ----- + // CHECK-LABEL: func @remove_lhlo_copy_op_created_from_tensor_store func @remove_lhlo_copy_op_created_from_tensor_store(%arg0: tensor, %arg1: tensor, %arg2: memref) { %0 = "xla_hlo.max"(%arg0, %arg1) : (tensor, tensor) -> tensor @@ -58,6 +64,8 @@ func @remove_lhlo_copy_op_created_from_tensor_store(%arg0: tensor, %arg1: t // CHECK-NOT: dealloc %[[ALLOC_OPERAND]] : memref // CHECK: "xla_lhlo.terminator"() : () -> () +// ----- + // CHECK-LABEL: func @fusion func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>, %summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) { @@ -77,6 +85,8 @@ func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>, "xla_lhlo.terminator"() : () -> () } +// ----- + // CHECK-LABEL: func @copy func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> @@ -87,6 +97,8 @@ func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { return } +// ----- + // CHECK-LABEL: func @exp func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> @@ -97,6 +109,8 @@ func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { return } +// ----- + // CHECK-LABEL: func @select func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) { @@ -110,6 +124,8 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>, return } +// ----- + // CHECK-LABEL: func @compare func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xi1>) { %tensor_lhs = tensor_load %lhs : memref<2x2xf32> @@ -122,6 +138,8 @@ func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2x return } +// ----- + // CHECK-LABEL: func @broadcast func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) { %tensor_operand = tensor_load %operand : memref<5xf32> @@ -133,6 +151,34 @@ func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) { return } +// ----- + +// CHECK-LABEL: func @dyn_broadcast +func @dyn_broadcast(%operand: memref) { + %tensor_operand = tensor_load %operand : memref + %shape = "compute.shape"() : () -> tensor<3xi64> + %tensor_result = "xla_hlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) + {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} + : (tensor, tensor<3xi64>) -> tensor + // CHECK: %[[SHAPE:.*]] = "compute.shape"() + // CHECK: %[[C0:.*]] = constant 0 : index + // CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<3xi64> + // CHECK: %[[IC0:.*]] = index_cast %[[EL0]] : i64 to index + // CHECK: %[[C1:.*]] = constant 1 : index + // CHECK: %[[EL1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<3xi64> + // CHECK: %[[IC1:.*]] = index_cast %[[EL1]] : i64 to index + // CHECK: %[[C2:.*]] = constant 2 : index + // CHECK: %[[EL2:.*]] = extract_element %[[SHAPE]][%[[C2]]] : tensor<3xi64> + // CHECK: %[[IC2:.*]] = index_cast %[[EL2]] : i64 to index + // CHECK: %[[RESULT:.*]] = alloc(%[[IC0]], %[[IC1]], %[[IC2]]) + // CHECK-NEXT: "xla_lhlo.broadcast_in_dim"(%{{.*}}, %[[RESULT]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} + // Do not store the value back to avoid the tensor-store being rewritten to + // a copy into the pre-allocated argument. + return +} + +// ----- + // CHECK-LABEL: func @iota func @iota(%result: memref<10xi32>) { %tensor_result = "xla_hlo.iota"() @@ -142,6 +188,8 @@ func @iota(%result: memref<10xi32>) { return } +// ----- + // CHECK-LABEL: func @abs func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> @@ -152,6 +200,8 @@ func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { return } +// ----- + // CHECK-LABEL: func @ceil func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> @@ -162,6 +212,8 @@ func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { return } +// ----- + // CHECK-LABEL: func @convert func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> @@ -172,6 +224,8 @@ func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { return } +// ----- + // CHECK-LABEL: func @cos func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> @@ -182,6 +236,8 @@ func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { return } +// ----- + // CHECK-LABEL: func @neg func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> @@ -192,6 +248,8 @@ func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { return } +// ----- + // CHECK-LABEL: func @sign func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> @@ -202,6 +260,8 @@ func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { return } +// ----- + // CHECK-LABEL: func @tanh func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> @@ -212,6 +272,8 @@ func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { return } +// ----- + // CHECK-LABEL: func @remainder func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_lhs = tensor_load %lhs : memref<2x2xf32> @@ -222,3 +284,47 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x tensor_store %tensor_result, %result : memref<2x2xf32> return } + +// ----- + +// Dynamic shape binary element-wise operation. +// CHECK-LABEL: func @add_dyn +func @add_dyn(%lhs: tensor, %rhs: tensor) { + %result = "xla_hlo.add"(%lhs, %rhs) + : (tensor, tensor) -> tensor + // CHECK: %[[DIM0:.*]] = dim %arg0, 0 : memref + // CHECK: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64 + // CHECK: %[[DIM1:.*]] = dim %arg0, 1 : memref + // CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64 + // CHECK: %[[SHAPE:.*]] = "xla_hlo.scalars_to_dimension_tensor"(%[[IC0]], %[[IC1]]) : (i64, i64) -> tensor<2xi64> + // CHECK: %[[C0:.*]] = constant 0 : index + // CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<2xi64> + // CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index + // CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<2xi64> + // CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index + // CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]]) + // CHECK: "xla_lhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref, memref, memref) -> () + return +} + +// ----- + +// Dynamic shape unary element-wise operation. +// CHECK-LABEL: func @tanh_dyn +func @tanh_dyn(%arg0: tensor) { + %result = "xla_hlo.tanh"(%arg0) + : (tensor) -> tensor + // CHECK: %[[DIM0:.*]] = dim %arg0, 0 : memref + // CHECK: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64 + // CHECK: %[[DIM1:.*]] = dim %arg0, 1 : memref + // CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64 + // CHECK: %[[SHAPE:.*]] = "xla_hlo.scalars_to_dimension_tensor"(%[[IC0]], %[[IC1]]) : (i64, i64) -> tensor<2xi64> + // CHECK: %[[C0:.*]] = constant 0 : index + // CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<2xi64> + // CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index + // CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<2xi64> + // CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index + // CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]]) + // CHECK: "xla_lhlo.tanh"(%arg0, %[[RESULT]]) : (memref, memref) -> () + return +} diff --git a/tensorflow/compiler/mlir/xla/tests/iota.mlir b/tensorflow/compiler/mlir/xla/tests/iota.mlir deleted file mode 100644 index 65b9f73ba67..00000000000 --- a/tensorflow/compiler/mlir/xla/tests/iota.mlir +++ /dev/null @@ -1,61 +0,0 @@ -// RUN: tf-opt %s -split-input-file -xla-legalize-to-std | FileCheck %s - -// ----- - -// CHECK-LABEL: func @iota.const.1() -> tensor<4xi32> { -func @iota.const.1() -> tensor<4xi32> { - // CHECK-NEXT: %[[CST:.*]] = constant dense<[0, 1, 2, 3]> : tensor<4xi32> - %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32> - // CHECK-NEXT: return %[[CST]] : tensor<4xi32> - return %0 : tensor<4xi32> -} - -// ----- - -// CHECK-LABEL: func @iota.const.2() -> tensor<2x4xi32> { -func @iota.const.2() -> tensor<2x4xi32> { - // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[}}0, 0, 0, 0], [1, 1, 1, 1]]> : tensor<2x4xi32> - %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x4xi32> - // CHECK-NEXT: return %[[CST]] : tensor<2x4xi32> - return %0 : tensor<2x4xi32> -} - -// ----- - -// CHECK-LABEL: func @iota.const.3() -> tensor<2x4xi32> { -func @iota.const.3() -> tensor<2x4xi32> { - // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[}}0, 1, 2, 3], [0, 1, 2, 3]]> : tensor<2x4xi32> - %0 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x4xi32> - // CHECK-NEXT: return %[[CST]] : tensor<2x4xi32> - return %0 : tensor<2x4xi32> -} - -// ----- - -// CHECK-LABEL: func @iota.const.4() -> tensor<2x3x4xi32> { -func @iota.const.4() -> tensor<2x3x4xi32> { - // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0{{\]\]}}, {{\[\[}}1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]]> : tensor<2x3x4xi32> - %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x3x4xi32> - // CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32> - return %0 : tensor<2x3x4xi32> -} - -// ----- - -// CHECK-LABEL: func @iota.const.5() -> tensor<2x3x4xi32> { -func @iota.const.5() -> tensor<2x3x4xi32> { - // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2{{\]\]}}, {{\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2]]]> : tensor<2x3x4xi32> - %0 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x3x4xi32> - // CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32> - return %0 : tensor<2x3x4xi32> -} - -// ----- - -// CHECK-LABEL: func @iota.const.6() -> tensor<2x3x4xi32> { -func @iota.const.6() -> tensor<2x3x4xi32> { - // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3{{\]\]}}, {{\[\[}}0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]]]> : tensor<2x3x4xi32> - %0 = "xla_hlo.iota"() {iota_dimension = 2 : i64} : () -> tensor<2x3x4xi32> - // CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32> - return %0 : tensor<2x3x4xi32> -} diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 67f085ef9a0..d80722e2865 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -3308,7 +3308,7 @@ func @random_shuffle_1D_10240(%input: tensor<10240xf32>) -> tensor<10240xf32> { // CHECK-LABEL: @random_shuffle_3D // CHECK-SAME: [[INPUT:%.*]]: tensor<4x?x16xf32> func @random_shuffle_3D(%input: tensor<4x?x16xf32>) -> tensor<4x?x16xf32> { - // CHECK: [[INDICES:%.*]] = "xla_hlo.iota"() {iota_dimension = 4 : i64} : () -> tensor<4xi32> + // CHECK: [[INDICES:%.*]] = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32> // CHECK: [[RNG_SHAPE:%.*]] = xla_hlo.constant dense<4> : tensor<1xi64> // CHECK: [[RNG_LOWER:%.*]] = xla_hlo.constant dense<0> : tensor diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir index 1d2cf767939..f56174ae075 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir @@ -135,3 +135,51 @@ func @float_constant() -> (tensor, tensor<2x3xf32>, tensor<2x3xf32>) { return %0, %1, %2: tensor, tensor<2x3xf32>, tensor<2x3xf32> } +// Test Iota lowering to constant +// CHECK-LABEL: func @iota.const.1() -> tensor<4xi32> { +func @iota.const.1() -> tensor<4xi32> { + // CHECK-NEXT: %[[CST:.*]] = constant dense<[0, 1, 2, 3]> : tensor<4xi32> + %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32> + // CHECK-NEXT: return %[[CST]] : tensor<4xi32> + return %0 : tensor<4xi32> +} + +// CHECK-LABEL: func @iota.const.2() -> tensor<2x4xi32> { +func @iota.const.2() -> tensor<2x4xi32> { + // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[}}0, 0, 0, 0], [1, 1, 1, 1]]> : tensor<2x4xi32> + %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x4xi32> + // CHECK-NEXT: return %[[CST]] : tensor<2x4xi32> + return %0 : tensor<2x4xi32> +} + +// CHECK-LABEL: func @iota.const.3() -> tensor<2x4xi32> { +func @iota.const.3() -> tensor<2x4xi32> { + // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[}}0, 1, 2, 3], [0, 1, 2, 3]]> : tensor<2x4xi32> + %0 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x4xi32> + // CHECK-NEXT: return %[[CST]] : tensor<2x4xi32> + return %0 : tensor<2x4xi32> +} + +// CHECK-LABEL: func @iota.const.4() -> tensor<2x3x4xi32> { +func @iota.const.4() -> tensor<2x3x4xi32> { + // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0{{\]\]}}, {{\[\[}}1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]]> : tensor<2x3x4xi32> + %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x3x4xi32> + // CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32> + return %0 : tensor<2x3x4xi32> +} + +// CHECK-LABEL: func @iota.const.5() -> tensor<2x3x4xi32> { +func @iota.const.5() -> tensor<2x3x4xi32> { + // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2{{\]\]}}, {{\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2]]]> : tensor<2x3x4xi32> + %0 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x3x4xi32> + // CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32> + return %0 : tensor<2x3x4xi32> +} + +// CHECK-LABEL: func @iota.const.6() -> tensor<2x3x4xi32> { +func @iota.const.6() -> tensor<2x3x4xi32> { + // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3{{\]\]}}, {{\[\[}}0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]]]> : tensor<2x3x4xi32> + %0 = "xla_hlo.iota"() {iota_dimension = 2 : i64} : () -> tensor<2x3x4xi32> + // CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32> + return %0 : tensor<2x3x4xi32> +} diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir index 7f9e8c19780..7f7e37ebe66 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir @@ -1,32 +1,57 @@ -// RUN: tf-opt -lhlo-fuse-linalg %s -o - | FileCheck %s +// RUN: tf-opt -lhlo-fuse-linalg %s -o - | FileCheck %s --dump-input=always +// RUN: tf-opt -lhlo-fuse-linalg=tile-sizes=2,3 %s -o - | FileCheck %s -check-prefix=TILED --dump-input-on-failure +// RUN: tf-opt -lhlo-fuse-linalg=use-parallel-loops %s -o - | FileCheck %s -check-prefix=PLOOP --dump-input-on-failure + #map0 = affine_map<(d0, d1) -> (d0, d1)> #pointwise_2d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} -func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>, - %summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) { - %temp_result = alloc() {temp = true} : memref<2x2xf32> +func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, + %summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) { + %temp_result = alloc() {temp = true} : memref<6x6xf32> linalg.generic #pointwise_2d_trait %summand_1, %summand_2, %temp_result { ^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32): %out = addf %summand_1_in, %summand_2_in : f32 linalg.yield %out : f32 - } : memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32> + } : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32> linalg.generic #pointwise_2d_trait %temp_result, %multiplier, %result { ^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32): %out = mulf %temp_result_in, %multiplier_in : f32 linalg.yield %out : f32 - } : memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32> - dealloc %temp_result : memref<2x2xf32> + } : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32> + dealloc %temp_result : memref<6x6xf32> "xla_lhlo.terminator"() : () -> () } // CHECK-LABEL: func @fusion -// CHECK-NOT: linalg.generic -// CHECK: loop.for -// CHECK: loop.for -// CHECK-NOT: loop.for -// CHECK: linalg.generic -// CHECK: addf -// CHECK: linalg.generic -// CHECK: mulf +// CHECK: %[[C1:.*]] = constant 1 +// CHECK-NOT: linalg.generic +// CHECK: loop.for {{.*}} step %[[C1]] +// CHECK: loop.for {{.*}} step %[[C1]] +// CHECK-NOT: loop.for +// CHECK: linalg.generic +// CHECK: addf +// CHECK: linalg.generic +// CHECK: mulf + +// TILED-LABEL: func @fusion +// TILED-DAG: %[[C2:.*]] = constant 2 +// TILED-DAG: %[[C3:.*]] = constant 3 +// TILED-NOT: linalg.generic +// TILED: loop.for {{.*}} step %[[C2]] +// TILED: loop.for {{.*}} step %[[C3]] +// TILED-NOT: loop.for +// TILED: linalg.generic +// TILED: addf +// TILED: linalg.generic +// TILED: mulf + +// PLOOP-LABEL: func @fusion +// PLOOP-NOT: linalg.generic +// PLOOP: loop.parallel +// PLOOP-NOT: loop.parallel +// PLOOP: linalg.generic +// PLOOP: addf +// PLOOP: linalg.generic +// PLOOP: mulf func @fusion_of_three(%arg0: memref<100x10xf32>, %arg1: memref<100xf32>, @@ -67,12 +92,36 @@ func @fusion_of_three(%arg0: memref<100x10xf32>, return } // CHECK-LABEL: func @fusion -// CHECK-NOT: linalg.generic -// CHECK: loop.for -// CHECK: loop.for -// CHECK-NOT: loop.for -// CHECK: linalg.generic -// CHECK: linalg.generic -// CHECK: subf -// CHECK: linalg.generic -// CHECK: exp +// CHECK: %[[C1:.*]] = constant 1 +// CHECK-NOT: linalg.generic +// CHECK: loop.for {{.*}} step %[[C1]] +// CHECK: loop.for {{.*}} step %[[C1]] +// CHECK-NOT: loop.for +// CHECK: linalg.generic +// CHECK: linalg.generic +// CHECK: subf +// CHECK: linalg.generic +// CHECK: exp + +// TILED-LABEL: func @fusion_of_three +// TILED-DAG: %[[C2:.*]] = constant 2 +// TILED-DAG: %[[C3:.*]] = constant 3 +// TILED-NOT: linalg.generic +// TILED: loop.for {{.*}} step %[[C2]] +// TILED: loop.for {{.*}} step %[[C3]] +// TILED-NOT: loop.for +// TILED: linalg.generic +// TILED: linalg.generic +// TILED: subf +// TILED: linalg.generic +// TILED: exp + +// PLOOP-LABEL: func @fusion_of_three +// PLOOP-NOT: linalg.generic +// PLOOP: loop.parallel +// PLOOP-NOT: loop.parallel +// PLOOP: linalg.generic +// PLOOP: linalg.generic +// PLOOP: subf +// PLOOP: linalg.generic +// PLOOP: exp diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir index 19e16ceab44..78f0d9ffb18 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir @@ -179,6 +179,22 @@ func @iota(%out: memref<7x10xi64>) { // ----- +// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)> +// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> +// CHECK-LABEL: func @dynamic_broadcast +func @dynamic_broadcast(%operand: memref, + %result: memref) { + "xla_lhlo.broadcast_in_dim"(%operand, %result) + {broadcast_dimensions = dense<[4,0,2]> : tensor<3xi64>} + : (memref, memref) -> () + return +} +// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] +// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 + +// ----- + // CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, 0)> // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> // CHECK-LABEL: func @broadcast diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir index 00ad25503d7..9f181d574c0 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir @@ -152,13 +152,6 @@ func @broadcast_in_dim_zero_rank_memref(%arg0: memref, %out: memref<1x2x3xi // ----- -// CHECK-LABEL: func @dynamic_broadcast_in_dim_memref -func @dynamic_broadcast_in_dim_memref(%arg0: memref, %out: memref, %shape: tensor<3xi64>) -> () { - "xla_lhlo.dynamic_broadcast_in_dim"(%arg0, %shape, %out) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (memref, tensor<3xi64>, memref) -> () - return -} - -// ----- // CHECK-LABEL: func @reduce_memref func @reduce_memref(%input: memref<10xf32>, %init: memref, %out: memref<1xf32>) -> () { diff --git a/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir b/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir index 53781158d58..682b153d474 100644 --- a/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir +++ b/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir @@ -235,3 +235,39 @@ func @compareBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tenso %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xi1> return %0 : tensor<1x4xi1> } + +// ----- + +// CHECK-LABEL: @dynamicBroadcastAdd +func @dynamicBroadcastAdd(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK-NEXT: %[[DIM0:.*]] = dim %arg0, 0 : tensor + // CHECK-NEXT: %[[DIM0C:.*]] = index_cast %[[DIM0]] : index to i32 + // CHECK-NEXT: %c1 = constant 1 : index + // CHECK-NEXT: %[[DIM1_0:.*]] = dim %arg0, 1 : tensor + // CHECK-NEXT: %[[DIM1_1:.*]] = dim %arg1, 0 : tensor + // CHECK-NEXT: %[[CMPI:.*]] = cmpi "eq", %[[DIM1_0]], %c1 : index + // CHECK-NEXT: %[[SEL:.*]] = select %[[CMPI]], %[[DIM1_0]], %[[DIM1_1]] : index + // CHECK-NEXT: %[[DIM1C:.*]] = index_cast %[[SEL]] : index to i32 + // CHECK-NEXT: %[[SHAPE:.*]] = "xla_hlo.scalars_to_dimension_tensor"(%[[DIM0C]], %[[DIM1C]]) : (i32, i32) -> tensor<2xi32> + // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[SHAPE]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xi32>) -> tensor + // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[SHAPE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xi32>) -> tensor + // CHECK-NEXT: xla_hlo.add %[[BROADCAST0]], %[[BROADCAST1]] : tensor + %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @dynamicBroadcastAddScalar +func @dynamicBroadcastAddScalar(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK-NEXT: %[[DIM0:.*]] = dim %arg0, 0 : tensor + // CHECK-NEXT: %[[DIM0C:.*]] = index_cast %[[DIM0]] : index to i32 + // CHECK-NEXT: %[[DIM1:.*]] = dim %arg0, 1 : tensor + // CHECK-NEXT: %[[DIM1C:.*]] = index_cast %[[DIM1]] : index to i32 + // CHECK-NEXT: %[[SHAPE:.*]] = "xla_hlo.scalars_to_dimension_tensor"(%[[DIM0C]], %[[DIM1C]]) : (i32, i32) -> tensor<2xi32> + // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[SHAPE]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xi32>) -> tensor + // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[SHAPE]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor<2xi32>) -> tensor + // CHECK-NEXT: xla_hlo.add %[[BROADCAST0]], %[[BROADCAST1]] : tensor + %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor + return %0 : tensor +} diff --git a/tensorflow/compiler/mlir/xla/tests/ops.mlir b/tensorflow/compiler/mlir/xla/tests/ops.mlir index 3c91f1d7dd0..7e2845daa06 100644 --- a/tensorflow/compiler/mlir/xla/tests/ops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/ops.mlir @@ -292,6 +292,22 @@ func @infeed_non_token_second_result(%token: !xla_hlo.token) -> tuple tensor { + // expected-error@+1 {{does not support scalars}} + %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor + return %0 : tensor +} + +// ----- + +func @iota_invalid_iota_dimension() -> tensor<4xi32> { + // expected-error@+1 {{iota dimension cannot go beyond the output rank or be negative}} + %0 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<4xi32> + return %0 : tensor<4xi32> +} + +// ----- + func @map_mismatched_args(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // expected-error@+1 {{expects number of operands to match the arity of map computation, but got: 2 and 1}} %0 = "xla_hlo.map"(%arg0, %arg1) ( { diff --git a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt index b9f88ef699c..b2dec8c950f 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt @@ -716,6 +716,37 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] { ROOT %Arg_0.1 = f32[] parameter(0) } +// Test scatter +%update_computation { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %sum = f32[] add(f32[] %lhs, f32[] %rhs) +} + +%test_scatter { + %input_tensor = f32[200,100,300] parameter(0) + %scatter_indices = s64[10,2] parameter(1) + %updates = f32[10,300] parameter(2) + ROOT %scatter = f32[200,100,300] scatter(f32[200,100,300] %input_tensor, s64[10,2] %scatter_indices, f32[10,300] %updates), update_window_dims={1}, inserted_window_dims={0,1}, scatter_dims_to_operand_dims={0,1}, index_vector_dim=1, to_apply=%update_computation +} + +// CHECK-LABEL: func @test_scatter +// CHECK-SAME: [[ARG_0:%.*]]: tensor<200x100x300xf32>, [[ARG_1:%.*]]: tensor<10x2xi64>, [[ARG_2:%.*]]: tensor<10x300xf32>) -> tensor<200x100x300xf32> +// CHECK: "xla_hlo.scatter"([[ARG_0]], [[ARG_1]], [[ARG_2]]) ( { +// CHECK: ^bb0([[LHS:%.*]]: tensor, [[RHS:%.*]]: tensor): +// CHECK: [[ADD:%.*]] = xla_hlo.add [[LHS]], [[RHS]] +// CHECK: "xla_hlo.return"([[ADD]]) : (tensor) -> () +// CHECK: }) +// CHECK-SAME: indices_are_sorted = false +// CHECK-SAME: scatter_dimension_numbers = { +// CHECK-SAME: index_vector_dim = 1 : i64 +// CHECK-SAME: inserted_window_dims = dense<[0, 1]> : tensor<2xi64> +// CHECK-SAME: scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64> +// CHECK-SAME: update_window_dims = dense<1> : tensor<1xi64> +// CHECK-SAME: } +// CHECK-SAME: unique_indices = false + + // CHECK-LABEL: func @test_select(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> { %test_select { %Arg_0.1 = pred[2,3] parameter(0) @@ -743,6 +774,25 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] { ROOT %sine.3 = f32[1,16,16,3]{3,2,1,0} sine(f32[1,16,16,3]{3,2,1,0} %arg0.1) } +// Test sort +%compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT +} + +%test_sort { + x = f32[1024]{0} parameter(0) + ROOT sorted = f32[1024]{0} sort(x), dimensions={0}, is_stable=true, to_apply=compare +} +// CHECK-LABEL: func @test_sort +// CHECK-SAME: [[ARG:%.*]]: tensor<1024xf32>) -> tensor<1024xf32> +// CHECK: "xla_hlo.sort"([[ARG]]) ( { +// CHECK: ^bb0([[ARG0:%.*]]: tensor, [[ARG1:%.*]]: tensor): +// CHECK: [[CMP:%.*]] = "xla_hlo.compare"([[ARG0]], [[ARG1]]) {comparison_direction = "LT", name = "lt"} : (tensor, tensor) -> tensor +// CHECK: "xla_hlo.return"([[CMP]]) : (tensor) -> () +// CHECK: }) {dimension = 0 : i64, is_stable = true} : (tensor<1024xf32>) -> tensor<1024xf32> + // CHECK-LABEL: func @test_subtract %test_subtract (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] { %Arg_0.1 = f32[4] parameter(0) diff --git a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc index 57610758bae..1384abed91c 100644 --- a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc +++ b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc @@ -30,6 +30,7 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" // TF:llvm-project #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" +#include "tensorflow/compiler/mlir/xla/transforms/hlo_shape_derivation.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" #include "tensorflow/compiler/mlir/xla/transforms/rewriters.h" @@ -49,6 +50,48 @@ Operation* FindInsertionPointForCopy(Value value) { return nullptr; } +Value InsertDynamicAllocAndDealloc(Location loc, Value result, + Value shape_operand, + ConversionPatternRewriter* rewriter) { + auto result_type = result.getType().dyn_cast(); + if (!result_type) { + result.getDefiningOp()->emitOpError() + << "tensor to buffer conversion expects ranked results"; + } + auto memref_type = + MemRefType::get(result_type.getShape(), result_type.getElementType()); + + Operation* op = result.getDefiningOp(); + auto block = op->getBlock(); + + // Extract the required element out of the vector. + SmallVector dynamic_operands; + for (auto shape_element : llvm::enumerate(result_type.getShape())) { + if (shape_element.value() != ShapedType::kDynamicSize) continue; + Value index = rewriter->create( + loc, rewriter->getIntegerAttr(rewriter->getIndexType(), + shape_element.index())); + Value alloc_operand = rewriter->create(loc, shape_operand, + ValueRange{index}); + if (!alloc_operand.getType().isIndex()) { + alloc_operand = rewriter->create(loc, alloc_operand, + rewriter->getIndexType()); + } + dynamic_operands.push_back(alloc_operand); + } + + // Insert in front of op to ensure sizes are available. + OpBuilder allocBuilder(op); + auto alloc = allocBuilder.create(loc, memref_type, dynamic_operands); + + alloc.setAttr(kTempBufferAttr, rewriter->getBoolAttr(true)); + + allocBuilder.setInsertionPoint(block, std::prev(block->end())); + allocBuilder.create(loc, alloc); + + return alloc; +} + Value InsertAllocAndDealloc(Location loc, Value result, ConversionPatternRewriter* rewriter) { auto result_type = result.getType().dyn_cast(); @@ -85,9 +128,24 @@ class HloToLhloOpConverter : public ConversionPattern { ConversionPatternRewriter& rewriter) const final { const auto& original_results = op->getResults(); SmallVector buffer_args(operands.begin(), operands.end()); - for (auto result : original_results) { - buffer_args.push_back( - InsertAllocAndDealloc(op->getLoc(), result, &rewriter)); + for (auto result : llvm::enumerate(original_results)) { + RankedTensorType resultType = + result.value().getType().dyn_cast(); + if (!resultType) { + return matchFailure(); + } + if (resultType.hasStaticShape()) { + buffer_args.push_back( + InsertAllocAndDealloc(op->getLoc(), result.value(), &rewriter)); + } else { + Value shape_value = ShapeDerivation::impl::deriveShapeFromOp( + op, result.index(), &rewriter); + if (!shape_value) { + return matchFailure(); + } + buffer_args.push_back(InsertDynamicAllocAndDealloc( + op->getLoc(), result.value(), shape_value, &rewriter)); + } } rewriter.create(op->getLoc(), llvm::None, buffer_args, op->getAttrs()); @@ -96,6 +154,30 @@ class HloToLhloOpConverter : public ConversionPattern { } }; +struct HloToLHloDynamicBroadcastInDimOpConverter + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + PatternMatchResult matchAndRewrite( + xla_hlo::DynamicBroadcastInDimOp op, ArrayRef operands, + ConversionPatternRewriter& rewriter) const final { + auto loc = op.getLoc(); + auto broadcast_dimensions = op.broadcast_dimensions(); + if (!broadcast_dimensions.hasValue()) { + return matchFailure(); + } + Value resultBuffer = InsertDynamicAllocAndDealloc( + loc, op.getResult(), op.output_dimensions(), &rewriter); + rewriter.create( + loc, operands[0], resultBuffer, broadcast_dimensions.getValue()); + + rewriter.replaceOp(op, {resultBuffer}); + + return matchSuccess(); + } +}; + struct HloToLHloReduceOpConverter : public OpConversionPattern { public: @@ -254,6 +336,7 @@ struct HloLegalizeToLhlo : public ModulePass { target.addIllegalOp(); target.addIllegalOp(); target.addLegalOp(); + target.addLegalOp(); target.addIllegalDialect(); target.addDynamicallyLegalOp([&](FuncOp op) { auto inputs = op.getType().getInputs(); @@ -264,7 +347,8 @@ struct HloLegalizeToLhlo : public ModulePass { auto module = getModule(); populateHLOToLHLOConversionPattern(module.getContext(), &patterns); - if (failed(applyFullConversion(module, target, patterns, nullptr))) { + // Do partial conversion so we can have unknown ops in tests. + if (failed(applyPartialConversion(module, target, patterns, nullptr))) { signalPassFailure(); } } @@ -354,7 +438,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context, OwningRewritePatternList* patterns) { // clang-format off patterns->insert< - HloToLHloReduceOpConverter, + HloToLHloDynamicBroadcastInDimOpConverter, HloToLhloFuncOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, @@ -379,9 +463,10 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, + HloToLHloReduceOpConverter, + StdToLhloReturnOpConverter, HloToLhloTensorLoadOpConverter, - HloToLhloTensorStoreOpConverter, - StdToLhloReturnOpConverter + HloToLhloTensorStoreOpConverter >(context); // clang-format on } diff --git a/tensorflow/compiler/mlir/xla/transforms/hlo_shape_derivation.h b/tensorflow/compiler/mlir/xla/transforms/hlo_shape_derivation.h new file mode 100644 index 00000000000..7c6d162632f --- /dev/null +++ b/tensorflow/compiler/mlir/xla/transforms/hlo_shape_derivation.h @@ -0,0 +1,130 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_HLO_SHAPE_DERIVATION_H_ +#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_HLO_SHAPE_DERIVATION_H_ + +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" + +namespace mlir { +namespace xla_hlo { + +// This file contains implementations for shape derivation functions that, +// given some operation and a result number, produce IR that computes the +// shape of the given result at runtime based on operands of the provided +// operation. +// These should be generated at some point based on annotations on the HLO +// using the new shape dialect. While this is still in the works, we hardcode +// the expected IR here to unblock progress. +// The implementation is based on templates to allow for using these derivation +// functions in templated code. + +namespace impl { + +struct UnknownShape { + // Default shape derivation function that simply fails with a runtime error. + static Value deriveShapeFromOp(Operation* op, int operand_position, + ConversionPatternRewriter* rewriter) { + op->emitOpError() + << "dynamic result shapes cannot be derived for this operation"; + return {}; + } +}; + +struct SameShapeAsFirstOperand { + // Shape derivation function that computes the shape of the result based on + // the first argument. For a 2-dimensional input tensor, this produces IR of + // the form + // + // %0 = dim %arg0, 0 : memref + // %1 = index_cast %0 : index to i64 + // %2 = dim %arg0, 1 : memref + // %3 = index_cast %2 : index to i64 + // %4 = "xla_hlo.scalars_to_dimension_tensor"(%1, %3) + // : (i64, i64) -> tensor<2xi64> + // + // and returns %4 as the shape value. + static Value deriveShapeFromOp(Operation* op, int result_postion, + ConversionPatternRewriter* rewriter) { + Value operand = op->getOperand(0); + ShapedType operand_type = operand.getType().dyn_cast(); + if (!operand_type) { + op->emitOpError() << "first operand has no shaped type"; + return {}; + } + auto loc = op->getLoc(); + SmallVector shape_values; + shape_values.reserve(operand_type.getRank()); + auto shape_scalar_type = rewriter->getIntegerType(64); + for (auto element : llvm::enumerate(operand_type.getShape())) { + if (element.value() == ShapedType::kDynamicSize) { + Value dim = rewriter->create(loc, operand, element.index()); + shape_values.push_back( + rewriter->create(loc, dim, shape_scalar_type)); + } else { + shape_values.push_back(rewriter->create( + loc, rewriter->getI64IntegerAttr(element.value()))); + } + } + return rewriter->create( + loc, RankedTensorType::get({operand_type.getRank()}, shape_scalar_type), + shape_values); + } +}; + +} // namespace impl + +// Default template to cover HLO operations whose shape derivation is unknown. +template +struct ShapeDerivation { + using impl = impl::UnknownShape; +}; + +// Element-wise operations that have the shape of their first operand. + +#define SAME_SHAPE_AS_FIRST_OPERAND(Op) \ + template <> \ + struct ShapeDerivation { \ + using impl = impl::SameShapeAsFirstOperand; \ + }; + +SAME_SHAPE_AS_FIRST_OPERAND(AbsOp) +SAME_SHAPE_AS_FIRST_OPERAND(AddOp) +SAME_SHAPE_AS_FIRST_OPERAND(AndOp) +SAME_SHAPE_AS_FIRST_OPERAND(CeilOp) +SAME_SHAPE_AS_FIRST_OPERAND(CosOp) +SAME_SHAPE_AS_FIRST_OPERAND(DivOp) +SAME_SHAPE_AS_FIRST_OPERAND(ExpOp) +SAME_SHAPE_AS_FIRST_OPERAND(MaxOp) +SAME_SHAPE_AS_FIRST_OPERAND(MinOp) +SAME_SHAPE_AS_FIRST_OPERAND(MulOp) +SAME_SHAPE_AS_FIRST_OPERAND(NegOp) +SAME_SHAPE_AS_FIRST_OPERAND(RemOp) +SAME_SHAPE_AS_FIRST_OPERAND(SubOp) +SAME_SHAPE_AS_FIRST_OPERAND(TanhOp) + +#undef SAME_SHAPE_AS_FIRST_OPERAND + +} // namespace xla_hlo +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_HLO_SHAPE_DERIVATION_H_ diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index 50ecce24df3..da135ea1860 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -3362,7 +3362,7 @@ class ConvertRandomShuffleOp : public OpRewritePattern { auto indices_type = RankedTensorType::get({first_dim_size}, rewriter.getIntegerType(32)); Value indices = rewriter.create( - op.getLoc(), indices_type, rewriter.getI64IntegerAttr(first_dim_size)); + op.getLoc(), indices_type, rewriter.getI64IntegerAttr(0)); // Generate random numbers to be used as swaps for the indices. Value swaps = CreateRngUniform32(op.getLoc(), first_dim_size, 0, diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index a78d9cc2d2d..872a288c259 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -368,6 +368,9 @@ def : Pat<(TF_ConstOp:$res ElementsAttr:$value), (HLO_ConstOp $value), // Relu op patterns. //===----------------------------------------------------------------------===// +// TODO(hinsu): Make these patterns to TF to TF lowering. Relu6 lowering will +// require HLO canonicalization of min and max on a tensor to ClampOp. + // TODO(hinsu): Lower unsinged and quantized types after supporting // them in GetScalarOfType. def : Pat<(TF_ReluOp AnyRankedTensor:$input), diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc index 5e12abc466c..5ee6010c3a8 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc @@ -24,12 +24,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/xla/transforms/passes.h" #include "tensorflow/compiler/mlir/xla/transforms/rewriters.h" -using mlir::Builder; -using mlir::FunctionPass; -using mlir::OpPassBase; -using mlir::OwningRewritePatternList; -using mlir::PassRegistration; - namespace mlir { namespace { #include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_to_standard.inc" @@ -37,16 +31,14 @@ namespace { namespace xla_hlo { namespace { -struct CompareIConvert : public RewritePattern { - explicit CompareIConvert(MLIRContext *context) - : RewritePattern("xla_hlo.compare", 1, context) {} +class CompareIConvert : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(Operation *op, + PatternMatchResult matchAndRewrite(xla_hlo::CompareOp op, PatternRewriter &rewriter) const override { - auto compare_op = cast(op); - - auto lhs = compare_op.lhs(); - auto rhs = compare_op.rhs(); + auto lhs = op.lhs(); + auto rhs = op.rhs(); auto lhs_type = lhs.getType().cast(); auto rhs_type = rhs.getType().cast(); @@ -57,7 +49,7 @@ struct CompareIConvert : public RewritePattern { !rhs_type.getElementType().isa()) return matchFailure(); - auto comparison_direction = compare_op.comparison_direction(); + auto comparison_direction = op.comparison_direction(); auto compare_predicate = llvm::StringSwitch>(comparison_direction) .Case("EQ", CmpIPredicate::eq) @@ -76,16 +68,14 @@ struct CompareIConvert : public RewritePattern { } }; -struct CompareFConvert : public RewritePattern { - explicit CompareFConvert(MLIRContext *context) - : RewritePattern("xla_hlo.compare", 1, context) {} +class CompareFConvert : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(Operation *op, + PatternMatchResult matchAndRewrite(xla_hlo::CompareOp op, PatternRewriter &rewriter) const override { - auto compare_op = cast(op); - - auto lhs = compare_op.lhs(); - auto rhs = compare_op.rhs(); + auto lhs = op.lhs(); + auto rhs = op.rhs(); auto lhs_type = lhs.getType().cast(); auto rhs_type = rhs.getType().cast(); @@ -96,7 +86,7 @@ struct CompareFConvert : public RewritePattern { !rhs_type.getElementType().isa()) return matchFailure(); - auto comparison_direction = compare_op.comparison_direction(); + auto comparison_direction = op.comparison_direction(); CmpFPredicate compare_predicate = llvm::StringSwitch(comparison_direction) .Case("EQ", CmpFPredicate::OEQ) @@ -115,9 +105,42 @@ struct CompareFConvert : public RewritePattern { } }; +class ConvertIotaOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(xla_hlo::IotaOp op, + PatternRewriter &rewriter) const override { + auto output_type = op.getType().cast(); + // TODO(prakalps): Handle FP and ComplexType iota ops. + if (!output_type.getElementType().isa()) return matchFailure(); + auto output_size = output_type.getNumElements(); + auto dimension = op.iota_dimension().getSExtValue(); + auto max_dim_size = output_type.getDimSize(dimension); + int bitwidth = output_type.getElementType().getIntOrFloatBitWidth(); + + llvm::SmallVector values; + values.reserve(output_size); + + int64_t increase_stride = output_size; + for (int i = 0; i <= dimension; i++) { + increase_stride /= output_type.getDimSize(i); + } + + int64_t current_value = 0; + for (int i = 0; i < output_size; i++) { + int64_t value = (current_value / increase_stride) % max_dim_size; + values.push_back(APInt(bitwidth, value)); + ++current_value; + } + + rewriter.replaceOpWithNewOp( + op, DenseIntElementsAttr::get(output_type, values)); + return matchSuccess(); + } +}; + } // end anonymous namespace -} // end namespace xla_hlo -} // end namespace mlir namespace { struct LegalizeToStandard : public FunctionPass { @@ -126,17 +149,14 @@ struct LegalizeToStandard : public FunctionPass { }; } // end anonymous namespace -std::unique_ptr> -mlir::xla_hlo::createLegalizeToStdPass() { +std::unique_ptr> createLegalizeToStdPass() { return std::make_unique(); } -void mlir::xla_hlo::PopulateXlaToStdPatterns(OwningRewritePatternList *patterns, - mlir::MLIRContext *ctx) { +void PopulateXlaToStdPatterns(OwningRewritePatternList *patterns, + mlir::MLIRContext *ctx) { mlir::populateWithGenerated(ctx, patterns); - patterns - ->insert( - ctx); + patterns->insert(ctx); } /// Perform the lowering to standard dialect. @@ -148,3 +168,6 @@ void LegalizeToStandard::runOnFunction() { static PassRegistration legalize_pass( "xla-legalize-to-std", "Legalize from XLA dialect to standard dialect"); + +} // end namespace xla_hlo +} // end namespace mlir diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc index b5e33fb0663..a52d2318ba7 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc @@ -18,6 +18,7 @@ limitations under the License. #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "absl/memory/memory.h" +#include "llvm/ADT/ArrayRef.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" // TF:llvm-project #include "mlir/Pass/Pass.h" // TF:llvm-project #include "mlir/Transforms/FoldUtils.h" // TF:llvm-project @@ -28,7 +29,15 @@ namespace { using linalg::LinalgOp; -struct LhloFuseLinalg : public FunctionPass { +class LhloFuseLinalg : public FunctionPass { + public: + LhloFuseLinalg() = default; + LhloFuseLinalg(const LhloFuseLinalg&) {} + LhloFuseLinalg(bool use_parallel_loops, llvm::ArrayRef tile_sizes) { + tile_sizes_->assign(tile_sizes.begin(), tile_sizes.end()); + use_parallel_loops_.setValue(use_parallel_loops); + } + void runOnFunction() override { auto func = getFunction(); @@ -50,13 +59,16 @@ struct LhloFuseLinalg : public FunctionPass { OpBuilder b(func); OperationFolder folder(func.getContext()); func.walk([&](linalg::GenericOp generic_op) { - const SmallVector tile_sizes( - generic_op.getNumInputsAndOutputs(), 1); + SmallVector tile_sizes(tile_sizes_.begin(), + tile_sizes_.end()); + if (tile_sizes.empty()) { + tile_sizes = + SmallVector(generic_op.getNumInputsAndOutputs(), 1); + } auto op = cast(generic_op.getOperation()); for (const Value result : op.getOutputBuffers()) { if (!func_args.count(result)) continue; - if (linalg::tileLinalgOp(b, op, tile_sizes, /*permutation=*/{}, - &folder)) { + if (tileGenericOp(op, tile_sizes, &b, &folder)) { generic_op.erase(); return; } @@ -83,6 +95,30 @@ struct LhloFuseLinalg : public FunctionPass { } for (auto* e : erase_set) e->erase(); } + + private: + bool tileGenericOp(LinalgOp op, ArrayRef tile_sizes, OpBuilder* b, + OperationFolder* folder) { + auto tiled_generic_op = + use_parallel_loops_ + ? linalg::tileLinalgOpToParallelLoops(*b, op, tile_sizes, + /*permutation=*/{}, folder) + : linalg::tileLinalgOp(*b, op, tile_sizes, + /*permutation=*/{}, folder); + return tiled_generic_op.hasValue(); + } + + Option use_parallel_loops_{ + *this, "use-parallel-loops", + llvm::cl::desc( + "Tiles GenericOp consumer to parallel loops before linalg fusion"), + llvm::cl::init(false)}; + + ListOption tile_sizes_{ + *this, "tile-sizes", + llvm::cl::desc( + "Tile sizes by which to tile linalg generic before linalg fusion"), + llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; }; } // namespace diff --git a/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc b/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc index 3ff6d374493..13467be41d9 100644 --- a/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc +++ b/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc @@ -15,6 +15,7 @@ limitations under the License. #include +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project #include "mlir/IR/MLIRContext.h" // TF:llvm-project #include "mlir/IR/Operation.h" // TF:llvm-project #include "mlir/IR/PatternMatch.h" // TF:llvm-project @@ -72,10 +73,9 @@ bool CreateBroadcastsForBinaryOp(SrcOp op, PatternRewriter *rewriter, return false; } - if (!op_ranked_type.hasStaticShape()) { - // Dynamic result shape, can't use BroadcastInDimOp. - return false; - } + // Dynamic result shape, can't use BroadcastInDimOp. + assert(op_ranked_type.hasStaticShape() && + "dynamic shape requires DynamicBroadcastInDim"); auto lhs_rank = lhs_ranked_type.getRank(); auto rhs_rank = rhs_ranked_type.getRank(); @@ -118,6 +118,143 @@ bool CreateBroadcastsForBinaryOp(SrcOp op, PatternRewriter *rewriter, return true; } +// Helper template to generate code for computing the result shape of a +// broadcasted operation. This ultimately should be subsumed by functions +// from the shape dialect. +// Assumes that large and small are the operand values of `op` and that they +// have a ranked tensory type with rank(large) >= rank(small). +template +std::vector ComputeBroadcastedShape(SrcOp op, Value small, Value large, + PatternRewriter *rewriter) { + auto loc = op.getLoc(); + auto larger_ranked_type = large.getType().cast(); + auto output_rank = larger_ranked_type.getRank(); + + constexpr int kExpandShape = -1; + + std::vector shape_values; + shape_values.reserve(output_rank); + std::vector indexes(output_rank, kExpandShape); + DenseIntElementsAttr broadcast_dimensions = + op.broadcast_dimensions().getValue(); + // Compute a mapping from output dimensions to their corresponding input + // dimensions in the smaller ranked operand. + for (auto pair : llvm::enumerate(broadcast_dimensions.getIntValues())) { + indexes.at(pair.value().getLimitedValue()) = pair.index(); + } + + // Compute the broadcasted shape of the result using numpy style broadcasting + // semantics. The result shape at a position is the shape of the larger + // operand at that position if the no dimension of the smaller operand is + // mapped to it. + // If both operands contribute to an output dimension, their shape has to + // either be the same in that dimension or it can be 1, in which case the + // shape of the other operand is used. + for (int i = 0; i < output_rank; ++i) { + Value index_value; + if (indexes[i] == kExpandShape) { + // The smaller shape gets expanded to the larger one in this case. + index_value = rewriter->create(loc, large, i); + } else { + // Compute the result shape depending on whether the rank of smaller is 1. + // This does not check that the broadcast operation actualy is correct. + // In particular, we do not check that both shapes are the same if the + // smaller ranked shape is not 1. + ConstantOp one = rewriter->create( + loc, rewriter->getIntegerAttr(rewriter->getIndexType(), 1)); + DimOp lrg_dim = rewriter->create(loc, large, i); + DimOp sml_dim = rewriter->create(loc, small, indexes[i]); + CmpIOp compare = + rewriter->create(loc, CmpIPredicate::eq, lrg_dim, one); + index_value = + rewriter->create(loc, compare, lrg_dim, sml_dim); + } + // Ideally, we would like to keep this on index but MLIR does not allow + // this. + shape_values.push_back(rewriter->create( + loc, index_value, rewriter->getIntegerType(32))); + } + + return shape_values; +} + +// Helper function for OpRewritePattern classes to materialize dynamic +// broadcasts on LHS and RHS arguments to a binary op. +// +// Returns true and set out_lhs and out_rhs for materialized dynamic broadcasts +// for LHS and RHS arguments, else returns false. +template +bool CreateDynamicBroadcastsForBinaryOp(SrcOp op, PatternRewriter *rewriter, + Value *out_lhs, Value *out_rhs) { + if (!op.broadcast_dimensions().hasValue()) { + // Note: the op may still have an implicit broadcast on it, such as + // for (tensor<1xf32>, tensor<4xf32>). + return false; + } + + // Insert BroadcastInDimOps for the left-hand-side and right-hand-side args, + // replacing the original LHS and RHS args in the source op with the results + // of the broadcasts. + Value lhs = op.lhs(); + Value rhs = op.rhs(); + + auto lhs_ranked_type = lhs.getType().dyn_cast(); + auto rhs_ranked_type = rhs.getType().dyn_cast(); + if (!lhs_ranked_type || !rhs_ranked_type) { + // Unranked, can't determine at this point how to perform the broadcast. + return false; + } + + auto lhs_rank = lhs_ranked_type.getRank(); + auto rhs_rank = rhs_ranked_type.getRank(); + + // Set broadcast_dimensions to [0, ..., rank] for the higher rank arg. + // Use the original op.broadcast_dimensions for the lower rank arg. + auto higher_rank_broadcast_dims = + GetI64ElementsAttrForSeq(0, std::max(lhs_rank, rhs_rank), rewriter); + DenseIntElementsAttr lhs_broadcast_dims; + DenseIntElementsAttr rhs_broadcast_dims; + std::vector shape_elements; + if (lhs_rank > rhs_rank) { + lhs_broadcast_dims = higher_rank_broadcast_dims; + rhs_broadcast_dims = op.broadcast_dimensions().getValue(); + shape_elements = ComputeBroadcastedShape(op, rhs, lhs, rewriter); + } else if (lhs_rank < rhs_rank) { + lhs_broadcast_dims = op.broadcast_dimensions().getValue(); + rhs_broadcast_dims = higher_rank_broadcast_dims; + shape_elements = ComputeBroadcastedShape(op, lhs, rhs, rewriter); + } else { + // This shouldn't happen for legal ops. If the broadcast_dimensions + // attribute is set, the ranks should be different. + // TODO(scotttodd): Add a custom verification for ops and assert here. + return false; + } + + // DynamicBroadcastInDimOp preserves the element type but produces a tensor + // with unranked shape. The rank of the output is the length of the + // output shape argument. + SmallVector op_shape(shape_elements.size(), + RankedTensorType::kDynamicSize); + auto lhs_type = + RankedTensorType::get(op_shape, lhs_ranked_type.getElementType()); + auto rhs_type = + RankedTensorType::get(op_shape, rhs_ranked_type.getElementType()); + + // We need a way to turn a list of scalars into a vector. While Standard + // dialect does not have one, use the XLA_HLO variant. + int shape_size = shape_elements.size(); + Type shape_element_type = shape_elements.front().getType(); + Value shape_value = rewriter->create( + op.getLoc(), RankedTensorType::get({shape_size}, shape_element_type), + shape_elements); + + *out_lhs = rewriter->createOrFold( + op.getLoc(), lhs_type, lhs, shape_value, lhs_broadcast_dims); + *out_rhs = rewriter->createOrFold( + op.getLoc(), rhs_type, rhs, shape_value, rhs_broadcast_dims); + return true; +} + template struct BinaryOpWithBroadcastConvert : public OpRewritePattern { explicit BinaryOpWithBroadcastConvert(MLIRContext *context) @@ -127,8 +264,19 @@ struct BinaryOpWithBroadcastConvert : public OpRewritePattern { PatternRewriter &rewriter) const override { Value new_lhs; Value new_rhs; - if (!CreateBroadcastsForBinaryOp(op, &rewriter, &new_lhs, &new_rhs)) { - return this->matchFailure(); + + auto op_ranked_type = op.getType().template dyn_cast(); + if (!op_ranked_type) return this->matchFailure(); + + if (op_ranked_type.hasStaticShape()) { + if (!CreateBroadcastsForBinaryOp(op, &rewriter, &new_lhs, &new_rhs)) { + return this->matchFailure(); + } + } else { + if (!CreateDynamicBroadcastsForBinaryOp(op, &rewriter, &new_lhs, + &new_rhs)) { + return this->matchFailure(); + } } // Replace the original op with a new one that uses the new args. diff --git a/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts_pass.cc b/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts_pass.cc index 933f8a73fd5..596b67f0eed 100644 --- a/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts_pass.cc +++ b/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts_pass.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project #include "mlir/IR/MLIRContext.h" // TF:llvm-project #include "mlir/IR/Operation.h" // TF:llvm-project #include "mlir/IR/PatternMatch.h" // TF:llvm-project @@ -34,6 +35,8 @@ struct TestMaterializeBroadcastsPass // Consider the xla_hlo dialect legal for tests. conversionTarget.addLegalDialect(); + // The conversion uses helpers from the Standard dialect. + conversionTarget.addLegalDialect(); SetupMaterializeBroadcastsLegality(&getContext(), &conversionTarget); PopulateMaterializeBroadcastsPatterns(&getContext(), &conversionPatterns); diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc index b6019b1e263..d07819284e5 100644 --- a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc +++ b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc @@ -227,19 +227,21 @@ class BroadcastInDimConverter unsigned nloops = resultMemrefType.getRank(); + auto operandShape = operandMemrefType.getShape(); SmallVector dimExprs; { dimExprs.reserve(nloops); + for (const auto& broadcastDim : llvm::enumerate( + broadcastOp.broadcast_dimensions().getValue().getIntValues())) { + int dim = broadcastDim.value().getSExtValue(); - auto operandShape = operandMemrefType.getShape(); - int index = 0; - for (const auto& broadcastSize : - broadcastOp.broadcast_dimensions().getValue().getIntValues()) { - int size = broadcastSize.getSExtValue(); - dimExprs.push_back( - operandShape[index++] == 1 + // TODO(pifon): Add support for args with dynamic shapes for the case + // when a dimension of size 1 is broadcasted into dim of size N. + AffineExpr affineExpr = + operandShape[broadcastDim.index()] == 1 ? mlir::getAffineConstantExpr(0, broadcastOp.getContext()) - : mlir::getAffineDimExpr(size, broadcastOp.getContext())); + : mlir::getAffineDimExpr(dim, broadcastOp.getContext()); + dimExprs.push_back(affineExpr); } } diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 77dbb1919be..203ef51c842 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -18,6 +18,10 @@ package_group( includes = [ "//tensorflow/compiler/tf2xla:internal", ], + packages = [ + # To pass open source testing in the pip Kokoros. + "//bazel_pip/tensorflow/compiler/tests/...", + ], ) package_group( @@ -25,6 +29,10 @@ package_group( includes = [ "//tensorflow/compiler/tf2xla:friends", ], + packages = [ + # To pass open source testing in the pip Kokoros. + "//bazel_pip/tensorflow/compiler/tests/...", + ], ) generate_backend_suites() @@ -53,7 +61,10 @@ py_library( py_library( name = "test_utils", testonly = 1, - srcs = ["test_utils.py"], + srcs = [ + "__init__.py", + "test_utils.py", + ], srcs_version = "PY2AND3", deps = [ "//third_party/py/numpy", @@ -66,6 +77,9 @@ py_test( size = "small", srcs = ["xla_test_test.py"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", ], @@ -76,6 +90,9 @@ tf_xla_py_test( size = "medium", srcs = ["adadelta_test.py"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -90,6 +107,9 @@ tf_xla_py_test( size = "small", srcs = ["adagrad_test.py"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -105,6 +125,9 @@ tf_xla_py_test( size = "small", srcs = ["adagrad_da_test.py"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -119,6 +142,9 @@ tf_xla_py_test( size = "small", srcs = ["adam_test.py"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -136,6 +162,9 @@ tf_xla_py_test( # TensorList ops are not implemented in the on-demand compilation model yet. disabled_backends = ["cpu_ondemand"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -151,6 +180,9 @@ tf_xla_py_test( size = "small", srcs = ["argminmax_test.py"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -168,6 +200,7 @@ tf_xla_py_test( shard_count = 5, tags = [ "no_oss", # TODO(b/148108508): Re-enable this test in OSS. + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "optonly", # Times out frequently in fastbuild mode. ], deps = [ @@ -194,6 +227,7 @@ tf_xla_py_test( python_version = "PY3", shard_count = 2, tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "optonly", # Times out frequently in fastbuild mode. ], deps = [ @@ -212,6 +246,9 @@ tf_xla_py_test( size = "small", srcs = ["bucketize_op_test.py"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -226,7 +263,10 @@ tf_xla_py_test( size = "small", srcs = ["categorical_op_test.py"], python_version = "PY3", - tags = ["optonly"], + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "optonly", + ], deps = [ ":xla_test", "//tensorflow/python:framework", @@ -242,6 +282,7 @@ tf_xla_py_test( srcs = ["cholesky_op_test.py"], python_version = "PY3", tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "no_rocm", "optonly", ], @@ -261,6 +302,9 @@ tf_xla_py_test( size = "small", srcs = ["cond_test.py"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/compiler/tf2xla/python:xla", @@ -278,7 +322,10 @@ tf_xla_py_test( size = "medium", srcs = ["self_adjoint_eig_op_test.py"], python_version = "PY3", - tags = ["optonly"], + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "optonly", + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -291,18 +338,6 @@ tf_xla_py_test( ], ) -tf_xla_py_test( - name = "searchsorted_op_test", - size = "small", - timeout = "moderate", - srcs = ["searchsorted_op_test.py"], - python_version = "PY3", - deps = [ - ":xla_test", - "//tensorflow/python:platform_test", - ], -) - tf_xla_py_test( name = "svd_op_test", size = "medium", @@ -314,6 +349,7 @@ tf_xla_py_test( ], python_version = "PY3", tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "no_rocm", "optonly", ], @@ -336,6 +372,7 @@ tf_xla_py_test( srcs = ["matrix_inverse_op_test.py"], python_version = "PY3", tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "noasan", "nomsan", "notsan", @@ -356,6 +393,9 @@ tf_xla_py_test( timeout = "moderate", srcs = ["matrix_solve_op_test.py"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/python:linalg_ops", @@ -371,7 +411,10 @@ tf_xla_py_test( timeout = "moderate", srcs = ["matrix_triangular_solve_op_test.py"], python_version = "PY3", - tags = ["optonly"], + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "optonly", + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -387,6 +430,9 @@ tf_xla_py_test( size = "small", srcs = ["clustering_test.py"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -403,6 +449,7 @@ tf_xla_py_test( python_version = "PY3", tags = [ "many_xla_args", + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "no_rocm", ], deps = [ @@ -423,6 +470,9 @@ tf_xla_py_test( srcs = ["conv2d_test.py"], python_version = "PY3", shard_count = 10, + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":test_utils", ":xla_test", @@ -442,6 +492,9 @@ tf_xla_py_test( srcs = ["conv3d_test.py"], python_version = "PY3", shard_count = 5, + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -460,6 +513,7 @@ tf_xla_py_test( python_version = "PY3", shard_count = 5, tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "no_rocm", "noasan", "nomsan", @@ -482,6 +536,9 @@ tf_xla_py_test( size = "small", srcs = ["dynamic_slice_ops_test.py"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ "//tensorflow/compiler/tests:xla_test", "//tensorflow/compiler/tf2xla/python:xla", @@ -499,6 +556,9 @@ tf_xla_py_test( "gpu", ], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -513,6 +573,9 @@ tf_xla_py_test( size = "small", srcs = ["reshape_op_test.py"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ "//tensorflow/compiler/tests:xla_test", "//tensorflow/compiler/tf2xla/python:xla", @@ -527,6 +590,9 @@ tf_xla_py_test( size = "small", srcs = ["dynamic_stitch_test.py"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -541,6 +607,9 @@ tf_xla_py_test( size = "small", srcs = ["extract_image_patches_op_test.py"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -556,6 +625,7 @@ tf_xla_py_test( python_version = "PY3", tags = [ "multi_and_single_gpu", + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip ], deps = [ ":xla_test", @@ -574,6 +644,9 @@ tf_xla_py_test( size = "medium", srcs = ["fifo_queue_test.py"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -591,6 +664,7 @@ tf_xla_py_test( python_version = "PY3", shard_count = 6, tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "no_rocm", "optonly", ], @@ -609,6 +683,9 @@ tf_xla_py_test( size = "small", srcs = ["slice_ops_test.py"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -623,6 +700,9 @@ tf_xla_py_test( size = "medium", srcs = ["ftrl_test.py"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -638,6 +718,9 @@ tf_xla_py_test( size = "small", srcs = ["function_test.py"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -653,6 +736,7 @@ tf_xla_py_test( python_version = "PY3", shard_count = 10, tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "optonly", # Times out frequently in fastbuild mode. ], deps = [ @@ -669,6 +753,9 @@ tf_xla_py_test( size = "small", srcs = ["listdiff_op_test.py"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -685,6 +772,9 @@ tf_xla_py_test( size = "medium", srcs = ["lrn_ops_test.py"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -700,6 +790,9 @@ tf_xla_py_test( size = "small", srcs = ["manip_ops_test.py"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -715,7 +808,10 @@ tf_xla_py_test( timeout = "long", srcs = ["matrix_band_part_test.py"], python_version = "PY3", - tags = ["optonly"], + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "optonly", + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -731,6 +827,9 @@ tf_xla_py_test( timeout = "long", srcs = ["matrix_diag_ops_test.py"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -744,6 +843,9 @@ tf_xla_py_test( size = "small", srcs = ["momentum_test.py"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -759,6 +861,9 @@ tf_xla_py_test( size = "small", srcs = ["nary_ops_test.py"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -773,6 +878,9 @@ tf_xla_py_test( size = "small", srcs = ["nullary_ops_test.py"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/python:control_flow_ops", @@ -787,6 +895,9 @@ tf_xla_py_test( srcs = ["pooling_ops_test.py"], python_version = "PY3", shard_count = 10, + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -803,6 +914,9 @@ tf_xla_py_test( srcs = ["pooling_ops_3d_test.py"], python_version = "PY3", shard_count = 10, + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -818,6 +932,9 @@ tf_xla_py_test( size = "medium", srcs = ["proximal_adagrad_test.py"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -832,6 +949,9 @@ tf_xla_py_test( size = "medium", srcs = ["proximal_gradient_descent_test.py"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -852,7 +972,10 @@ tf_xla_py_test( ], python_version = "PY3", shard_count = 5, - tags = ["optonly"], + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "optonly", + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -871,6 +994,7 @@ tf_xla_py_test( python_version = "PY3", shard_count = 5, tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "no_rocm", "optonly", ], @@ -892,6 +1016,7 @@ tf_xla_py_test( python_version = "PY3", shard_count = 10, tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "notap", # TODO(b/141057424): flaky on TPU ], deps = [ @@ -911,6 +1036,9 @@ tf_xla_py_test( srcs = ["reduce_ops_test.py"], python_version = "PY3", shard_count = 5, + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -927,6 +1055,9 @@ tf_xla_py_test( size = "small", srcs = ["reduce_window_test.py"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/compiler/tf2xla/python:xla", @@ -943,6 +1074,9 @@ tf_xla_py_test( size = "medium", srcs = ["reverse_ops_test.py"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -955,7 +1089,10 @@ tf_xla_py_test( size = "medium", srcs = ["reverse_sequence_op_test.py"], python_version = "PY3", - tags = ["optonly"], + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "optonly", + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -969,6 +1106,9 @@ tf_xla_py_test( size = "small", srcs = ["rmsprop_test.py"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -984,7 +1124,10 @@ tf_xla_py_test( size = "small", srcs = ["scan_ops_test.py"], python_version = "PY3", - tags = ["optonly"], + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "optonly", + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -999,6 +1142,9 @@ tf_xla_py_test( size = "medium", srcs = ["segment_reduction_ops_test.py"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -1015,6 +1161,9 @@ tf_xla_py_test( srcs = ["spacetobatch_op_test.py"], python_version = "PY3", shard_count = 3, + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -1029,6 +1178,9 @@ tf_xla_py_test( size = "small", srcs = ["sparse_to_dense_op_test.py"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -1043,7 +1195,10 @@ tf_xla_py_test( size = "small", srcs = ["stack_ops_test.py"], python_version = "PY3", - tags = ["config-cuda-only"], + tags = [ + "config-cuda-only", + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], use_xla_device = False, deps = [ ":xla_test", @@ -1060,7 +1215,10 @@ tf_xla_py_test( srcs = ["stateful_random_ops_test.py"], python_version = "PY3", shard_count = 10, - tags = ["optonly"], + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "optonly", + ], deps = [ ":xla_test", "//tensorflow/python:framework", @@ -1076,7 +1234,10 @@ tf_xla_py_test( size = "medium", srcs = ["stateless_random_ops_test.py"], python_version = "PY3", - tags = ["optonly"], + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "optonly", + ], deps = [ ":xla_test", "//tensorflow/python:framework", @@ -1096,6 +1257,7 @@ tf_xla_py_test( python_version = "PY3", tags = [ "config-cuda-only", + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "v1only", ], use_xla_device = False, @@ -1121,6 +1283,9 @@ tf_xla_py_test( # TensorList ops are not implemented in the on-demand compilation model yet. disabled_backends = ["cpu_ondemand"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -1136,6 +1301,9 @@ tf_xla_py_test( size = "medium", srcs = ["ternary_ops_test.py"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -1152,6 +1320,9 @@ tf_xla_py_test( size = "medium", srcs = ["unary_ops_test.py"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -1168,6 +1339,9 @@ tf_xla_py_test( size = "medium", srcs = ["fused_batchnorm_test.py"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":test_utils", ":xla_test", @@ -1188,7 +1362,10 @@ tf_xla_py_test( size = "small", srcs = ["variable_ops_test.py"], python_version = "PY3", - tags = ["optonly"], + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "optonly", + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -1207,6 +1384,9 @@ tf_xla_py_test( size = "small", srcs = ["while_test.py"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/compiler/tf2xla/python:xla", @@ -1222,7 +1402,10 @@ tf_xla_py_test( size = "medium", srcs = ["gather_test.py"], python_version = "PY3", - tags = ["optonly"], + tags = [ + "no_pip", + "optonly", + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -1237,6 +1420,9 @@ tf_xla_py_test( size = "medium", srcs = ["gather_nd_op_test.py"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -1250,7 +1436,10 @@ tf_xla_py_test( size = "medium", srcs = ["scatter_nd_op_test.py"], python_version = "PY3", - tags = ["optonly"], + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "optonly", + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -1266,7 +1455,10 @@ tf_xla_py_test( python_version = "PY3", shard_count = 1, # Times out in fastbuild mode. - tags = ["optonly"], + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "optonly", + ], deps = [ "//tensorflow/compiler/tests:xla_test", "//tensorflow/compiler/tf2xla/python:xla", @@ -1280,6 +1472,9 @@ tf_xla_py_test( size = "small", srcs = ["data_format_ops_test.py"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ "//tensorflow/compiler/tests:xla_test", "//tensorflow/python:array_ops", @@ -1294,7 +1489,10 @@ tf_xla_py_test( size = "small", srcs = ["xla_device_test.py"], python_version = "PY3", - tags = ["optonly"], + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "optonly", + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -1307,6 +1505,9 @@ cuda_py_test( name = "xla_device_gpu_test", size = "small", srcs = ["xla_device_gpu_test.py"], + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], xla_enable_strict_auto_jit = False, deps = [ "//tensorflow/python:array_ops", @@ -1323,7 +1524,10 @@ cuda_py_test( size = "medium", srcs = ["jit_test.py"], shard_count = 5, - tags = ["no_rocm"], + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "no_rocm", + ], xla_enable_strict_auto_jit = False, deps = [ ":test_utils", @@ -1344,7 +1548,10 @@ cuda_py_test( name = "dense_layer_test", size = "medium", srcs = ["dense_layer_test.py"], - tags = ["no_rocm"], + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "no_rocm", + ], xla_enable_strict_auto_jit = False, deps = [ ":test_utils", @@ -1385,6 +1592,7 @@ tf_cuda_cc_test( size = "large", # This test is randomized, so only run it if explicitly requested. tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "manual", "notap", ] + tf_cuda_tests_tags(), @@ -1394,7 +1602,9 @@ tf_cuda_cc_test( tf_cuda_cc_test( name = "unary_ops_composition_test", srcs = ["unary_ops_composition_test.cc"], - tags = tf_cuda_tests_tags(), + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ] + tf_cuda_tests_tags(), deps = [ "//tensorflow/cc:cc_ops", "//tensorflow/compiler/jit", @@ -1430,7 +1640,10 @@ py_library( cuda_py_test( name = "lstm_test", srcs = ["lstm_test.py"], - tags = ["no_rocm"], + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "no_rocm", + ], xla_enable_strict_auto_jit = False, deps = [ ":lstm", @@ -1474,6 +1687,9 @@ tf_xla_py_test( size = "medium", srcs = ["fake_quant_ops_test.py"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/python:framework", @@ -1486,6 +1702,9 @@ tf_xla_py_test( size = "small", srcs = ["placeholder_test.py"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -1499,6 +1718,9 @@ tf_xla_py_test( size = "medium", srcs = ["quantized_ops_test.py"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/compiler/tf2xla/python:xla", @@ -1516,6 +1738,9 @@ tf_xla_py_test( size = "medium", srcs = ["xla_ops_test.py"], python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], deps = [ ":xla_test", "//tensorflow/compiler/tf2xla/python:xla", @@ -1535,6 +1760,7 @@ tf_xla_py_test( shard_count = 5, tags = [ "no_oss", # TODO(b/148108508): Re-enable this test in OSS. + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "no_rocm", ], deps = [ @@ -1560,6 +1786,7 @@ tf_xla_py_test( ], python_version = "PY3", tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "optonly", ], deps = [ @@ -1576,7 +1803,10 @@ tf_xla_py_test( size = "medium", srcs = ["special_math_test.py"], shard_count = 5, - tags = ["optonly"], + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "optonly", + ], deps = [ ":xla_test", "//tensorflow/python:extra_py_tests_deps", diff --git a/third_party/toolchains/preconfig/centos6/tensorrt5/LICENSE b/tensorflow/compiler/tests/__init__.py similarity index 100% rename from third_party/toolchains/preconfig/centos6/tensorrt5/LICENSE rename to tensorflow/compiler/tests/__init__.py diff --git a/tensorflow/compiler/tests/searchsorted_op_test.py b/tensorflow/compiler/tests/searchsorted_op_test.py deleted file mode 100644 index d77bd0902d3..00000000000 --- a/tensorflow/compiler/tests/searchsorted_op_test.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Test for XLA implementation of tf.searchsorted.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.compiler.tests import xla_test -from tensorflow.python.ops import array_ops -from tensorflow.python.platform import test - - -class SearchSorteddOpTest(xla_test.XLATestCase): - - def test1D(self): - # Test against NumPy implementation (which is 1D only). - np.random.seed(1) - for side in ['left', 'right']: - for dtype in [np.float32, np.int32]: - values = np.random.uniform( - low=-1000, high=1000, size=(10,)).astype(dtype) - unsorted = np.random.uniform( - low=-1000, high=1000, size=(20,)).astype(dtype) - - sorted_sequence = np.sort(unsorted) - np_ans = np.searchsorted(sorted_sequence, values, side=side) - - with self.session() as session: - with self.test_scope(): - tf_ans = array_ops.searchsorted(sorted_sequence, values, side=side) - tf_out = session.run(tf_ans) - self.assertAllEqual(np_ans, tf_out) - - def _test2DExample(self, dtype, side, sorted_sequence, values, correct_ans): - - with self.session() as session: - with self.test_scope(): - tf_ans = array_ops.searchsorted(sorted_sequence, values, side=side) - tf_out = session.run(tf_ans) - self.assertAllEqual(correct_ans, tf_out) - - def testLowerBound2DExample(self): - # 2D TensorFlow documentation example. - for dtype in self.float_types | self.int_types: - sorted_sequence = np.array([[0, 3, 9, 9, 10], [1, 2, 3, 4, 5]], dtype) - values = np.array([[2, 4, 9], [0, 2, 6]], dtype) - correct_ans = np.array([[1, 2, 2], [0, 1, 5]], dtype) - self._test2DExample(dtype, 'left', sorted_sequence, values, correct_ans) - - def testUpperBound2DExample(self): - # 2D TensorFlow documentation example. - for dtype in self.float_types | self.int_types: - sorted_sequence = np.array([[0, 3, 9, 9, 10], [1, 2, 3, 4, 5]], dtype) - values = np.array([[2, 4, 9], [0, 2, 6]], dtype) - correct_ans = np.array([[1, 2, 4], [0, 2, 5]], dtype) - self._test2DExample(dtype, 'right', sorted_sequence, values, correct_ans) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/compiler/tests/special_math_test.py b/tensorflow/compiler/tests/special_math_test.py index 7beebf0720e..b3abc40f82d 100644 --- a/tensorflow/compiler/tests/special_math_test.py +++ b/tensorflow/compiler/tests/special_math_test.py @@ -29,6 +29,10 @@ import scipy.special as sps import six from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import gen_random_ops +from tensorflow.python.ops import gradient_checker_v2 from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -39,6 +43,13 @@ flags.DEFINE_bool('vary_seed', False, NUM_SAMPLES = int(1e3) +# This is df/da / df/dx, where f = igamma. +def implicit_reparameterization_grad(a, x): + log_prob = math_ops.xlogy(a - 1., x) - math_ops.lgamma(a) - x + prob = math_ops.exp(log_prob) + return -gen_math_ops.igamma_grad_a(a, x) / prob + + class IgammaTest(xla_test.XLATestCase, parameterized.TestCase): def setUp(self): @@ -48,9 +59,15 @@ class IgammaTest(xla_test.XLATestCase, parameterized.TestCase): answer = int(entropy.encode('hex'), 16) else: answer = int.from_bytes(entropy, 'big') - np.random.seed(answer) + np.random.seed(answer % (2**32 - 1)) super(IgammaTest, self).setUp() + # Skip Float64 test on TPU due to missing ops. + def maybe_skip_test(self, dtype): + if self.device not in ['XLA_GPU', 'XLA_CPU', 'CPU'] and dtype == np.float64: + self.skipTest( + 'Skipping test because some F64 operations not supported on TPU.') + @parameterized.parameters((np.float32, 1e-2, 1e-11), (np.float64, 1e-4, 1e-30)) def testIgammaSmallValues(self, dtype, rtol, atol): @@ -93,6 +110,97 @@ class IgammaTest(xla_test.XLATestCase, parameterized.TestCase): actual = sess.run(math_ops.igamma(a, x)) self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol) + # We don't check small values because the numerical gradients become quite + # large. + @parameterized.parameters((np.float32, 0.09), (np.float64, 1e-7)) + def testIgammaGradMediumValues(self, dtype, tolerance): + self.maybe_skip_test(dtype) + with self.session(): + with self.test_scope(): + x = constant_op.constant( + np.random.uniform(low=1., high=100., + size=[NUM_SAMPLES]).astype(dtype)) + a = constant_op.constant( + np.random.uniform(low=1., high=100., + size=[NUM_SAMPLES]).astype(dtype)) + + f = lambda b: math_ops.igamma(b, x) + max_error = gradient_checker_v2.max_error( + *gradient_checker_v2.compute_gradient(f, x=[a], delta=1e-3)) + self.assertLessEqual(max_error, tolerance) + + @parameterized.parameters((np.float32, 0.5), (np.float64, 1e-7)) + def testIgammaGradLargeValues(self, dtype, tolerance): + self.maybe_skip_test(dtype) + with self.session(): + with self.test_scope(): + x = constant_op.constant( + np.random.uniform(low=100., high=int(1e4), + size=[NUM_SAMPLES]).astype(dtype)) + a = constant_op.constant( + np.random.uniform(low=100., high=int(1e4), + size=[NUM_SAMPLES]).astype(dtype)) + + f = lambda b: math_ops.igamma(b, x) + max_error = gradient_checker_v2.max_error( + *gradient_checker_v2.compute_gradient(f, x=[a], delta=1e-2)) + self.assertLessEqual(max_error, tolerance) + + @parameterized.parameters((np.float32, 1e-2, 1e-11), + (np.float64, 1e-4, 1e-30)) + def testRandomGammaGradSmallValues(self, dtype, rtol, atol): + self.maybe_skip_test(dtype) + # Test values near zero. + + with self.session() as sess: + with self.test_scope(): + x = constant_op.constant( + np.random.uniform( + low=np.finfo(dtype).tiny, high=1., + size=[NUM_SAMPLES]).astype(dtype)) + a = constant_op.constant( + np.random.uniform( + low=np.finfo(dtype).tiny, high=1., + size=[NUM_SAMPLES]).astype(dtype)) + gamma_sample_grad = gen_random_ops.random_gamma_grad(a, x) + actual_grad = implicit_reparameterization_grad(a, x) + gamma_sample_grad, actual_grad = sess.run( + [gamma_sample_grad, actual_grad]) + # We do this because the ratio computed in + # implicit_reparameterization_grad can very easily result in a NaN due + # to the computed numerator and denominator zeroing out. + gamma_sample_grad = gamma_sample_grad[ + ~np.logical_or(np.isnan(actual_grad), np.isinf(actual_grad))] + actual_grad = actual_grad[ + ~np.logical_or(np.isnan(actual_grad), np.isinf(actual_grad))] + self.assertAllClose(actual_grad, gamma_sample_grad, atol=atol, rtol=rtol) + + @parameterized.parameters((np.float32, 1e-2, 1e-11), + (np.float64, 1e-4, 1e-30)) + def testRandomGammaGradMediumValues(self, dtype, rtol, atol): + self.maybe_skip_test(dtype) + + with self.session() as sess: + with self.test_scope(): + x = constant_op.constant( + np.random.uniform(low=1., high=10., + size=[NUM_SAMPLES]).astype(dtype)) + a = constant_op.constant( + np.random.uniform(low=1., high=10., + size=[NUM_SAMPLES]).astype(dtype)) + gamma_sample_grad = gen_random_ops.random_gamma_grad(a, x) + actual_grad = implicit_reparameterization_grad(a, x) + gamma_sample_grad, actual_grad = sess.run( + [gamma_sample_grad, actual_grad]) + # We do this because the ratio computed in + # implicit_reparameterization_grad can very easily result in a NaN due + # to the computed numerator and denominator zeroing out. + gamma_sample_grad = gamma_sample_grad[ + ~np.logical_or(np.isnan(actual_grad), np.isinf(actual_grad))] + actual_grad = actual_grad[ + ~np.logical_or(np.isnan(actual_grad), np.isinf(actual_grad))] + self.assertAllClose(actual_grad, gamma_sample_grad, atol=atol, rtol=rtol) + if __name__ == '__main__': os.environ['XLA_FLAGS'] = '--xla_cpu_enable_fast_math=false' diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index c3ecc1c6215..a0aea950cde 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -587,6 +587,26 @@ class UnaryOpsTest(xla_test.XLATestCase): rtol=1e-6, atol=1e-6) + # For real part close to zero, or imaginary part close to a multiple of + # pi. + + self._assertOpOutputMatchesExpected( + math_ops.expm1, + np.array([[1e-11 + 1j, -1e-11 - 1j, 1. + 1e-11j, + -1. - 1e-11j, 1e-13j + 1e-13j]], dtype=dtype), + # TODO(srvasude): Use numpy as the source of truth after we depend on + # latest numpy with this pull request: + # https://github.com/numpy/numpy/pull/15110. + # The numbers below were generated by scipy.special.expm1. + expected=np.array([[ + -4.59697694e-01+8.41470985e-01j, + -4.59697694e-01-8.41470985e-01j, + 1.71828183e+00+2.71828183e-11j, + -6.32120559e-01-3.67879441e-12j, + -2.00000000e-26+2.00000000e-13j]], dtype=dtype), + rtol=1e-09, + atol=1e-20) + self._assertOpOutputMatchesExpected( math_ops.reciprocal, np.array([[1, 2j, 2 + 3j]], dtype=dtype), diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index a55ca56e551..b26b509b067 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -189,6 +189,8 @@ tf_cuda_cc_test( "//tensorflow/core:test_main", "//tensorflow/core:testlib", "//tensorflow/core/kernels:ops_testutil", + "//tensorflow/core/kernels:function_ops", + "//tensorflow/core/kernels:array", ] + if_tensorrt([ "@local_config_cuda//cuda:cuda_headers", ]), @@ -247,10 +249,12 @@ tf_cuda_library( srcs = [ "utils/trt_int8_calibrator.cc", "utils/trt_lru_cache.cc", + "utils/trt_shape_optimization_profiles.cc", ], hdrs = [ "utils/trt_int8_calibrator.h", "utils/trt_lru_cache.h", + "utils/trt_shape_optimization_profiles.h", ], deps = [ ":trt_allocator", @@ -306,6 +310,22 @@ tf_cc_test( ], ) +tf_cuda_cc_test( + name = "trt_shape_optimization_profiles_test", + size = "small", + srcs = ["utils/trt_shape_optimization_profiles_test.cc"], + tags = [ + "no_cuda_on_cpu_tap", + "no_windows", + "nomac", + ], + deps = [ + ":trt_resources", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + tf_cuda_library( name = "logger_registry", srcs = ["convert/logger_registry.cc"], diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc index 0131d45f815..6f276546451 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc @@ -431,7 +431,8 @@ Status CreateTRTNode(const ConversionParams& params, calibrate_int8 ? TrtPrecisionMode::FP32 : info.precision_mode, max_batch_size, info.max_workspace_size_bytes, input_shapes, trt_logger, alloc, /*calibrator=*/nullptr, &engine, info.use_calibration, - params.use_implicit_batch, /*convert_successfully=*/nullptr)); + params.use_implicit_batch, /*convert_successfully=*/nullptr, + /*profile=*/nullptr)); TrtUniquePtrType engine_data(engine->serialize()); segment_string = string(static_cast(engine_data->data()), engine_data->size()); @@ -468,6 +469,7 @@ Status CreateTRTNode(const ConversionParams& params, .Attr("precision_mode", prec_string) .Attr("use_calibration", info.use_calibration) .Attr("_use_implicit_batch", params.use_implicit_batch) + .Attr("_allow_build_at_runtime", info.allow_build_at_runtime) .Attr("OutT", out_types) .Finalize(&trt_node); if (!status.ok()) { @@ -671,6 +673,7 @@ Status ConvertAfterShapes(const ConversionParams& params) { : EngineInfo::EngineType::TRTStatic); curr_engine.use_calibration = params.use_calibration; curr_engine.maximum_cached_engines = params.max_cached_engines; + curr_engine.allow_build_at_runtime = params.allow_build_at_runtime; status = RegisterGraphToFunctionLibrary(curr_engine.segment_graph_def, &graph, curr_engine.engine_name); diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h index 00dc4c72f43..2bfaa2a786c 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h @@ -49,6 +49,7 @@ struct ConversionParams { int max_cached_engines = 1; bool use_calibration = true; bool use_implicit_batch = true; + bool allow_build_at_runtime = true; }; // Method to call from optimization pass diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index 2d4c0d49bad..433564513db 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -32,6 +32,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h" #include "tensorflow/core/framework/node_def.pb.h" // NOLINT #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/tensor.pb.h" // NOLINT @@ -249,6 +250,16 @@ void GetInputProperties(const grappler::GraphProperties& graph_properties, } } +// This function checks if a tensor is compatible with TRT. +// +// We check that the shape and datatype are compatible with TensorRT. We also +// return the corresponding trt_dtype, the trt_dims and the batch_size (latter +// is only needed in implicit batch mode). +// +// The return status indicates wether the tensor is compatible. +// +// For implicit batch mode, when validation_only == false, we also check that +// all input dimensions (besides the batch dimension) are known dimensions. Status ValidateTensorProperties(const string& producer_node_type, const DataType dtype, const PartialTensorShape& shape, @@ -293,11 +304,7 @@ Status ValidateTensorProperties(const string& producer_node_type, if (validation_only) return Status::OK(); - // Following checks are only used during TRT engine creation time. In implicit - // batch mode we check that all inputs for the network has static shape (as - // required by the TensorRT). The only exception is the batch size, which - // could be unknown. In contrast, using explicit batch mode this test is not - // necessary, since any dimension could be unknown in explicit batch mode. + // Following checks are only used during TRT engine creation time. if (use_implicit_batch) { for (int d = first_trt_dim; d < shape.dims(); ++d) { if (shape.dim_size(d) < 0) { @@ -653,6 +660,9 @@ size_t TRT_ShapedWeights::size_bytes() const { data_type_size = 2; break; case nvinfer1::DataType::kINT8: +#if IS_TRT_VERSION_GE(7, 0, 0, 0) + case nvinfer1::DataType::kBOOL: +#endif data_type_size = 1; break; } @@ -1336,7 +1346,7 @@ Status Converter::RenameAndMarkOutputTensors( Status Converter::BuildCudaEngine( TrtUniquePtrType* engine, int max_batch_size, size_t max_workspace_size_bytes, nvinfer1::IGpuAllocator* allocator, - TRTInt8Calibrator* calibrator) { + TRTInt8Calibrator* calibrator, TrtShapeOptimizationProfile* profiles) { VLOG(1) << "Configuring TensorRT builder"; trt_builder_->setMaxBatchSize(max_batch_size); trt_builder_->setGpuAllocator(allocator); @@ -1356,7 +1366,10 @@ Status Converter::BuildCudaEngine( builder_config->setInt8Calibrator(nullptr); } } - + if (!use_implicit_batch_ && profiles) { + TF_RETURN_IF_ERROR(profiles->ConfigureBuilder( + trt_builder_.get(), builder_config.get(), network())); + } VLOG(1) << "Building TensorRT engine"; engine->reset( trt_builder_->buildEngineWithConfig(*network(), *builder_config)); @@ -5743,7 +5756,8 @@ Status ConvertGraphDefToEngine( nvinfer1::ILogger* trt_logger, nvinfer1::IGpuAllocator* allocator, TRTInt8Calibrator* calibrator, TrtUniquePtrType* engine, bool use_calibration, - const bool use_implicit_batch, bool* convert_successfully) { + const bool use_implicit_batch, bool* convert_successfully, + TrtShapeOptimizationProfile* profiles) { engine->reset(); if (convert_successfully) *convert_successfully = false; @@ -5842,7 +5856,8 @@ Status ConvertGraphDefToEngine( // Build the engine. TF_RETURN_IF_ERROR(converter->BuildCudaEngine( - engine, max_batch_size, max_workspace_size_bytes, allocator, calibrator)); + engine, max_batch_size, max_workspace_size_bytes, allocator, calibrator, + profiles)); VLOG(1) << "Finished conversion"; return Status::OK(); diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h index d295f074a98..8608c8226ee 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/grappler/costs/graph_properties.h" @@ -92,7 +93,8 @@ struct EngineInfo { : engine_type(EngineType::TRTStatic), max_workspace_size_bytes(0), precision_mode(TrtPrecisionMode::FP32), - use_calibration(true) {} + use_calibration(true), + allow_build_at_runtime(true) {} string engine_name; string device; @@ -109,6 +111,7 @@ struct EngineInfo { int maximum_cached_engines; TrtPrecisionMode precision_mode; bool use_calibration; + bool allow_build_at_runtime; }; // Constructs a graphdef from the segment in the given graph. Adds _Arg @@ -145,7 +148,8 @@ Status ConvertGraphDefToEngine( nvinfer1::ILogger* logger, nvinfer1::IGpuAllocator* allocator, TRTInt8Calibrator* calibrator, TrtUniquePtrType* engine, bool use_calibration, - const bool use_implicit_batch, bool* convert_successfully); + const bool use_implicit_batch, bool* convert_successfully, + TrtShapeOptimizationProfile* profiles); // Helper class for the segmenter to determine whether an output edge from the // TRT segment is valid. @@ -465,7 +469,8 @@ class Converter { Status BuildCudaEngine(TrtUniquePtrType* engine, int max_batch_size, size_t max_workspace_size_bytes, nvinfer1::IGpuAllocator* allocator, - TRTInt8Calibrator* calibrator); + TRTInt8Calibrator* calibrator, + TrtShapeOptimizationProfile* profiles); ////////////////////////////////////////////////////////////////////////////// // Methods used by op converters to convert individual TF node and add layers diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index 98aaa18e9fc..400c53614f9 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -1187,7 +1187,7 @@ class ConvertGraphDefToEngineTest : public ::testing::Test { /*max_workspace_size_bytes=*/64 << 20, input_shapes, &logger_, /*allocator=*/nullptr, /*calibrator=*/nullptr, &engine_, /*use_calibration=*/false, /*use_implicit_batch=*/true, - /*convert_successfully=*/nullptr); + /*convert_successfully=*/nullptr, /*profiles=*/nullptr); } protected: @@ -1302,7 +1302,8 @@ class OpConverterTest : public ::testing::Test { /*max_batch_size=*/batch_size, /*max_workspace_size_bytes=*/1 << 26, /*allocator=*/nullptr, - /*calibrator=*/nullptr)); + /*calibrator=*/nullptr, + /*profiles=*/nullptr)); CHECK_NOTNULL(engine_.get()); CheckDataTypeMatches(input_data); CheckDataTypeMatches(*output_data); diff --git a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc index 757ddd159c9..7995163ed44 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc @@ -70,6 +70,9 @@ Status TRTOptimizationPass::Init( if (params.count("trt_logger")) { trt_logger_name_ = params.at("trt_logger").s(); } + if (params.count("allow_build_at_runtime")) { + allow_build_at_runtime_ = params.at("allow_build_at_runtime").b(); + } if (params.count("use_implicit_batch")) { use_implicit_batch_ = params.at("use_implicit_batch").b(); } @@ -265,6 +268,7 @@ Status TRTOptimizationPass::Optimize(grappler::Cluster* cluster, cp.max_cached_engines = max_cached_batches_; cp.use_calibration = use_calibration_; cp.use_implicit_batch = use_implicit_batch_; + cp.allow_build_at_runtime = allow_build_at_runtime_; auto status = ConvertAfterShapes(cp); VLOG(1) << "Returning from " << name_; return status; diff --git a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h index 3ce0d09b7c0..f79048bb5f6 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h +++ b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h @@ -42,7 +42,8 @@ class TRTOptimizationPass : public grappler::CustomGraphOptimizer { max_cached_batches_(1), max_workspace_size_bytes_(256LL << 20), use_calibration_(true), - use_implicit_batch_(true) { + use_implicit_batch_(true), + allow_build_at_runtime_(true) { VLOG(1) << "Constructing " << name_; } @@ -75,6 +76,7 @@ class TRTOptimizationPass : public grappler::CustomGraphOptimizer { int64_t max_workspace_size_bytes_; bool use_calibration_; bool use_implicit_batch_; + bool allow_build_at_runtime_; }; } // namespace convert diff --git a/tensorflow/compiler/tf2tensorrt/convert/utils.cc b/tensorflow/compiler/tf2tensorrt/convert/utils.cc index ae6555d2219..2fb8902883e 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/utils.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/utils.cc @@ -133,6 +133,25 @@ string DebugString(const std::vector& shapes) { string DebugString(const std::vector& shapes) { return PartialTensorShapeUtils::PartialShapeListString(shapes); } + +int GetNumberOfEngineInputs(const nvinfer1::ICudaEngine* engine) { + int n_bindings = engine->getNbBindings(); + int n_input = 0; + for (int i = 0; i < n_bindings; i++) { + if (engine->bindingIsInput(i)) n_input++; + } + // According to TensorRT 7 doc: "If the engine has been built for K profiles, + // the first getNbBindings() / K bindings are used by profile number 0, the + // following getNbBindings() / K bindings are used by profile number 1 etc." + // Therefore, to get the number of input tensors, we need to divide by the + // the number of profiles. +#if IS_TRT_VERSION_GE(6, 0, 0, 0) + int n_profiles = engine->getNbOptimizationProfiles(); +#else + int n_profiles = 1; +#endif + return n_input / n_profiles; +} #endif string GetLinkedTensorRTVersion() { diff --git a/tensorflow/compiler/tf2tensorrt/convert/utils.h b/tensorflow/compiler/tf2tensorrt/convert/utils.h index 97dcf8976f4..668620bb90a 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/utils.h +++ b/tensorflow/compiler/tf2tensorrt/convert/utils.h @@ -106,6 +106,11 @@ string GetLinkedTensorRTVersion(); // TensorRT library version information {Maj, Min, Patch}. string GetLoadedTensorRTVersion(); +// Returns the number of inputs for the engine, which also correspends to the +// number of input tensors for the network. This can differ from the number of +// input bindings, because the number of total input bindings equals the number +// of profiles times the number of engine inputs. +int GetNumberOfEngineInputs(const nvinfer1::ICudaEngine* engine); #endif // GOOGLE_CUDA && GOOGLE_TENSORRT } // namespace tensorrt diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc index 909e3e11006..25bed655f19 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/framework/function.h" @@ -92,7 +93,7 @@ class TRTEngineOp : public AsyncOpKernel { LRUCache, std::unique_ptr, VectorTensorShapeHasher>; - // Execute calibration + // Executes calibration. void ExecuteCalibration(OpKernelContext* ctx, TRTEngineCacheResource* cache_res, AsyncHelper* helper); @@ -103,14 +104,15 @@ class TRTEngineOp : public AsyncOpKernel { Status ConstructFunctionHandle(FunctionLibraryRuntime* lib, const string& device_name); - // Execute replaced native segment as function Op. + // Executes replaced native segment as function Op. void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper); - // Execute the tensorrt engine. Returns whether we need to retry by running + // Executes the tensorrt engine. Returns whether we need to retry by running // the native segment. - bool ExecuteTrtEngine(OpKernelContext* ctx, EngineContext* engine_context); + bool ExecuteTrtEngine(OpKernelContext* ctx, EngineContext* engine_context, + int trt_context_idx); - // Allocate necessary resources for calibration + // Allocates necessary resources for calibration. Status AllocateCalibrationResources(OpKernelContext* ctx, TRTEngineCacheResource* cache_res); @@ -157,6 +159,9 @@ class TRTEngineOp : public AsyncOpKernel { // Whether to use implicit batch dimension for TensorRT bool use_implicit_batch_; + // Whether to build TensorRT engines at runtime + bool allow_build_at_runtime_; + // Maximum number of cached engines int max_cached_engines_; @@ -281,6 +286,14 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) context->GetAttr("use_calibration", &use_calibration_)); OP_REQUIRES_OK(context, context->GetAttr("input_shapes", &input_partial_shapes_)); + auto status = + context->GetAttr("_allow_build_at_runtime", &allow_build_at_runtime_); + if (status.code() == tensorflow::error::NOT_FOUND) { + VLOG(2) << "Not found _allow_build_at_runtime in " + << context->device()->name() + << ", thus setting _allow_build_at_runtime=true"; + allow_build_at_runtime_ = true; + } func_handle_ = kInvalidHandle; if (!static_engine_) { FunctionLibraryRuntime* lib = context->function_library(); @@ -302,7 +315,7 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) OP_REQUIRES_OK(context, context->GetAttr("max_cached_engines_count", &max_cached_engines_)); - auto status = context->GetAttr("_use_implicit_batch", &use_implicit_batch_); + status = context->GetAttr("_use_implicit_batch", &use_implicit_batch_); if (status.code() == tensorflow::error::NOT_FOUND) { VLOG(2) << "Not found _use_implicit_batch in " << context->device()->name() << ", thus setting _use_implicit_batch=true"; @@ -594,11 +607,24 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx, OP_REQUIRES_OK_ASYNC(ctx, VerifyInputShapes(input_concrete_shapes), *helper); + if (!use_implicit_batch_) { + if (cache_res->profiles_.GetNumProfiles() == 0) { + // Create a single profile from the current input shape. In the future we + // will collect a set of input shapes during build mode and create + // profiles for each of them. + cache_res->profiles_.AddShape(input_concrete_shapes); + cache_res->profiles_.InitProfiles(); + } + } StatusOr status = GetEngine(input_concrete_shapes, ctx, cache_res); OP_REQUIRES_OK_ASYNC(ctx, status.status(), *helper); EngineContext* engine_context = status.ValueOrDie(); + // Context idx equals with the profile idx because for each profile we create + // one context. Currently we do not have profile_generation mode, therefore we + // have just a single profile. + int trt_context_idx = 0; if (!engine_context->cuda_engine) { VLOG(1) << "Engine retrieval for input shapes: " << TensorShapeUtils::ShapeListString(input_concrete_shapes) @@ -606,7 +632,8 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx, ExecuteNativeSegment(ctx, helper); return; } - const bool retry = ExecuteTrtEngine(ctx, engine_context); + + const bool retry = ExecuteTrtEngine(ctx, engine_context, trt_context_idx); if (retry) { LOG(WARNING) << "Failed to execute engine, " << "retrying with native segment for " << name(); @@ -654,7 +681,8 @@ Status GetTrtBindingIndex(const char* tensor_name, int profile_index, } bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx, - EngineContext* engine_context) { + EngineContext* engine_context, + int trt_context_idx) { VLOG(1) << "Executing TRT engine: " << name(); auto& cuda_engine = engine_context->cuda_engine; @@ -677,6 +705,11 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx, } const bool kRetry = true; + if (trt_context_idx >= 1) { + LOG(ERROR) << "Requested engine context with index " << trt_context_idx + << ", but only 1 context is present."; + return kRetry; + } const int num_binding = cuda_engine->getNbBindings(); std::vector buffers(num_binding); @@ -687,8 +720,8 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx, for (int i = 0; i < ctx->num_inputs(); i++) { const string input_name = StrCat(IONamePrefixes::kInputPHName, i); int binding_index; - auto status = GetTrtBindingIndex(input_name.c_str(), 0, cuda_engine.get(), - &binding_index); + auto status = GetTrtBindingIndex(input_name.c_str(), trt_context_idx, + cuda_engine.get(), &binding_index); if (!status.ok()) { ctx->SetStatus(status); return !kRetry; @@ -759,8 +792,8 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx, for (int i = 0; i < ctx->num_outputs(); i++) { const string output_name = StrCat(IONamePrefixes::kOutputPHName, i); int binding_index; - auto status = GetTrtBindingIndex(output_name.c_str(), 0, cuda_engine.get(), - &binding_index); + auto status = GetTrtBindingIndex(output_name.c_str(), trt_context_idx, + cuda_engine.get(), &binding_index); if (!status.ok()) { ctx->SetStatus(status); return !kRetry; @@ -790,7 +823,7 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx, trt_shape.push_back(dims.d[j]); } } - // Allocate output tensor of TRTEngineOp + // Allocate output tensor of TRTEngineOp. Tensor* output_tensor = nullptr; TensorShape output_shape; status = TensorShapeUtils::MakeShape(trt_shape.data(), trt_shape.size(), @@ -957,6 +990,16 @@ StatusOr TRTEngineOp::GetEngine( // If matched, use that engine. Otherwise, we will look in cache for that // exact shape and possibly create a new engine if it is not in cache. if (!cache.count(engine_input_shapes)) { + if (!allow_build_at_runtime_) { + LOG(WARNING) << "Found no engine in cache matching input shapes. " + << "Not building a new engine because " + << "allow_build_at_runtime=False. " + << "The native segment will be used instead."; + // Store an empty engine in the cache for these input shapes so we don't + // try to build the same failing engine again. + cache.emplace(engine_input_shapes, absl::make_unique()); + return &empty_context; + } TrtUniquePtrType engine; bool convert_successfully = false; LOG(INFO) << "Building a new TensorRT engine for " << name() @@ -976,7 +1019,8 @@ StatusOr TRTEngineOp::GetEngine( auto status = convert::ConvertGraphDefToEngine( segment_graph_def_, precision_mode_, batch_size, workspace_size_, conversion_input_shapes, &logger, allocator, calibrator_.get(), &engine, - use_calibration_, use_implicit_batch_, &convert_successfully); + use_calibration_, use_implicit_batch_, &convert_successfully, + &cache_res->profiles_); if (!status.ok()) { LOG(WARNING) << "Engine creation for " << name() << " failed. " << "The native segment will be used instead. " @@ -986,11 +1030,12 @@ StatusOr TRTEngineOp::GetEngine( cache.emplace(input_concrete_shapes, absl::make_unique()); return &empty_context; } - TrtUniquePtrType exec_context( - engine->createExecutionContext()); + std::vector> exec_context; + TF_RETURN_IF_ERROR(cache_res->profiles_.CreateExecutionContexts( + engine.get(), exec_context)); cache.emplace(input_concrete_shapes, absl::make_unique(std::move(engine), - std::move(exec_context))); + std::move(exec_context[0]))); VLOG(1) << "Added new engine to cache of " << name() << ". Cache size: " << cache.size(); } @@ -1064,9 +1109,9 @@ Status TRTEngineOp::AllocateCalibrationResources( this->segment_graph_def_, TrtPrecisionMode::INT8, cres->calibrator_->getBatchSize(), this->workspace_size_, partial_shapes, &cache_res->GetLogger(), cache_res->allocator_.get(), - cres->calibrator_.get(), &cres->engine_, - /*use_calibration=*/true, this->use_implicit_batch_, - /*convert_successfully=*/nullptr); + cres->calibrator_.get(), &cres->engine_, /*use_calibration=*/true, + this->use_implicit_batch_, /*convert_successfully=*/nullptr, + /*profiles=*/nullptr); if (!s.ok()) { LOG(ERROR) << "Calibration failed: " << s; cres->calibrator_->setDone(); // Ignore further pushes diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc index a88f2b5e29e..da8bd6686a7 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc @@ -62,7 +62,8 @@ class TRTEngineOpTestBase : public OpsTestBase { public: void AddSimpleTrtOp(DataType dtype, int max_cached_engines_count = 1, PartialTensorShape shape = PartialTensorShape({-1, -1}), - bool use_implicit_batch = true) { + bool use_implicit_batch = true, + bool allow_build_at_runtime = true) { // Create the GPU device. std::unique_ptr device( DeviceFactory::NewDevice("GPU", {}, "/job:worker/replica:0/task:0")); @@ -104,6 +105,7 @@ class TRTEngineOpTestBase : public OpsTestBase { .Attr("precision_mode", "FP32") .Attr("use_calibration", false) .Attr("_use_implicit_batch", use_implicit_batch) + .Attr("_allow_build_at_runtime", allow_build_at_runtime) .Attr("OutT", {dtype}) .Finalize(OpsTestBase::node_def())); TF_ASSERT_OK(InitOpWithFunctionLibrary()); @@ -127,9 +129,14 @@ class TRTEngineOpTestBase : public OpsTestBase { private: Status InitOpWithFunctionLibrary() { OpKernel* kernel = nullptr; - Status status = CreateOpKernel(device_type_, device_, allocator(), - pflr_->GetFLR(device_->name()), node_def_, - TF_GRAPH_DEF_VERSION, &kernel); + auto flr = pflr_->GetFLR(device_->name()); + std::shared_ptr props; + Status status = NodeProperties::CreateFromNodeDef( + node_def_, flr->GetFunctionLibraryDefinition(), &props); + if (status.ok()) { + status.Update(CreateOpKernel(device_type_, device_, allocator(), flr, + props, TF_GRAPH_DEF_VERSION, &kernel)); + } kernel_ = std::unique_ptr(kernel); if (kernel_ != nullptr) input_types_ = kernel_->input_types(); return status; @@ -186,6 +193,33 @@ TEST_F(TRTEngineOpTestBase, DynamicEngines) { EXPECT_EQ(1, cache->count({TensorShape({10, 10})})); } +TEST_F(TRTEngineOpTestBase, AllowBuildAtRuntime) { + TRTEngineOpTestBase::AddSimpleTrtOp(DT_FLOAT, /*max_cached_engines_count=*/1, + PartialTensorShape({-1, -1}), + /*use_implicit_batch=*/true, + /*allow_build_at_runtime=*/false); + + // Execute the op + TensorShape input_shape({2, 2}); + TRTEngineOpTestBase::AddSimpleInput(input_shape); + TF_ASSERT_OK(OpsTestBase::RunOpKernel()); + + // Get the engine cache. + TRTEngineCacheResource* cache_resource = nullptr; + TF_ASSERT_OK( + device_->resource_manager()->Lookup("TF-TRT", "myop", &cache_resource)); + core::ScopedUnref sc(cache_resource); + + // It should contain a placeholder with an empty cuda_engine (to mark that + // engine creation was not successful for the given input shape). + auto cache = &cache_resource->cache_; + EXPECT_EQ(1, cache->size()); + ASSERT_EQ(1, cache->count({input_shape})); + EngineContext* ectx = cache->at({input_shape}).get(); + EXPECT_EQ(ectx->cuda_engine, nullptr); +} + +#if IS_TRT_VERSION_GE(6, 0, 0, 0) TEST_F(TRTEngineOpTestBase, ExplicitBatch) { // Test inference in explicit batch mode with static input shapes. Static // shapes in this context means that the TensorRT knows all the input shapes @@ -225,15 +259,6 @@ TEST_F(TRTEngineOpTestBase, DynamicShapes) { TensorShape input_shape({1, 2}); TRTEngineOpTestBase::AddSimpleInput(input_shape); - // We expect that TensorRT engine creation fails: we would need to configure - // the engine with optimization profiles to use dynamic input shapes, but that - // feature is not yet implemented. - // - // Since TRT engine creation has failed, we fall back to native segment. - // Calling the native segment fails for the same reason that is investigated - // in https://github.com/tensorflow/tensorflow/pull/34919. This is irrelevant - // for the current test, here we want to just check wether TRT engine creation - // has failed. TF_ASSERT_OK(OpsTestBase::RunOpKernel()); // Get the engine cache. @@ -246,11 +271,8 @@ TEST_F(TRTEngineOpTestBase, DynamicShapes) { auto cache = &cache_resource->cache_; EXPECT_EQ(1, cache->size()); ASSERT_EQ(1, cache->count({input_shape})); - // TODO(bixia): re-enable the check below when the problem is fixed. - // EngineContext* ectx = cache->at({input_shape}).get(); - // Since engine creation failed, we expect to find nullptr. Finding a nullptr - // indicates that unknown shapes were used to define the TensorRT network. - // EXPECT_EQ(ectx->cuda_engine, nullptr); + EngineContext* ectx = cache->at({input_shape}).get(); + EXPECT_NE(ectx->cuda_engine, nullptr); } template @@ -274,6 +296,7 @@ TYPED_TEST(TRTEngineOpTest, Basic) { output->NumElements()), ElementsAre(TypeParam(0.0f), TypeParam(2.0f))); } +#endif } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops.cc index 891b75be824..de7b7381d0c 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops.cc @@ -140,11 +140,24 @@ class InitializeTRTResource : public OpKernel { engine_instance.serialized_engine().c_str(), engine_instance.serialized_engine().size(), nullptr)); auto raw_engine = engine.get(); - resource->cache_.emplace( - engine_input_shapes, - absl::make_unique( - std::move(engine), TrtUniquePtrType( - raw_engine->createExecutionContext()))); + std::vector> ctx_vec; + if (num_loaded_engine == 0) { + // Restore profiles if there are any. Currently only 1 engine is allowed + // in dynamic mode therefore we call this only for the 0th engine. + // it is a no-op in implicit batch mode. + OP_REQUIRES_OK(ctx, resource->profiles_.RestoreProfiles(raw_engine)); + OP_REQUIRES_OK(ctx, resource->profiles_.CreateExecutionContexts( + raw_engine, ctx_vec)); + } else { + // Multiple engines are only available in static mode. For each engine + // we have only a single execution context. + TrtUniquePtrType exec_ctx( + raw_engine->createExecutionContext()); + ctx_vec.push_back(std::move(exec_ctx)); + } + resource->cache_.emplace(engine_input_shapes, + absl::make_unique( + std::move(engine), std::move(ctx_vec[0]))); ++num_loaded_engine; } while (1); VLOG(1) << "Loaded " << num_loaded_engine << " TRT engines for op " diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h index 808b689127e..ae54569a726 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/lib/core/errors.h" @@ -182,6 +183,11 @@ class TRTEngineCacheResource : public ResourceBase { // TODO(hinsu): Use different calibration context for the available shapes and // attach it to each item of the cache. std::unique_ptr calib_ctx_; + + // This object maintains all the optimization profiles during profile + // generation and engine build. During runtime the list of profiles is used to + // look up a matching profile for the input data. + TrtShapeOptimizationProfile profiles_; }; #endif // GOOGLE_TENSORRT diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.cc new file mode 100644 index 00000000000..27ef726514b --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.cc @@ -0,0 +1,185 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h" + +#include +#include + +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" + +#if GOOGLE_CUDA && GOOGLE_TENSORRT +namespace tensorflow { +namespace tensorrt { + +// Creates optimization profiles for a list of input shapes. The list of input +// shapes are stored in shapes_. +void TrtShapeOptimizationProfile::InitProfiles() { + if (input_shapes_.size() == 0) { + VLOG(1) << "Not creating profiles without input_shapes. " + "You have to enable profile generation mode first (build)."; + } else { + VLOG(1) << "Creating profiles with startegy of one profile " + << "for each input (min=opt=max)."; + } + for (auto& shape_vec : input_shapes_) { + std::vector dimvec; + for (auto& shape : shape_vec) { + dimvec.push_back(TensorShapeToTrtDims(shape, false)); + } + // We set min=opt=max. + OptimizationProfileConfig profConfig{dimvec, dimvec, dimvec}; + profiles_.push_back(std::move(profConfig)); + VLOG(1) << "Created profile " << profiles_.back().DebugString(); + } +} + +#if IS_TRT_VERSION_GE(6, 0, 0, 0) +Status TrtShapeOptimizationProfile::AddProfiles( + nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config, + const nvinfer1::INetworkDefinition* network) { + // Create a vector of optimization profiles + for (int i = 0; i < profiles_.size(); i++) { + auto* optProfile = builder->createOptimizationProfile(); + Status status = profiles_[i].SetDimensions(network, optProfile); + if (!status.ok()) { + return status; + } + int idx = -1; + if (optProfile->isValid()) { + idx = config->addOptimizationProfile(optProfile); + } + if (idx >= 0) { + if (i != idx) { + return errors::Internal( + "Profile index of engine config is different from resource profile " + "index: ", + i, " != ", idx); + } + VLOG(1) << "Added optimization profile " << profiles_[i].DebugString() + << " to builder config."; + } else { + LOG(ERROR) << "Failed to add optimization profile " + << profiles_[i].DebugString() + << ". This usually happens when profile is invalid."; + } + } + if (config->getNbOptimizationProfiles() == 0) { + return errors::Internal("Failure in adding an optimization profile."); + } + // if TRT_VERSION < 6, then we do not need to add + return Status::OK(); +} +#endif + +#if IS_TRT_VERSION_GE(6, 0, 0, 0) +Status TrtShapeOptimizationProfile::ConfigureBuilder( + nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config, + const nvinfer1::INetworkDefinition* network) { + TF_RETURN_IF_ERROR(AddProfiles(builder, config, network)); + return Status::OK(); +} +#endif + +int TrtShapeOptimizationProfile::GetProfileNumber( + std::vector shapes) { + for (int i = 0; i < profiles_.size(); i++) { + if (profiles_[i].IncludesShapes(shapes)) { + return i; + } + } + VLOG(1) << "Profile not found for input shapes " << DebugString(shapes) + << "."; + return -1; +} + +Status TrtShapeOptimizationProfile::CreateExecutionContexts( + nvinfer1::ICudaEngine* engine, + std::vector>& exec_context) { + int i = 0; + // The following loop runs once if we have static shapes, to create a single + // execution context without profiles. In dynamic mode we create one context + // for each profile and set the corresponding optimization profile. + do { + VLOG(1) << "Creating execution context " << i; + nvinfer1::IExecutionContext* ctx = engine->createExecutionContext(); + if (ctx == nullptr) { + return errors::Internal("Failed to create execution context"); + } + if (i > 0) { + // This condition is needed for two reasons: + // - using static shapes we do not have any profiles so we cannot call + // set optimizationprofiles. + // - The 0th profile is set implicitly for the first execution context + // therefore we do not need to set. +#if IS_TRT_VERSION_GE(6, 0, 0, 0) + bool stat = ctx->setOptimizationProfile(i); + if (!stat) { + ctx->destroy(); + return errors::Internal("Could not set TRT optimization profile."); + } +#endif + } + exec_context.push_back(TrtUniquePtrType(ctx)); + i++; + } while (i < profiles_.size()); + + return Status::OK(); +} + +Status TrtShapeOptimizationProfile::RestoreProfiles( + const nvinfer1::ICudaEngine* engine) { +#if IS_TRT_VERSION_GE(6, 0, 0, 0) + if (!engine) { + // We do not need to restore profiles for an empty engine + return Status::OK(); + } +#if IS_TRT_VERSION_GE(7, 0, 0, 0) + if (engine->hasImplicitBatchDimension()) { + // Nothing to do, we cannot have profiles in implicit batch mode + return Status::OK(); + } +#endif + int n_profiles = engine->getNbOptimizationProfiles(); + int n_inputs = GetNumberOfEngineInputs(engine); + VLOG(2) << "Attempting to restore " << n_profiles << " profiles, each with " + << n_inputs << " inputs"; + for (int prof_idx = 0; prof_idx < n_profiles; prof_idx++) { + OptimizationProfileConfig cfg; + for (int j = 0; j < n_inputs; j++) { + nvinfer1::Dims min = engine->getProfileDimensions( + j, prof_idx, nvinfer1::OptProfileSelector::kMIN); + nvinfer1::Dims max = engine->getProfileDimensions( + j, prof_idx, nvinfer1::OptProfileSelector::kMAX); + nvinfer1::Dims opt = engine->getProfileDimensions( + j, prof_idx, nvinfer1::OptProfileSelector::kOPT); + cfg.min.push_back(min); + cfg.max.push_back(max); + cfg.opt.push_back(opt); + } + VLOG(2) << "Restored profile " << cfg.DebugString(); + profiles_.push_back(std::move(cfg)); + } +#endif + return Status::OK(); +} + +int TrtShapeOptimizationProfile::GetNumProfiles() const { + return profiles_.size(); +} + +} // namespace tensorrt +} // namespace tensorflow +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h b/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h new file mode 100644 index 00000000000..40c7f5dcf31 --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h @@ -0,0 +1,178 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_SHAPE_OPTIMIZATION_PROFILES_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_SHAPE_OPTIMIZATION_PROFILES_H_ + +#include +#include +#include +#include + +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +#include "third_party/tensorrt/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { + +// Stores optimization profile parameters (min/opt/max of each input shape). +// +// A TensorRT optimization profile describes the possible min/max values of +// each dynamic input shape along with an optimum value. These values are used +// by the TensorRT builder to select the best kernel for the optimum value among +// those kernels that are valid for all input tensors in the [min, max] range. +struct OptimizationProfileConfig { + // Length of vector == num_inputs to engine + std::vector min; + std::vector opt; + std::vector max; + + string DebugString() const { + using absl::StrCat; + return StrCat("[min: ", tensorflow::tensorrt::DebugString(min), + ", opt: : ", tensorflow::tensorrt::DebugString(opt), + ", max: ", tensorflow::tensorrt::DebugString(max), "]"); + } + +#if IS_TRT_VERSION_GE(6, 0, 0, 0) + // Sets the stored min/opt/max dimensions for profile. + // + // Parameters: + // network - TensorRT network, used to enumerate all the input tensors + // profile - on exit the profile information will be set for each input tensor + Status SetDimensions(const nvinfer1::INetworkDefinition* network, + nvinfer1::IOptimizationProfile* profile) const { + int n_inputs = network->getNbInputs(); + if (min.size() != n_inputs || opt.size() != n_inputs || + max.size() != n_inputs) { + return errors::Internal("Incorrect number of profile config parameters"); + } + for (int i = 0; i < n_inputs; i++) { + const char* name = network->getInput(i)->getName(); + profile->setDimensions(name, nvinfer1::OptProfileSelector::kMIN, min[i]); + profile->setDimensions(name, nvinfer1::OptProfileSelector::kOPT, opt[i]); + profile->setDimensions(name, nvinfer1::OptProfileSelector::kMAX, max[i]); + } + return Status::OK(); + } +#endif + + // Returns true if profile range completely includes the given shapes. + bool IncludesShapes(const std::vector& shapes) const { + // min, max, and opt must have the same size which is already verified in + // SetDimensions. + if (min.size() != shapes.size()) { + return false; + } + for (int i = 0; i < shapes.size(); i++) { + auto current_shape = shapes[i]; + // min, max, and opt must have the same nbDims, which is already verified + // in SetDimensions. + if (min[i].nbDims != current_shape.dims()) { + return false; + } + // Check if range [min, max] includes current_shape. + for (int dim = 0; dim < current_shape.dims(); dim++) { + if ((min[i].d[dim] > current_shape.dim_size(dim)) || + (max[i].d[dim] < current_shape.dim_size(dim))) { + return false; + } + } + } + return true; + } +}; + +// Manages Optimization profiles during TRT Engine construction. +// +// An optimization profile describes a range of dimensions for each TRT network +// input, and the optimal dimensions that the auto-tuner should use for +// optimization. +// +// This class stores the list of input shapes that were seen during the +// build/profile_generation_mode phase, and using them it creates a set of +// OptimizationProfileConfigs. These configs will be added to IBuilderConfig +// before the engine is created. +class TrtShapeOptimizationProfile { + public: + TrtShapeOptimizationProfile() {} + + // Stores input shape information during profile_generation_mode + void AddShape(std::vector shapes) { + input_shapes_.insert(shapes); + VLOG(1) << "Collected shape(s) " << DebugString(shapes) << " for profiles."; + } + + void clear() { profiles_.clear(); } + + // Returns the profile number that should be used to execute the network with + // the given input shapes. Returns -1 if none of cached profiles are + // compatible with the given input shapes. + int GetProfileNumber(std::vector shapes); + +#if IS_TRT_VERSION_GE(6, 0, 0, 0) + // Creates optimization profiles and add them to the builder config. + Status ConfigureBuilder(nvinfer1::IBuilder* builder, + nvinfer1::IBuilderConfig* config, + const nvinfer1::INetworkDefinition* network); +#endif + + // Creates execution contexts for each optimization profile. + Status CreateExecutionContexts( + nvinfer1::ICudaEngine* engine, + std::vector>& exec_context); + + // Maps input vector shapes to TRT Optimization profiles (min, max, opt) i.e. + // maps input_shapes_ to profiles_ + void InitProfiles(); + + // Returns number of created profiles. + int GetNumProfiles() const; + + // Restores profiles from the engine (used after deserialization) + Status RestoreProfiles(const nvinfer1::ICudaEngine* engine); + + private: + // Set of input shape vetors that we collect during profile_generation_mode + std::unordered_set, VectorTensorShapeHasher> + input_shapes_; + + // The optimization profiles generated from input_shapes_ + std::vector profiles_; + +#if IS_TRT_VERSION_GE(6, 0, 0, 0) + /// Adds optimization profiles to the builder config + Status AddProfiles(nvinfer1::IBuilder* builder, + nvinfer1::IBuilderConfig* config, + const nvinfer1::INetworkDefinition* network); +#endif +}; + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_SHAPE_OPTIMIZATION_PROFILES_H_ diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles_test.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles_test.cc new file mode 100644 index 00000000000..501810587e0 --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles_test.cc @@ -0,0 +1,218 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +#include + +#include + +#include "absl/memory/memory.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/test.h" +#include "third_party/tensorrt/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { + +std::vector DimVecToShapeVec(std::vector dimvec) { + std::vector shapevec(dimvec.size()); + for (int i = 0; i < dimvec.size(); i++) { + TensorShape shape; + TF_CHECK_OK( + TensorShapeUtils::MakeShape(dimvec[i].d, dimvec[i].nbDims, &shape)); + shapevec[i] = shape; + } + return shapevec; +} + +bool DimsContained(const nvinfer1::Dims& dim, const nvinfer1::Dims& min, + const nvinfer1::Dims& max) { + if (dim.nbDims != min.nbDims || dim.nbDims != max.nbDims) { + return false; + } + for (int i = 0; i < dim.nbDims; i++) { + if (dim.d[i] < min.d[i] || dim.d[i] > max.d[i]) { + return false; + } + } + return true; +} + +bool DimsEqual(const nvinfer1::Dims& a, const nvinfer1::Dims& b) { + if (a.nbDims != b.nbDims) { + return false; + } + for (int i = 0; i < a.nbDims; i++) { + if (a.d[i] != b.d[i]) { + return false; + } + } + return true; +} + +class TrtShapeOptimizationProfileTest : public ::testing::Test { + protected: + void SetUp() override { + builder_ = TrtUniquePtrType( + nvinfer1::createInferBuilder(logger_)); +#if IS_TRT_VERSION_GE(6, 0, 0, 0) + network_ = TrtUniquePtrType( + builder_->createNetworkV2(flags_)); + builder_config_ = TrtUniquePtrType( + builder_->createBuilderConfig()); + builder_config_->setMaxWorkspaceSize(1 << 10); +#else + network_ = TrtUniquePtrType( + builder_->createNetwork()); + builder_->setMaxWorkspaceSize(1 << 10); +#endif + } + + // Defines a simple network: output = input1 + input2. + void DefineNetwork(nvinfer1::INetworkDefinition* network, + nvinfer1::Dims3& dims) { + nvinfer1::ITensor* input1 = + network->addInput("input1", nvinfer1::DataType::kFLOAT, dims); + EXPECT_NE(nullptr, input1); + + nvinfer1::ITensor* input2 = + network->addInput("input2", nvinfer1::DataType::kFLOAT, dims); + EXPECT_NE(nullptr, input1); + + auto layer = network->addElementWise(*input1, *input2, + nvinfer1::ElementWiseOperation::kSUM); + EXPECT_NE(nullptr, layer); + // Mark the output. + nvinfer1::ITensor* output = layer->getOutput(0); + output->setName("output"); + network->markOutput(*output); + } + + Logger logger_; + TrtUniquePtrType builder_; + TrtUniquePtrType network_; +#if IS_TRT_VERSION_GE(6, 0, 0, 0) + TrtUniquePtrType builder_config_; +#endif + TrtUniquePtrType engine; + std::vector> exec_context_; + // The order is important: exec_context_ must be destroyed first, and logger + // at last. +#if IS_TRT_VERSION_GE(6, 0, 0, 0) + const uint32_t flags_ = + 1U << static_cast( + nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); +#endif +}; + +TEST_F(TrtShapeOptimizationProfileTest, Static) { + // Network with static input shape + nvinfer1::Dims3 dims(8, 8, 10); + DefineNetwork(network_.get(), dims); + + TrtShapeOptimizationProfile profile; + +#if IS_TRT_VERSION_GE(6, 0, 0, 0) + // Configure and build engine - should be a no-op + TF_CHECK_OK(profile.ConfigureBuilder(builder_.get(), builder_config_.get(), + network_.get())); + + engine = TrtUniquePtrType( + builder_->buildEngineWithConfig(*network_, *builder_config_)); +#else + engine = TrtUniquePtrType( + builder_->buildCudaEngine(*network_)); +#endif + EXPECT_NE(nullptr, engine); + TF_CHECK_OK(profile.CreateExecutionContexts(engine.get(), exec_context_)); + // A single execution context should be created for a graph with static input + ASSERT_EQ(exec_context_.size(), 1); + EXPECT_NE(nullptr, exec_context_[0]); + + std::vector dim_vec(2, dims); + std::vector shape_vec = DimVecToShapeVec(dim_vec); + EXPECT_EQ(-1, profile.GetProfileNumber(shape_vec)); +} + +#if IS_TRT_VERSION_GE(6, 0, 0, 0) +TEST_F(TrtShapeOptimizationProfileTest, Dynamic) { + // Network with dynamic input shapes + nvinfer1::Dims3 dims(-1, -1, 10); + DefineNetwork(network_.get(), dims); + + TrtShapeOptimizationProfile profile; + std::vector> input_profiles{ + {nvinfer1::Dims3(2, 2, 10), nvinfer1::Dims3(2, 2, 10)}, + {nvinfer1::Dims3(3, 3, 10), nvinfer1::Dims3(3, 3, 10)}, + {nvinfer1::Dims3(16, 16, 10), nvinfer1::Dims3(16, 16, 10)}, + }; + + // Simulate a profile collection phase + for (auto dim_vec : input_profiles) { + std::vector shape_vec = DimVecToShapeVec(dim_vec); + profile.AddShape(shape_vec); + } + profile.InitProfiles(); + + // Configure and build engine + TF_CHECK_OK(profile.ConfigureBuilder(builder_.get(), builder_config_.get(), + network_.get())); + engine = TrtUniquePtrType( + builder_->buildEngineWithConfig(*network_.get(), *builder_config_.get())); + ASSERT_NE(nullptr, engine); + + TF_CHECK_OK(profile.CreateExecutionContexts(engine.get(), exec_context_)); + + // Each profile has an associated execution context. + EXPECT_EQ(exec_context_.size(), input_profiles.size()); + + // Check if the profiles are assigned correctly. + for (auto dimvec : input_profiles) { + std::vector shape_vec = DimVecToShapeVec(dimvec); + int idx = profile.GetProfileNumber(shape_vec); + int prof_idx = exec_context_[idx]->getOptimizationProfile(); + ASSERT_GE(prof_idx, 0); + + for (int j = 0; j < dimvec.size(); j++) { + nvinfer1::Dims min = engine->getProfileDimensions( + j, prof_idx, nvinfer1::OptProfileSelector::kMIN); + nvinfer1::Dims max = engine->getProfileDimensions( + j, prof_idx, nvinfer1::OptProfileSelector::kMAX); + nvinfer1::Dims opt = engine->getProfileDimensions( + j, prof_idx, nvinfer1::OptProfileSelector::kOPT); + + // This should always hold. + EXPECT_TRUE(DimsContained(dimvec[j], min, max)); + + // The following test depends on the profile creation strategy, and needs + // to be updated (disabled) if the default trategy (defined by + // InitProfiles) changes. + EXPECT_TRUE(DimsEqual(dimvec[j], opt)); + } + } +} +#endif + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index 34888fc0e2f..f0aebc9b543 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -133,7 +133,7 @@ Status GraphCompiler::Compile() { OpKernel* op_kernel_raw = nullptr; // The kernel is not actually run for functional ops, we just need it // for metadata. - Status s = flib_->CreateKernel(n->def(), &op_kernel_raw); + Status s = flib_->CreateKernel(n->properties(), &op_kernel_raw); // Transfer ownership of the kernel to a local smart pointer. std::unique_ptr op_kernel(op_kernel_raw); diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 5f1c2f28ba4..8571c503299 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -55,7 +55,6 @@ tf_kernel_library( "index_ops.cc", "l2loss_op.cc", "listdiff_op.cc", - "lower_upper_bound_ops.cc", "lrn_ops.cc", "matmul_op.cc", "matrix_band_part_op.cc", @@ -150,7 +149,6 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla:comparison_util", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index 62ed069b4f0..0ea851e9325 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -264,6 +264,23 @@ xla::XlaOp IgammaImpl(xla::XlaOp x, xla::XlaOp y, XLA_MAKE_BINARY(Igamma, IgammaImpl(lhs, rhs, broadcast_helper)); +xla::XlaOp IgammaGradAImpl(xla::XlaOp x, xla::XlaOp y, + const BCast& broadcast_helper) { + std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); + return xla::IgammaGradA(x, y); +} + +XLA_MAKE_BINARY(IgammaGradA, IgammaGradAImpl(lhs, rhs, broadcast_helper)); + +xla::XlaOp RandomGammaGradImpl(xla::XlaOp x, xla::XlaOp y, + const BCast& broadcast_helper) { + std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); + return xla::RandomGammaGrad(x, y); +} + +XLA_MAKE_BINARY(RandomGammaGrad, + RandomGammaGradImpl(lhs, rhs, broadcast_helper)); + xla::XlaOp IgammacImpl(xla::XlaOp x, xla::XlaOp y, const BCast& broadcast_helper) { std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); diff --git a/tensorflow/compiler/tf2xla/kernels/lower_upper_bound_ops.cc b/tensorflow/compiler/tf2xla/kernels/lower_upper_bound_ops.cc deleted file mode 100644 index 0eacf8812f1..00000000000 --- a/tensorflow/compiler/tf2xla/kernels/lower_upper_bound_ops.cc +++ /dev/null @@ -1,116 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" -#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/comparison_util.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" - -namespace tensorflow { -namespace { - -// Builds a LowerBound or UpperBound op, the distinction lying in -// comparison_direction: GT => LowerBoundOp, GE => UpperBoundOp. -// Note that this is an O(MN) algorithm: all entries in each sorted_inputs row -// are considered, and their sorted nature is not fully exploited. -void BuildLowerUpperBoundOp(XlaOpKernelContext* ctx, DataType out_dtype, - xla::ComparisonDirection comparison_direction) { - const TensorShape sorted_inputs_shape = ctx->InputShape("sorted_inputs"); - const TensorShape values_shape = ctx->InputShape("values"); - const xla::XlaOp sorted_inputs = ctx->Input("sorted_inputs"); - const xla::XlaOp values = ctx->Input("values"); - - // We are assuming both inputs are 2D, which they will be given the current - // implementation of tf.searchsorted. - OP_REQUIRES(ctx, sorted_inputs_shape.dims() == 2, - errors::FailedPrecondition("sorted_inputs must be 2D")); - OP_REQUIRES(ctx, values_shape.dims() == 2, - errors::FailedPrecondition("values must be 2D")); - - // Add a new inner dimension to values, to allow broadcasting along the inner - // dimension of sorted_sequence. - auto new_values_shape = values_shape; - new_values_shape.InsertDim(/* d */ 2, /* size */ 1); - auto values_reshaped = xla::Reshape(values, new_values_shape.dim_sizes()); - - // Add a new penultimate dimension to sorted_inputs, to allow broadcasting of - // sorted_sequence entries for each value. - auto new_sorted_inputs_shape = sorted_inputs_shape; - new_sorted_inputs_shape.InsertDim(/* d */ 1, /* size */ 1); - auto sorted_inputs_reshaped = - xla::Reshape(sorted_inputs, new_sorted_inputs_shape.dim_sizes()); - - // We are relying on broadcasting to compare each value against each entry in - // the associated sorted_inputs row. - // The reshapes above leave the tensors with equal rank of 3, so broadcast - // dimensions are not explicitly specified. - auto comparison = xla::Compare(values_reshaped, sorted_inputs_reshaped, {}, - comparison_direction); - - const DataType accumulation_type = XlaHelpers::SumAccumulationType(out_dtype); - - // Convert boolean comparison results to integers so we can sum them. - auto comparison_int = - XlaHelpers::ConvertElementType(comparison, accumulation_type); - - // Sum the comparison results over the inner dimension to find the index for - // each value. - xla::XlaBuilder* builder = ctx->builder(); - auto reduced = - xla::Reduce(comparison_int, XlaHelpers::Zero(builder, accumulation_type), - *ctx->GetOrCreateAdd(accumulation_type), {2}); - - ctx->SetOutput(0, reduced); -} - -class LowerBoundOp : public XlaOpKernel { - public: - explicit LowerBoundOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_)); - } - - void Compile(XlaOpKernelContext* ctx) override { - BuildLowerUpperBoundOp(ctx, out_dtype_, xla::ComparisonDirection::kGt); - } - - private: - DataType out_dtype_; -}; - -REGISTER_XLA_OP(Name("LowerBound"), LowerBoundOp); - -class UpperBoundOp : public XlaOpKernel { - public: - explicit UpperBoundOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_)); - } - - void Compile(XlaOpKernelContext* ctx) override { - BuildLowerUpperBoundOp(ctx, out_dtype_, xla::ComparisonDirection::kGe); - } - - private: - DataType out_dtype_; -}; - -REGISTER_XLA_OP(Name("UpperBound"), UpperBoundOp); - -} // namespace -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index 67d49eafcde..5f5cae8f176 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -32,6 +32,8 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/pooling_ops_common.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/util/tensor_format.h" namespace tensorflow { namespace { @@ -157,6 +159,13 @@ class MaxPoolOp : public PoolingOp { OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), errors::InvalidArgument("Invalid data format")); + OP_REQUIRES( + ctx, + data_format_ != FORMAT_NCHW_VECT_C && + data_format_ != FORMAT_NHWC_VECT_W, + errors::Unimplemented("XLA does not support the VECT_* data formats. " + "Returning unimplemented from MaxPool to keep " + "Tensorflow's intended optimized MaxPool here.")); } void Compile(XlaOpKernelContext* ctx) override { diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc index a0ffd1908c5..7ac4cb8fb06 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc @@ -35,11 +35,10 @@ namespace tensorflow { // This require the TF_DUMP_GRAPH_PREFIX to be set to a path that exist (or can // be created). static void DumpModule(mlir::ModuleOp module, llvm::StringRef file_prefix) { - const char* prefix_env = GetDumpDirFromEnvVar(); - if (!prefix_env) { + std::string prefix = GetDumpDirFromEnvVar(); + if (prefix.empty()) { return; } - std::string prefix = prefix_env; auto* env = tensorflow::Env::Default(); auto status = env->RecursivelyCreateDir(prefix); diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index 3efdda15a94..0df61da57a3 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -34,6 +34,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import bitwise_ops from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import gen_random_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops @@ -200,6 +201,8 @@ shift_right_logical = _broadcasting_binary_op(_shift_right_logical_helper) shift_right_arithmetic = _broadcasting_binary_op(_shift_right_arithmetic_helper) igamma = _broadcasting_binary_op(math_ops.igamma) +igamma_grad_a = _broadcasting_binary_op(gen_math_ops.igamma_grad_a) +random_gamma_grad = _broadcasting_binary_op(gen_random_ops.random_gamma_grad) igammac = _broadcasting_binary_op(math_ops.igammac) diff --git a/tensorflow/compiler/tf2xla/sharding_util.cc b/tensorflow/compiler/tf2xla/sharding_util.cc index 4d5bf0835e1..366e8d49228 100644 --- a/tensorflow/compiler/tf2xla/sharding_util.cc +++ b/tensorflow/compiler/tf2xla/sharding_util.cc @@ -26,22 +26,6 @@ const char kShardingAttribute[] = "_XlaSharding"; } // namespace namespace { -xla::StatusOr> GetShardingFromNodeDef( - const NodeDef& node_def) { - if (!HasNodeAttr(node_def, kShardingAttribute)) { - return absl::optional(); - } - string value; - xla::OpSharding sharding; - TF_RETURN_IF_ERROR(GetNodeAttr(node_def, kShardingAttribute, &value)); - if (!sharding.ParseFromString(value)) { - return xla::InvalidArgument( - "Experimental _XlaSharding attribute was not a valid encoded " - "xla::OpSharding proto."); - } - return absl::optional(sharding); -} - Status CoreOutOfRangeError(int core, int num_cores_per_replica) { return errors::InvalidArgument( "Invalid replicated core id: ", core, @@ -107,4 +91,19 @@ void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst) { } } +xla::StatusOr> GetShardingFromNodeDef( + const NodeDef& node_def) { + if (!HasNodeAttr(node_def, kShardingAttribute)) { + return absl::optional(); + } + string value; + xla::OpSharding sharding; + TF_RETURN_IF_ERROR(GetNodeAttr(node_def, kShardingAttribute, &value)); + if (!sharding.ParseFromString(value)) { + return xla::InvalidArgument( + "Experimental _XlaSharding attribute was not a valid encoded " + "xla::OpSharding proto."); + } + return absl::optional(sharding); +} } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/sharding_util.h b/tensorflow/compiler/tf2xla/sharding_util.h index ab67d4f1542..196434826f9 100644 --- a/tensorflow/compiler/tf2xla/sharding_util.h +++ b/tensorflow/compiler/tf2xla/sharding_util.h @@ -45,6 +45,10 @@ xla::StatusOr> ParseShardingFromDevice( void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst); +// Get sharding inforamtion from node. +xla::StatusOr> GetShardingFromNodeDef( + const NodeDef& node_def); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_SHARDING_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/type_util.cc b/tensorflow/compiler/tf2xla/type_util.cc index 634f64e01e6..2266a07463d 100644 --- a/tensorflow/compiler/tf2xla/type_util.cc +++ b/tensorflow/compiler/tf2xla/type_util.cc @@ -97,6 +97,7 @@ xla::StatusOr EncodePrimitiveTypeAsDataType(xla::PrimitiveType type) { {xla::U16, DT_UINT16}, {xla::U32, DT_UINT32}, {xla::U64, DT_UINT64}, + {xla::C128, DT_COMPLEX128}, }); auto it = data_type_map.find(type); diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 8e44d3d4255..3ea62882dcb 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -139,6 +139,86 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, return Status::OK(); } +// Rewrites the layout of xla_shape if there is tiled sharding. +Status RewriteLayoutWithShardedShape( + const absl::optional& sharding, bool use_fast_memory, + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + xla::Shape* xla_shape) { + if (sharding && !sharding->IsTileMaximal()) { + // After sharding, per core shape might have different layout. For example, + // before sharding, a shape [128, 128] will be assigned default + // minor-to-major {1, 0}. But after we shard this shape to [128, 64] * 2, + // the sharded shapes will have minor-to-major {0, 1}. + // + // As a result, for sharded shapes, we set their layout to per core shape's + // layout. + // + // TODO(endlessroad): for variable input & update, we might have + // different layouts which will prevent input output aliasing and + // increase memory usage. Investigate such cases. + int64 device = *sharding->tile_assignment().begin(); + std::vector offset = + sharding->TileOffsetForDevice(*xla_shape, device); + std::vector limit = sharding->TileLimitForDevice(*xla_shape, device); + std::vector dimensions(xla_shape->rank()); + for (int64 i = 0; i < xla_shape->rank(); ++i) { + dimensions[i] = limit[i] - offset[i]; + } + xla::Shape per_device_xla_shape = + xla::ShapeUtil::MakeShape(xla_shape->element_type(), dimensions); + TensorShape per_device_tensor_shape; + TF_RETURN_IF_ERROR( + XLAShapeToTensorShape(per_device_xla_shape, &per_device_tensor_shape)); + TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType( + xla_shape->element_type())); + TF_ASSIGN_OR_RETURN(per_device_xla_shape, + shape_representation_fn(per_device_tensor_shape, dtype, + use_fast_memory)); + *xla_shape->mutable_layout() = per_device_xla_shape.layout(); + } + return Status::OK(); +} + +// There is a shape_representation_fn or sharding for an output, this function +// uses a reshape to fix the layout. +xla::StatusOr ReshapeWithCorrectRepresentationAndSharding( + xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape, + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + absl::optional sharding, bool fast_mem) { + if (original_shape.IsTuple()) { + std::vector elements; + for (int64 i = 0; i < original_shape.tuple_shapes_size(); ++i) { + auto subsharding = sharding ? sharding->tuple_shardings(i) : sharding; + TF_ASSIGN_OR_RETURN(auto element, + ReshapeWithCorrectRepresentationAndSharding( + builder, xla::GetTupleElement(original, i), + original_shape.tuple_shapes(i), + shape_representation_fn, subsharding, fast_mem)); + elements.push_back(element); + } + return xla::Tuple(builder, elements); + } + if (!original_shape.IsArray()) return original; + TensorShape shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(original_shape, &shape)); + TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType( + original_shape.element_type())); + TF_ASSIGN_OR_RETURN(auto to_shape, + shape_representation_fn(shape, dtype, fast_mem)); + if (sharding) { + TF_ASSIGN_OR_RETURN(auto hlo_sharding, + xla::HloSharding::FromProto(*sharding)); + TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape( + hlo_sharding, fast_mem, shape_representation_fn, &to_shape)); + } + if (xla::ShapeUtil::Compatible(original_shape, to_shape)) { + for (int64 i = 0; i < original_shape.rank(); ++i) { + to_shape.set_dynamic_dimension(i, original_shape.is_dynamic_dimension(i)); + } + } + return xla::Reshape(to_shape, original); +} + // Builds the XLA computation. // - `args` is the list of input arguments // - `retvals` is the list of retvals produced by _Retval operators, in index @@ -188,10 +268,6 @@ Status BuildComputation( std::vector elems; elems.reserve(retvals.size()); - // Keeps track of the layout of each retval. If a retval is not in this list, - // a descending layout is used. The first element is the output index, second - // element is the new layout. - std::vector> retval_index_and_layout; // Keeps track of sharding of each retval. If a retval is not in this list, // replicate sharding is used. The first element is the output index, second // element is the sharding. @@ -219,22 +295,22 @@ Status BuildComputation( TF_ASSIGN_OR_RETURN(output.shape, retval.GetShape()); xla::XlaOp value = retval.handle(); auto it = retval_shardings.find(i); - xla::XlaScopedShardingAssignment assign_sharding( - builder, it == retval_shardings.end() - ? absl::optional() - : it->second); + absl::optional sharding = + it == retval_shardings.end() ? absl::optional() + : it->second; if (it != retval_shardings.end()) { retval_index_and_sharding[elems.size()] = it->second; } if (shape_representation_fn) { - // If there is a shape representation function, reshape the output - // tensor to the shape given by the representation shape function. - TF_ASSIGN_OR_RETURN(xla::Shape shape, shape_representation_fn( - output.shape, output.type, - /*use_fast_memory=*/false)); - value = xla::Reshape(value, xla::AsInt64Slice(shape.dimensions())); - retval_index_and_layout.emplace_back(elems.size(), shape.layout()); - } else if (it != retval_shardings.end()) { + TF_ASSIGN_OR_RETURN(auto original_shape, builder->GetShape(value)); + TF_ASSIGN_OR_RETURN(value, + ReshapeWithCorrectRepresentationAndSharding( + builder, value, original_shape, + shape_representation_fn, sharding, + /*fast_mem=*/false)); + } + if (it != retval_shardings.end()) { + xla::XlaScopedShardingAssignment assign_sharding(builder, sharding); // Apply the sharding to the output, if there is a core assignment. value = identity_op(value); } @@ -312,43 +388,27 @@ Status BuildComputation( update.tensor_array_gradients_accessed.insert(grad.first); } + xla::XlaOp handle; + TF_RETURN_IF_ERROR(resource->Pack(&handle, builder)); + auto sharding = it == arg_shardings.end() + ? absl::optional() + : it->second; + // Set layout of the retval to device representation layout. + if (shape_representation_fn) { + TF_ASSIGN_OR_RETURN(auto original_shape, builder->GetShape(handle)); + TF_ASSIGN_OR_RETURN( + handle, ReshapeWithCorrectRepresentationAndSharding( + builder, handle, original_shape, + shape_representation_fn, sharding, arg.fast_mem)); + } + // Request that the value be returned on a specific core. - xla::XlaScopedShardingAssignment assign_sharding( - builder, it == arg_shardings.end() ? absl::optional() - : it->second); + xla::XlaScopedShardingAssignment assign_sharding(builder, sharding); if (it != arg_shardings.end()) { retval_index_and_sharding[elems.size()] = it->second; } - - xla::XlaOp handle; - TF_RETURN_IF_ERROR(resource->Pack(&handle, builder)); - // Ensures the correct sharding is applied to the output. handle = identity_op(handle); - - // Set layout of the retval to device representation layout. - absl::optional representation_shape; - if (shape_representation_fn) { - TF_ASSIGN_OR_RETURN( - xla::Shape xla_shape, - shape_representation_fn(resource->shape(), resource->type(), - /*use_fast_memory=*/false)); - representation_shape = xla_shape; - } - if (resource->representation_shape().has_value()) { - const xla::Shape& xla_shape = resource->representation_shape().value(); - if (representation_shape) { - TF_RET_CHECK( - xla::ShapeUtil::Compatible(*representation_shape, xla_shape)); - } else { - representation_shape = xla_shape; - } - } - if (representation_shape) { - retval_index_and_layout.emplace_back(elems.size(), - representation_shape->layout()); - } - elems.push_back(handle); } } @@ -411,20 +471,8 @@ Status BuildComputation( } *computation = computation_status.ConsumeValueOrDie(); - TF_ASSIGN_OR_RETURN(const auto& program_shape, - computation->GetProgramShape()); + TF_ASSIGN_OR_RETURN(auto program_shape, computation->GetProgramShape()); *output_shape = program_shape.result(); - // Update the output layout to the layout of retval. - for (auto& index_and_layout : retval_index_and_layout) { - if (!always_return_tuple && elems.size() == 1) { - *output_shape->mutable_layout() = index_and_layout.second; - continue; - } - - xla::Shape* output_sub_shape = xla::ShapeUtil::GetMutableSubshape( - output_shape, {index_and_layout.first}); - *output_sub_shape->mutable_layout() = index_and_layout.second; - } return Status::OK(); } @@ -779,47 +827,6 @@ Status XlaCompiler::XLAShapeForArgument( const XlaCompiler::Argument& arg, bool is_entry_computation, const absl::optional& arg_sharding, xla::Shape* xla_shape) const { - auto rewrite_layout_with_sharded_shape = - [](const absl::optional& arg_sharding, - bool use_fast_memory, - XlaCompiler::ShapeRepresentationFn shape_representation_fn, - xla::Shape* xla_shape) { - if (arg_sharding && !arg_sharding->IsTileMaximal()) { - // After parameter sharding, per core parameter might have different - // layout. For example, before sharding, a parameter of shape [128, - // 128] will be assigned default minor-to-major {1, 0}. But after we - // shard this parameter to [128, 64] * 2, the sharded parameters - // will have minor-to-major {0, 1}. - // - // As a result, for sharded parameters, we set their layout to per - // core parameter's layout. - // - // TODO(endlessroad): for variable input & update, we might have - // different layouts which will prevent input output aliasing and - // increase memory usage. Investigate such cases. - int64 device = *arg_sharding->tile_assignment().begin(); - std::vector offset = - arg_sharding->TileOffsetForDevice(*xla_shape, device); - std::vector limit = - arg_sharding->TileLimitForDevice(*xla_shape, device); - std::vector dimensions(xla_shape->rank()); - for (int64 i = 0; i < xla_shape->rank(); ++i) { - dimensions[i] = limit[i] - offset[i]; - } - xla::Shape per_device_xla_shape = - xla::ShapeUtil::MakeShape(xla_shape->element_type(), dimensions); - TensorShape per_device_tensor_shape; - TF_RETURN_IF_ERROR(XLAShapeToTensorShape(per_device_xla_shape, - &per_device_tensor_shape)); - TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType( - xla_shape->element_type())); - TF_ASSIGN_OR_RETURN(per_device_xla_shape, - shape_representation_fn(per_device_tensor_shape, - dtype, use_fast_memory)); - *xla_shape->mutable_layout() = per_device_xla_shape.layout(); - } - return Status::OK(); - }; switch (arg.kind) { case XlaCompiler::Argument::kConstant: LOG(FATAL) << "Unreachable case"; @@ -835,7 +842,7 @@ Status XlaCompiler::XLAShapeForArgument( TF_ASSIGN_OR_RETURN(*xla_shape, options_.shape_representation_fn( shape, arg.type, /*use_fast_memory=*/false)); - TF_RETURN_IF_ERROR(rewrite_layout_with_sharded_shape( + TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape( arg_sharding, /*use_fast_memory=*/false, options_.shape_representation_fn, xla_shape)); } else { @@ -863,7 +870,7 @@ Status XlaCompiler::XLAShapeForArgument( options_.shape_representation_fn( absl::get(arg.shape), arg.type, /*use_fast_memory=*/arg.fast_mem)); - TF_RETURN_IF_ERROR(rewrite_layout_with_sharded_shape( + TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape( arg_sharding, arg.fast_mem, options_.shape_representation_fn, xla_shape)); return Status::OK(); diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index cf8bd6b6ce4..76780167187 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -365,7 +365,8 @@ TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForFastMemVar) { compile_options.return_updated_values_for_all_resources = true; TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph), args, &result)); - EXPECT_EQ(fast_mem_arg_count, 1); + // Count 2: one for argument, one for the return value. + EXPECT_EQ(fast_mem_arg_count, 2); } // Tests that the compiler can correctly propagate the layout assigned by @@ -417,6 +418,8 @@ TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForRetVal) { // Check that the return shapes are correctly tranposed. EXPECT_EQ(result.xla_output_shape, xla::ShapeUtil::MakeTupleShape({transposed, transposed})); + EXPECT_EQ(result.computation->GetProgramShape().ConsumeValueOrDie().result(), + xla::ShapeUtil::MakeTupleShape({transposed, transposed})); } // The layout of resource variable shouldn't change after transpose @@ -1091,6 +1094,8 @@ TEST_F(XlaCompilerTest, ResultLayoutSingle) { EXPECT_TRUE(xla::ShapeUtil::Equal( result.xla_output_shape, xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1}))); + EXPECT_EQ(result.computation->GetProgramShape().ConsumeValueOrDie().result(), + result.xla_output_shape); } TEST_F(XlaCompilerTest, ResultLayoutMultiple) { @@ -1131,6 +1136,8 @@ TEST_F(XlaCompilerTest, ResultLayoutMultiple) { EXPECT_TRUE(xla::ShapeUtil::Equal( result.xla_output_shape, xla::ShapeUtil::MakeTupleShape({result_shape, result_shape}))); + EXPECT_EQ(result.computation->GetProgramShape().ConsumeValueOrDie().result(), + result.xla_output_shape); } // Tests a simple graph that reads and writes a variable. diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index dd9f83bf26e..01f35df0e20 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -722,6 +722,7 @@ tf_cc_test( ":text_literal_writer", ":types", "//tensorflow/core:lib", + "//tensorflow/core:test", "//tensorflow/core:test_main", ], ) diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc index d0971734570..701479614aa 100644 --- a/tensorflow/compiler/xla/client/lib/math.cc +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -693,7 +693,10 @@ XlaOp Digamma(XlaOp input) { namespace { +enum kIgammaMode { VALUE, DERIVATIVE, SAMPLE_DERIVATIVE }; + // Helper function for computing Igamma using a power series. +template XlaOp IgammaSeries(XlaOp ax, XlaOp x, XlaOp a, XlaOp enabled, xla::PrimitiveType type) { // vals: (enabled, r, c, ans, x) @@ -715,24 +718,60 @@ XlaOp IgammaSeries(XlaOp ax, XlaOp x, XlaOp a, XlaOp enabled, XlaOp c = vals[2]; XlaOp ans = vals[3]; XlaOp x = vals[4]; + XlaOp dc_da = vals[5]; + XlaOp dans_da = vals[6]; + r = r + ScalarLike(r, 1); + dc_da = dc_da * (x / r) + (ScalarLike(r, -1) * c * x) / (r * r); + dans_da = dans_da + dc_da; c = c * (x / r); ans = ans + c; + XlaOp conditional; + if (mode == VALUE) { + conditional = And(enabled, Gt(c / ans, Epsilon(builder, type))); + } else { + conditional = + And(enabled, Gt(Abs(dc_da / dans_da), Epsilon(builder, type))); + } + return std::vector{ - And(enabled, Gt(c / ans, Epsilon(builder, type))), - Select(enabled, r, vals[1]), Select(enabled, c, vals[2]), - Select(enabled, ans, vals[3]), Select(enabled, x, vals[4])}; + conditional, + Select(enabled, r, vals[1]), + Select(enabled, c, vals[2]), + Select(enabled, ans, vals[3]), + Select(enabled, x, vals[4]), + Select(enabled, dc_da, vals[5]), + Select(enabled, dans_da, vals[6]), + }; }; auto& b = *ax.builder(); return b.ReportErrorOrReturn([&]() -> StatusOr { - std::vector vals = {enabled, a, FullLike(a, 1), FullLike(a, 1), x}; + std::vector vals = { + enabled, a, FullLike(a, 1), FullLike(a, 1), x, FullLike(a, 0), + FullLike(a, 0), + }; + TF_ASSIGN_OR_RETURN(vals, WhileLoopHelper(cond, body, vals, "igamma", &b)); XlaOp ans = vals[3]; - return (ans * ax) / a; + XlaOp dans_da = vals[6]; + if (mode == VALUE) { + return (ans * ax) / a; + } + + XlaOp dlogax_da = Log(x) - Digamma(a + ScalarLike(a, 1)); + + switch (mode) { + case DERIVATIVE: + return ax * (ans * dlogax_da + dans_da) / a; + case SAMPLE_DERIVATIVE: + default: + return -(dans_da + ans * dlogax_da) * x / a; + } }); } // Helper function for computing Igammac using a continued fraction. +template XlaOp IgammacContinuedFraction(XlaOp ax, XlaOp x, XlaOp a, XlaOp enabled, xla::PrimitiveType type) { // vals: enabled, ans, t, y, z, c, pkm1, qkm1, pkm2, qkm2 @@ -754,6 +793,13 @@ XlaOp IgammacContinuedFraction(XlaOp ax, XlaOp x, XlaOp a, XlaOp enabled, XlaOp qkm1 = vals[7]; XlaOp pkm2 = vals[8]; XlaOp qkm2 = vals[9]; + + XlaOp dpkm2_da = vals[10]; + XlaOp dqkm2_da = vals[11]; + XlaOp dpkm1_da = vals[12]; + XlaOp dqkm1_da = vals[13]; + XlaOp dans_da = vals[14]; + c = c + ScalarLike(c, 1); y = y + ScalarLike(y, 1); z = z + ScalarLike(z, 2); @@ -762,18 +808,46 @@ XlaOp IgammacContinuedFraction(XlaOp ax, XlaOp x, XlaOp a, XlaOp enabled, XlaOp qk = qkm1 * z - qkm2 * yc; XlaOp qk_is_nonzero = Ne(qk, ScalarLike(qk, 0)); XlaOp r = pk / qk; + t = Select(qk_is_nonzero, Abs((ans - r) / r), FullLike(t, 1)); ans = Select(qk_is_nonzero, r, ans); + + XlaOp dpk_da = dpkm1_da * z - pkm1 - dpkm2_da * yc + pkm2 * c; + XlaOp dqk_da = dqkm1_da * z - qkm1 - dqkm2_da * yc + qkm2 * c; + XlaOp dans_da_new = + Select(qk_is_nonzero, (dpk_da - ans * dqk_da) / qk, dans_da); + XlaOp grad_conditional = + Select(qk_is_nonzero, Abs(dans_da_new - dans_da), FullLike(dans_da, 1)); + pkm2 = pkm1; pkm1 = pk; qkm2 = qkm1; qkm1 = qk; + + dpkm2_da = dpkm1_da; + dqkm2_da = dqkm1_da; + dpkm1_da = dpk_da; + dqkm1_da = dqk_da; + XlaOp rescale = Gt(Abs(pk), Reciprocal(Epsilon(builder, type))); pkm2 = Select(rescale, pkm2 * Epsilon(builder, type), pkm2); pkm1 = Select(rescale, pkm1 * Epsilon(builder, type), pkm1); qkm2 = Select(rescale, qkm2 * Epsilon(builder, type), qkm2); qkm1 = Select(rescale, qkm1 * Epsilon(builder, type), qkm1); - return std::vector{And(enabled, Gt(t, Epsilon(builder, type))), + + dpkm2_da = Select(rescale, dpkm2_da * Epsilon(builder, type), dpkm2_da); + dqkm2_da = Select(rescale, dqkm2_da * Epsilon(builder, type), dqkm2_da); + dpkm1_da = Select(rescale, dpkm1_da * Epsilon(builder, type), dpkm1_da); + dqkm1_da = Select(rescale, dqkm1_da * Epsilon(builder, type), dqkm1_da); + + XlaOp conditional; + if (mode == VALUE) { + conditional = And(enabled, Gt(t, Epsilon(builder, type))); + } else { + conditional = And(enabled, Gt(grad_conditional, Epsilon(builder, type))); + } + + return std::vector{conditional, Select(enabled, ans, vals[1]), Select(enabled, t, vals[2]), Select(enabled, y, vals[3]), @@ -782,7 +856,12 @@ XlaOp IgammacContinuedFraction(XlaOp ax, XlaOp x, XlaOp a, XlaOp enabled, Select(enabled, pkm1, vals[6]), Select(enabled, qkm1, vals[7]), Select(enabled, pkm2, vals[8]), - Select(enabled, qkm2, vals[9])}; + Select(enabled, qkm2, vals[9]), + Select(enabled, dpkm2_da, vals[10]), + Select(enabled, dqkm2_da, vals[11]), + Select(enabled, dpkm1_da, vals[12]), + Select(enabled, dqkm1_da, vals[13]), + Select(enabled, dans_da_new, vals[14])}; }; auto& b = *ax.builder(); @@ -796,11 +875,31 @@ XlaOp IgammacContinuedFraction(XlaOp ax, XlaOp x, XlaOp a, XlaOp enabled, XlaOp qkm1 = z * x; XlaOp ans = pkm1 / qkm1; XlaOp t = FullLike(x, 1); - std::vector vals = {enabled, ans, t, y, z, - c, pkm1, qkm1, pkm2, qkm2}; + XlaOp dpkm2_da = FullLike(x, 0); + XlaOp dqkm2_da = FullLike(x, 0); + XlaOp dpkm1_da = FullLike(x, 0); + XlaOp dqkm1_da = -x; + XlaOp dans_da = (dpkm1_da - ans * dqkm1_da) / qkm1; + std::vector vals = {enabled, ans, t, y, z, + c, pkm1, qkm1, pkm2, qkm2, + dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da}; + TF_ASSIGN_OR_RETURN(vals, WhileLoopHelper(cond, body, vals, "igammac", &b)); ans = vals[1]; - return ans * ax; + if (mode == VALUE) { + return ans * ax; + } + + dans_da = vals[14]; + XlaOp dlogax_da = Log(x) - Digamma(a); + + switch (mode) { + case DERIVATIVE: + return ax * (ans * dlogax_da + dans_da); + case SAMPLE_DERIVATIVE: + default: + return -(dans_da + ans * dlogax_da) * x; + } }); } @@ -820,9 +919,9 @@ XlaOp Igamma(XlaOp a, XlaOp x) { const double nan = std::numeric_limits::quiet_NaN(); XlaOp output = Select( use_igammac, - ScalarLike(a, 1) - - IgammacContinuedFraction(ax, x, a, And(enabled, use_igammac), type), - IgammaSeries(ax, x, a, And(enabled, Not(use_igammac)), type)); + ScalarLike(a, 1) - IgammacContinuedFraction( + ax, x, a, And(enabled, use_igammac), type), + IgammaSeries(ax, x, a, And(enabled, Not(use_igammac)), type)); output = Select(underflow, ZerosLike(output), output); output = Select(x_is_zero, ZerosLike(output), output); output = Select(Or(domain_error, is_nan), FullLike(a, nan), output); @@ -852,6 +951,101 @@ XlaOp Igamma(XlaOp a, XlaOp x) { }); } +XlaOp IgammaGradA(XlaOp a, XlaOp x) { + auto& b = *a.builder(); + auto doit = [&b](XlaOp a, XlaOp x, PrimitiveType type) -> XlaOp { + XlaOp is_nan = Or(IsNan(a), IsNan(x)); + XlaOp x_is_zero = Eq(x, ScalarLike(x, 0)); + XlaOp domain_error = Or(Lt(x, ScalarLike(x, 0)), Le(a, ScalarLike(a, 0))); + XlaOp use_igammac = And(Gt(x, ScalarLike(x, 1)), Gt(x, a)); + XlaOp ax = a * Log(x) - x - Lgamma(a); + XlaOp underflow = Lt(ax, -Log(MaxFiniteValue(&b, type))); + ax = Exp(ax); + XlaOp enabled = Not(Or(Or(Or(x_is_zero, domain_error), underflow), is_nan)); + const double nan = std::numeric_limits::quiet_NaN(); + XlaOp output = Select(use_igammac, + -IgammacContinuedFraction( + ax, x, a, And(enabled, use_igammac), type), + IgammaSeries( + ax, x, a, And(enabled, Not(use_igammac)), type)); + output = Select(underflow, ZerosLike(output), output); + output = Select(x_is_zero, ZerosLike(output), output); + output = Select(Or(domain_error, is_nan), FullLike(a, nan), output); + return output; + }; + return b.ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(auto a_shape, b.GetShape(a)); + TF_ASSIGN_OR_RETURN(auto x_shape, b.GetShape(x)); + if (a_shape != x_shape) { + return InvalidArgument( + "Arguments to IgammaGradA must have equal shapes and types; got %s " + "and %s", + a_shape.ToString(), x_shape.ToString()); + } + TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IgammaGradA", a)); + bool needs_upcast = + a_shape.element_type() == F16 || a_shape.element_type() == BF16; + + if (needs_upcast) { + a = ConvertElementType(a, F32); + x = ConvertElementType(x, F32); + } + XlaOp result = doit(a, x, a_shape.element_type()); + if (needs_upcast) { + result = ConvertElementType(result, a_shape.element_type()); + } + return result; + }); +} + +// Gradient of Gamma sample from Gamma(a, 1) with respect to `a`. +XlaOp RandomGammaGrad(XlaOp a, XlaOp x) { + auto& b = *a.builder(); + auto doit = [&b](XlaOp a, XlaOp x, PrimitiveType type) -> XlaOp { + XlaOp is_nan = Or(IsNan(a), IsNan(x)); + XlaOp x_is_zero = Eq(x, ScalarLike(x, 0)); + XlaOp domain_error = Or(Lt(x, ScalarLike(x, 0)), Le(a, ScalarLike(a, 0))); + XlaOp use_igammac = And(Gt(x, ScalarLike(x, 1)), Gt(x, a)); + XlaOp ax = a * Log(x) - x - Lgamma(a); + XlaOp underflow = Lt(ax, -Log(MaxFiniteValue(&b, type))); + ax = Exp(ax); + XlaOp enabled = Not(Or(Or(Or(x_is_zero, domain_error), underflow), is_nan)); + const double nan = std::numeric_limits::quiet_NaN(); + XlaOp output = Select(use_igammac, + -IgammacContinuedFraction( + ax, x, a, And(enabled, use_igammac), type), + IgammaSeries( + ax, x, a, And(enabled, Not(use_igammac)), type)); + output = Select(underflow, ZerosLike(output), output); + output = Select(x_is_zero, ZerosLike(output), output); + output = Select(Or(domain_error, is_nan), FullLike(a, nan), output); + return output; + }; + return b.ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(auto a_shape, b.GetShape(a)); + TF_ASSIGN_OR_RETURN(auto x_shape, b.GetShape(x)); + if (a_shape != x_shape) { + return InvalidArgument( + "Arguments to RandomGammaGrad must have equal shapes and types; got " + "%s and %s", + a_shape.ToString(), x_shape.ToString()); + } + TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("RandomGammaGrad", a)); + bool needs_upcast = + a_shape.element_type() == F16 || a_shape.element_type() == BF16; + + if (needs_upcast) { + a = ConvertElementType(a, F32); + x = ConvertElementType(x, F32); + } + XlaOp result = doit(a, x, a_shape.element_type()); + if (needs_upcast) { + result = ConvertElementType(result, a_shape.element_type()); + } + return result; + }); +} + XlaOp Igammac(XlaOp a, XlaOp x) { auto& b = *a.builder(); auto doit = [&b](XlaOp a, XlaOp x, PrimitiveType type) -> XlaOp { @@ -863,10 +1057,10 @@ XlaOp Igammac(XlaOp a, XlaOp x) { ax = Exp(ax); XlaOp result = Select(use_igamma, - ScalarLike(a, 1) - - IgammaSeries(ax, x, a, And(enabled, use_igamma), type), - IgammacContinuedFraction(ax, x, a, And(enabled, Not(use_igamma)), - type)); + ScalarLike(a, 1) - IgammaSeries( + ax, x, a, And(enabled, use_igamma), type), + IgammacContinuedFraction( + ax, x, a, And(enabled, Not(use_igamma)), type)); return Select(underflow, ZerosLike(a), Select(out_of_range, FullLike(a, 1), result)); }; @@ -1008,12 +1202,23 @@ XlaOp Asinh(XlaOp x) { if (primitive_util::IsComplexType(shape.element_type())) { return Log(x + Sqrt(x * x + one)); } + // For small x, sqrt(x**2 + 1) will evaluate to 1 due to floating point + // arithmetic. However, we would like to retain the low order term of this, + // which is around 0.5 * x**2 using a binomial expansion. + // Let z = sqrt(a**2 + 1) + // log(a + sqrt(a**2 + 1)) = + // log((a + sqrt(a**2 + 1)) * (1 + sqrt(a**2 + 1)) / (1 + sqrt(a**2 + 1))) = + // log((a + a**2 + 1 + a * z + z) / (1 + z)) = + // log(1 + a + a**2 / (1 + z)) = + // log(1 + a + a ** 2 / (1 + sqrt(a**2 + 1))) + // This rewrite retains the lower order term. auto a = Abs(x); + auto small_result = Log1p(a + a * a / (one + Sqrt(a * a + one))); auto naive_result = Log(a + Sqrt(a * a + one)); auto overflow_result = Log(Abs(a)) + Log(ScalarLike(a, 2)); auto sqrt_max_value = Sqrt(MaxFiniteValue(b, shape.element_type())); - return Sign(x) * - Select(Ge(a, sqrt_max_value), overflow_result, naive_result); + return Sign(x) * Select(Ge(a, sqrt_max_value), overflow_result, + Select(Le(a, one), small_result, naive_result)); }; // These upcasts are not strictly necessary on all platforms to get within our // error tolerances, so we could relax this if it ever mattered. @@ -1028,9 +1233,7 @@ XlaOp Atanh(XlaOp x) { XlaBuilder* b = x.builder(); auto do_it = [&](XlaOp x) -> StatusOr { TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x)); - auto naive_result = - Log((ScalarLike(x, 1.0) + x) / (ScalarLike(x, 1.0) - x)) * - ScalarLike(x, 0.5); + auto naive_result = (Log1p(x) - Log1p(-x)) * ScalarLike(x, 0.5); // TODO(jlebar): For now, we ignore the nan edge case for complex inputs, // because we don't yet have exhaustive tests for complex trig functions. @@ -1074,9 +1277,35 @@ XlaOp Cosh(XlaOp x) { // correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so // we deem this acceptable. XlaOp Sinh(XlaOp x) { - return DoWithUpcastToF32(x, {BF16, F16}, [](XlaOp x) { + XlaBuilder* b = x.builder(); + auto do_it = [&](XlaOp x) -> StatusOr { + TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x)); + auto one_half = ScalarLike(x, 0.5); auto log_one_half = Log(ScalarLike(x, 0.5)); - return Exp(x + log_one_half) - Exp(-x + log_one_half); + auto large_sinh_result = Exp(x + log_one_half) - Exp(-x + log_one_half); + + if (primitive_util::IsComplexType(shape.element_type())) { + return large_sinh_result; + } + + // Here we use e^x = e^(x / 2) * e^(x / 2). This avoids overflow for large + // values of x. + + // For smaller x, we get unwanted cancellations of e^x - e^-x, resulting in + // 0. + // Rewrite this to avoid that. We use expm1(x) because that preserves the + // first order term of the taylor series of e^x. + // (e^(x) - e^(-x)) / 2. = + // (e^(x) - 1 + 1 - e^(-x)) / 2. + // (expm1(x) + (e^(x) - 1) / e^x) / 2. + // (expm1(x) + expm1(x) / (expm1(x) + 1)) / 2. + auto expm1 = Expm1(x); + auto one = ScalarLike(x, 1.); + auto small_sinh_result = one_half * (expm1 + expm1 / (expm1 + one)); + return Select(Lt(Abs(x), one), small_sinh_result, large_sinh_result); + }; + return DoWithUpcastToF32(x, {BF16, F16}, [&](XlaOp x) { + return b->ReportErrorOrReturn(do_it(x)); }); } diff --git a/tensorflow/compiler/xla/client/lib/math.h b/tensorflow/compiler/xla/client/lib/math.h index ac96a50aecc..f862372a288 100644 --- a/tensorflow/compiler/xla/client/lib/math.h +++ b/tensorflow/compiler/xla/client/lib/math.h @@ -61,6 +61,14 @@ XlaOp Digamma(XlaOp input); // Computes an approximation of the incomplete gamma function. XlaOp Igamma(XlaOp a, XlaOp x); +// Computes an approximation of the derivative of the incomplete gamma function +// with respect to a. +XlaOp IgammaGradA(XlaOp a, XlaOp x); + +// Computes an approximation of the derivative of a sample `x` from a `Gamma(a, +// 1)` distribution with respect to a. +XlaOp RandomGammaGrad(XlaOp a, XlaOp x); + // Computes an approximation of the complementary incomplete gamma function. XlaOp Igammac(XlaOp a, XlaOp x); diff --git a/tensorflow/compiler/xla/client/lib/math_test.cc b/tensorflow/compiler/xla/client/lib/math_test.cc index faf30f68a10..32796dd8d70 100644 --- a/tensorflow/compiler/xla/client/lib/math_test.cc +++ b/tensorflow/compiler/xla/client/lib/math_test.cc @@ -298,6 +298,30 @@ XLA_TEST_F(MathTest, SqrtSixValues) { ComputeAndCompareR1(&builder, expected, {}, error_spec_); } +XLA_TEST_F(MathTest, SinhSmallValues) { + XlaBuilder builder(TestName()); + auto x = ConstantR1(&builder, {1e-3, 1e-5, 1e-7, 1e-9, 1e-11}); + Sinh(x); + std::vector expected = {1e-3, 1e-5, 1e-7, 1e-9, 1e-11}; + ComputeAndCompareR1(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(MathTest, AsinhSmallValues) { + XlaBuilder builder(TestName()); + auto x = ConstantR1(&builder, {1e-3, 1e-5, 1e-7, 1e-9, 1e-11}); + Asinh(x); + std::vector expected = {1e-3, 1e-5, 1e-7, 1e-9, 1e-11}; + ComputeAndCompareR1(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(MathTest, AtanhSmallValues) { + XlaBuilder builder(TestName()); + auto x = ConstantR1(&builder, {1e-8, 1e-9, 1e-10, 1e-11}); + Atanh(x); + std::vector expected = {1e-8, 1e-9, 1e-10, 1e-11}; + ComputeAndCompareR1(&builder, expected, {}, error_spec_); +} + XLA_TEST_F(MathTest, Lgamma) { XlaBuilder builder(TestName()); auto x = ConstantR1(&builder, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.5, 1.5, diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index a7e761b7dd0..d4a267d4356 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -528,7 +528,8 @@ StatusOr XlaBuilder::AddBroadcastSequence(const Shape& output_shape, } // Eliminate the size one dimensions. - TF_ASSIGN_OR_RETURN(XlaOp reshaped_operand, Reshape(reshaped_shape, operand)); + TF_ASSIGN_OR_RETURN(XlaOp reshaped_operand, + ReshapeInternal(reshaped_shape, operand)); // Broadcast 'reshape' up to the larger size. return InDimBroadcast(broadcast_shape, reshaped_operand, broadcast_dimensions); @@ -828,8 +829,8 @@ XlaOp XlaBuilder::BroadcastInDim( }); } -StatusOr XlaBuilder::Reshape(const Shape& shape, XlaOp operand, - int64 inferred_dimension) { +StatusOr XlaBuilder::ReshapeInternal(const Shape& shape, XlaOp operand, + int64 inferred_dimension) { TF_RETURN_IF_ERROR(first_error_); HloInstructionProto instr; @@ -1020,7 +1021,7 @@ XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span dimensions, XlaOp transposed = IsIdentityPermutation(dimensions) ? operand : Transpose(operand, dimensions); - return Reshape(shape, transposed, inferred_dimension); + return ReshapeInternal(shape, transposed, inferred_dimension); }); } @@ -1034,6 +1035,13 @@ XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span new_sizes, }); } +XlaOp XlaBuilder::Reshape(const Shape& shape, XlaOp operand, + int64 inferred_dimension) { + return ReportErrorOrReturn([&]() -> StatusOr { + return ReshapeInternal(shape, operand, inferred_dimension); + }); +} + XlaOp XlaBuilder::Collapse(XlaOp operand, absl::Span dimensions) { return ReportErrorOrReturn([&]() -> StatusOr { if (dimensions.size() <= 1) { @@ -2951,6 +2959,10 @@ XlaOp Reshape(const XlaOp operand, absl::Span new_sizes) { return operand.builder()->Reshape(operand, new_sizes); } +XlaOp Reshape(const Shape& shape, XlaOp operand) { + return operand.builder()->Reshape(shape, operand); +} + XlaOp ReshapeWithInferredDimension(XlaOp operand, absl::Span new_sizes, int64 inferred_dimension) { diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 993394ea275..6ec9aeb809f 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -397,6 +397,9 @@ class XlaBuilder { XlaOp Reshape(XlaOp operand, absl::Span new_sizes, int64 inferred_dimension = -1); + XlaOp Reshape(const Shape& shape, XlaOp operand, + int64 inferred_dimension = -1); + XlaOp Collapse(XlaOp operand, absl::Span dimensions); XlaOp Slice(XlaOp operand, absl::Span start_indices, @@ -668,8 +671,8 @@ class XlaBuilder { // Internal helper method for creating a Reshape op with the already inferred // shape. - StatusOr Reshape(const Shape& shape, XlaOp operand, - int64 inferred_dimension = -1); + StatusOr ReshapeInternal(const Shape& shape, XlaOp operand, + int64 inferred_dimension = -1); // Returns the (inferred) result for the program shape using the given root. StatusOr GetProgramShape(int64 root_id) const; @@ -777,6 +780,8 @@ class XlaBuilder { friend XlaOp Reshape(XlaOp operand, absl::Span new_sizes); + friend XlaOp Reshape(const Shape& shape, XlaOp operand); + friend XlaOp ReshapeWithInferredDimension(XlaOp operand, absl::Span new_sizes, int64 inferred_dimension); @@ -1252,6 +1257,9 @@ XlaOp Reshape(XlaOp operand, absl::Span dimensions, // sizes. Conceptually, this is a limited form of "shape casting". XlaOp Reshape(XlaOp operand, absl::Span new_sizes); +// Enqueues a Reshape op that uses an explicit target shape. +XlaOp Reshape(const Shape& shape, XlaOp operand); + // `inferred_dimension` represents the output dimension that's inferred by // upper-level framework by dividing the input element count by the known // output element count. While an inferred_dimension can be static, if there diff --git a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py index ded290a234d..b89bfd68073 100644 --- a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py +++ b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py @@ -90,7 +90,7 @@ class Sharding(object): tile_assignment_devices=list(flattened_devices))) @classmethod - def split(cls, tensor, split_dimension, num_devices): + def split(cls, tensor, split_dimension, num_devices, input_shape=None): """Returns a Sharding that splits a tensor across a dimension. This creates a Tiled attribute, similar to tile(), but easier to use for the @@ -100,12 +100,16 @@ class Sharding(object): tensor: A tf.Tensor to split. split_dimension: The dimension number to split. num_devices: The number of cores to split `tensor` over. + input_shape: The shape of the original tensor. Raises: ValueError: The tensor to split was smaller in the split dimension than the number of devices to split over. """ - shape = tensor.shape.as_list() + if input_shape: + shape = input_shape + else: + shape = tensor.shape.as_list() if (shape[split_dimension] is not None and shape[split_dimension] < num_devices): raise ValueError('Split dimension was smaller than the required number ' @@ -221,7 +225,8 @@ def split(tensor, split_dimension, num_devices, assign_tuple_sharding=False, - use_sharding_op=False): + use_sharding_op=False, + input_shape=None): """Returns a tensor that is split along the given dimension. Args: @@ -230,10 +235,11 @@ def split(tensor, num_devices: The number of devices to partition the dimension. assign_tuple_sharding: If the sharding type should be a tuple. use_sharding_op: If true, adds a sharding op to set the sharding. + input_shape: The full shape of the input tensor. """ if use_sharding_op: tensor = tf2xla.sharding(tensor) - Sharding.split(tensor, split_dimension, num_devices).apply_to_tensor( - tensor, - assign_tuple_sharding=assign_tuple_sharding) + Sharding.split( + tensor, split_dimension, num_devices, input_shape).apply_to_tensor( + tensor, assign_tuple_sharding=assign_tuple_sharding) return tensor diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index 9fc0c5b04d0..d6c1a034859 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -353,6 +353,7 @@ pybind_extension( "//tensorflow/compiler/xla/client/lib:math", "//tensorflow/compiler/xla/client/lib:qr", "//tensorflow/compiler/xla/client/lib:self_adjoint_eig", + "//tensorflow/compiler/xla/client/lib:sorting", "//tensorflow/compiler/xla/client/lib:svd", "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:custom_call_target_registry", @@ -372,6 +373,7 @@ pybind_extension( # not require Tensorflow. "//tensorflow/core:lib_internal_impl", # buildcleaner: keep "//tensorflow/core/profiler/lib:profiler_backends", + "//tensorflow/core/profiler/lib:profiler_session", "//tensorflow/core/profiler/lib:traceme", "//tensorflow/core/profiler/rpc:profiler_server", "//tensorflow/stream_executor:device_memory_allocator", diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/BUILD b/tensorflow/compiler/xla/python/tpu_driver/client/BUILD index 5237ce3ab7a..148822f3ba7 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/BUILD +++ b/tensorflow/compiler/xla/python/tpu_driver/client/BUILD @@ -22,6 +22,7 @@ cc_library( "//tensorflow/compiler/xla/python:local_client", "//tensorflow/compiler/xla/python:semaphore", "//tensorflow/compiler/xla/python/tpu_driver", + "//tensorflow/compiler/xla/python/tpu_driver:direct_tpu_driver", "//tensorflow/compiler/xla/python/tpu_driver:grpc_tpu_driver", "//tensorflow/compiler/xla/python/tpu_driver:recording_tpu_driver", "//tensorflow/compiler/xla/python/tpu_driver:tpu_driver_proto_cc", diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc index 00e38a5f90d..f6e2fab7ef0 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc @@ -25,6 +25,11 @@ namespace xla { namespace py = pybind11; PYBIND11_MODULE(tpu_client_extension, m) { + // Initializes the NumPy API for the use of the types module. + if (!InitializeNumpyAPIForTypes()) { + throw std::runtime_error("Unable to initialize Numpy API"); + } + py::class_>(m, "TpuClient") .def_static("Get", &PyTpuClient::Get, py::arg("worker")) .def("device_count", &PyTpuClient::device_count) diff --git a/tensorflow/compiler/xla/python/tpu_driver/direct_tpu_driver.cc b/tensorflow/compiler/xla/python/tpu_driver/direct_tpu_driver.cc index 3e4626c5841..76d79786bbf 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/direct_tpu_driver.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/direct_tpu_driver.cc @@ -27,7 +27,8 @@ namespace tpu_driver { namespace { -// Enable the macro by default in the env where the libtpu.so is available. +// Enable the macro by default in the Google internal environment where the +// libtpu.so is linked in statically. #ifdef PLATFORM_GOOGLE #define TPU_SHARED_LIBRARY_COMPILE_LINK 1 #endif diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index 07fff76668f..4be375ac15a 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/qr.h" #include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h" +#include "tensorflow/compiler/xla/client/lib/sorting.h" #include "tensorflow/compiler/xla/client/lib/svd.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -349,7 +350,10 @@ void BuildOpsSubmodule(py::module* m) { py::arg("precision_config") = nullptr); ops.def("ConvertElementType", &ConvertElementType, py::arg("operand"), py::arg("new_element_type")); + // TODO(phawkins): remove CustomCall after callers are updated to use + // CustomCallWithLayout. ops.def("CustomCall", &CustomCallWithLayout); + ops.def("CustomCallWithLayout", &CustomCallWithLayout); ops.def("Dot", &Dot, py::arg("lhs"), py::arg("rhs"), py::arg("precision_config") = nullptr); ops.def("DotGeneral", &DotGeneral, py::arg("lhs"), py::arg("rhs"), @@ -451,6 +455,7 @@ void BuildOpsSubmodule(py::module* m) { }, py::arg("builder"), py::arg("operands"), py::arg("dimension") = -1, py::arg("comparator") = absl::nullopt); + ops.def("TopK", &TopK, py::arg("input"), py::arg("k")); ops.def("Transpose", &Transpose); ops.def("TriangularSolve", &TriangularSolve); ops.def("Tuple", &Tuple); @@ -458,6 +463,8 @@ void BuildOpsSubmodule(py::module* m) { ops.def("Igamma", &Igamma); ops.def("Igammac", &Igammac); + ops.def("IgammaGradA", &IgammaGradA); + ops.def("RandomGammaGrad", &RandomGammaGrad); ops.def("RegularizedIncompleteBeta", &RegularizedIncompleteBeta); #define BINARY_OP(op) \ @@ -494,6 +501,7 @@ void BuildOpsSubmodule(py::module* m) { #define UNARY_OP(op) ops.def(#op, &op) UNARY_OP(Not); + UNARY_OP(PopulationCount); UNARY_OP(Clz); UNARY_OP(Abs); UNARY_OP(Exp); diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index a8f29009d9e..9d53f9bd082 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -1189,12 +1189,12 @@ class ComputationBuilder(object): return ops.Call(self._builder, computation_to_apply.computation, list(operands)) - def CustomCall(self, - call_target_name, - operands, - shape_with_layout, - operand_shapes_with_layout, - opaque=None): + def CustomCallWithLayout(self, + call_target_name, + operands, + shape_with_layout, + operand_shapes_with_layout, + opaque=None): """Enqueues a custom call operation onto the computation. Args: @@ -1214,6 +1214,10 @@ class ComputationBuilder(object): list(operands), shape_with_layout, list(operand_shapes_with_layout), opaque) + # TODO(phawkins): remove CustomCall after callers are updated to use + # CustomCallWithLayout. + CustomCall = CustomCallWithLayout + def Map(self, operands, computation_to_apply, dimensions): """Enqueues a map operation onto the computation. @@ -1635,6 +1639,7 @@ FftType = _xla.FftType _UNARY_OPS = [ 'Not', + 'PopulationCount', 'Clz', 'Abs', 'Exp', @@ -1698,6 +1703,7 @@ _BINARY_OPS = [ 'ShiftRightLogical', 'Atan2', 'Igamma', + 'IgammaGradA', 'Igammac', 'Complex', 'NextAfter', @@ -1719,6 +1725,7 @@ _OTHER_OPS = [ 'Rev', 'Select', 'SliceInDim', + 'TopK', ] diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index a3a16f09ce6..de5ae258976 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -969,6 +969,12 @@ class SingleOpTest(ComputationTest): c.Not(c.Constant(arr)) self._ExecuteAndCompareClose(c, expected=~arr) + def testPopulationCount(self): + c = self._NewComputation() + arr = NumpyArrayS32([3, 0, 1]) + c.PopulationCount(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=np.array([2, 0, 1])) + def testCountLeadingZeros(self): c = self._NewComputation() arr = NumpyArrayS32([0x7FFF, 0x12345678]) diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 7dc03511f30..da50e92de32 100755 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -27,9 +27,7 @@ package_group( includes = [ "//tensorflow/compiler/xla:friends", ], - packages = [ - "//learning/brain/experimental/tf_runtime/...", - ], + packages = ["//learning/brain/experimental/tf_runtime/..."], ) tf_proto_library_cc( @@ -1947,6 +1945,51 @@ tf_cc_test( ], ) +cc_library( + name = "all_reduce_combiner", + srcs = ["all_reduce_combiner.cc"], + hdrs = ["all_reduce_combiner.h"], + deps = [ + ":hlo", + ":hlo_domain_map", + ":hlo_pass", + ":hlo_query", + ":hlo_reachability", + ":shape_inference", + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "all_reduce_combiner_test", + srcs = ["all_reduce_combiner_test.cc"], + deps = [ + ":all_reduce_combiner", + ":hlo", + ":hlo_matchers", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", + ], +) + cc_library( name = "all_reduce_simplifier", srcs = ["all_reduce_simplifier.cc"], @@ -3389,6 +3432,7 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/core:lib", + "//tensorflow/core/platform:hash", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", ], @@ -4301,6 +4345,7 @@ cc_library( ":call_graph", ":hlo", ":hlo_pass", + ":hlo_query", ":pattern_matcher", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 5f50c2b303b..fd373671b97 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -3204,53 +3204,6 @@ StatusOr AlgebraicSimplifierVisitor::TrySimplifyScalarSlice( return false; } - if (slice->operand(0)->opcode() == HloOpcode::kPad) { - VLOG(10) << "Trying to simplify scalar slice of pad"; - // Check there's no internal padding. Again, we could handle that too, since - // everything is statically known, but it's not worth it. - auto pad = Cast(slice->mutable_operand(0)); - auto padding_config = pad->padding_config(); - int64 rank = padding_config.dimensions_size(); - if (HasInteriorPadding(padding_config)) { - VLOG(10) << "Not folding scalar slice of pad, pad has interior padding"; - return false; - } - - // Check whether the scalar we're slicing out falls into the padding. - bool in_padding = [&]() { - for (int64 i = 0; i < rank; ++i) { - int64 start = slice->slice_starts(i); - int64 low = padding_config.dimensions(i).edge_padding_low(); - int64 data = pad->operand(0)->shape().dimensions(i); - if (start < low || start >= low + data) { - return true; - } - } - return false; - }(); - - if (in_padding) { - VLOG(10) << "Folding scalar slice of pad into padding value"; - TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( - slice, HloInstruction::CreateReshape(slice->shape(), - pad->mutable_padding_value()))); - return true; - } else { - // We already know the output of the slice is scalar. If the padded - // value is scalar, and it's not in the padding, then it's exactly the - // output value. - bool replaced = - ReplaceInstructionIfSameShape(slice, pad->mutable_operand(0)); - if (replaced) { - VLOG(10) << "Folding scalar slice of pad into padded value"; - } else { - VLOG(10) << "Not folding scalar slice of pad into padded value as they " - "have different shapes."; - } - return replaced; - } - } - if (slice->operand(0)->opcode() == HloOpcode::kConcatenate) { VLOG(10) << "Trying to simplify scalar slice of concat"; // Only do this for R1, there's no chance of this being useful otherwise. @@ -3356,20 +3309,54 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { HloInstruction* pad; HloInstruction* pad_operand; if (Match(slice, m::Slice(m::Pad(&pad, m::Op(&pad_operand), m::Op())))) { + // Is the result of the slice the pad operand. bool slice_undoes_pad = true; + // Can the slice be moved to the pad_operand without any padding being read. + bool slice_inside_pad = true; + // Does this slice slice out pading only. + bool slice_in_padding = false; + std::vector new_starts = slice->slice_starts(); + std::vector new_limits = slice->slice_limits(); for (int64 i = 0; i < slice->shape().rank(); ++i) { - if (slice->slice_starts(i) != - pad->padding_config().dimensions(i).edge_padding_low()) { + const int64 start = slice->slice_starts(i); + const int64 stride = slice->slice_strides(i); + const int64 limit = slice->slice_limits(i); + const int64 size = pad->shape().dimensions(i); + + const auto& dim = pad->padding_config().dimensions(i); + const int64 low = dim.edge_padding_low(); + const int64 high = dim.edge_padding_high(); + const int64 interior = dim.interior_padding(); + const int64 edge = size - high; + + if (limit <= low || start >= edge) { + slice_in_padding = true; + break; + } + + if (start != low || stride - 1 != interior) { slice_undoes_pad = false; } - if (slice->slice_strides(i) - 1 != - pad->padding_config().dimensions(i).interior_padding()) { - slice_undoes_pad = false; + + if (start < low || limit > edge || interior != 0 || stride != 1) { + slice_inside_pad = false; } + new_starts[i] -= low; + new_limits[i] -= low; + } + if (slice_in_padding) { + return ReplaceInstruction( + slice, MakeBroadcastHlo(pad->mutable_operand(1), {}, slice->shape())); } if (slice_undoes_pad && ReplaceInstructionIfSameShape(slice, pad_operand)) { return Status::OK(); } + if (slice_inside_pad) { + TF_ASSIGN_OR_RETURN(HloInstruction * new_slice, + MakeSliceHlo(pad_operand, new_starts, new_limits, + slice->slice_strides())); + return ReplaceInstruction(slice, new_slice); + } } if (slice->operand(0)->opcode() == HloOpcode::kSlice && @@ -3727,7 +3714,8 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { // Convert Reduce(Dot(X,Y)) to Dot(X,Y) if any of the dimensions reduced were // batch dimensions of the dot. The transformation supports reducing other // dimensions as well. - if (Match(arg, m::Dot(&dot, m::Op(&lhs), m::Op(&rhs)).WithOneUser()) && + if (options_.enable_dot_strength_reduction() && + Match(arg, m::Dot(&dot, m::Op(&lhs), m::Op(&rhs)).WithOneUser()) && Match(reduce->to_apply()->root_instruction(), m::Add(m::Parameter(), m::Parameter())) && absl::c_any_of(reduce->dimensions(), [&](int64 dim) { diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 8f66f8084f3..31fa125b3e1 100755 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -4389,7 +4389,7 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadLow) { AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, GmockMatch(m::Reshape(m::Constant()))); + EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant()))); } TEST_F(AlgebraicSimplifierTest, SliceOfPadHigh) { @@ -4410,7 +4410,7 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadHigh) { AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, GmockMatch(m::Reshape(m::Constant()))); + EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant()))); } TEST_F(AlgebraicSimplifierTest, SliceOfPadMidNonScalar) { @@ -4429,7 +4429,31 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadMidNonScalar) { AlgebraicSimplifierOptions options; AlgebraicSimplifier simplifier(options); - EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Slice(m::Parameter(0)))); +} + +TEST_F(AlgebraicSimplifierTest, SliceOfPad) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + param = f32[3,4] parameter(0) + constant = f32[] constant(0.0) + pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5 + ROOT slice = f32[2,3] slice(f32[8,10] pad), slice={[4:6],[2:5]} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options; + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Slice(m::Parameter(0)))); + EXPECT_THAT(root->slice_starts(), ElementsAre(1, 1)); } TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalarConstant) { @@ -4450,7 +4474,7 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalarConstant) { AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, GmockMatch(m::Reshape(m::Constant()))); + EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant()))); } TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalar) { @@ -4494,7 +4518,7 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadSomeDimsInPadding) { AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, GmockMatch(m::Reshape(m::ConstantScalar(-7.0)))); + EXPECT_THAT(root, GmockMatch(m::Broadcast(m::ConstantScalar(-7.0)))); } TEST_F(AlgebraicSimplifierTest, SliceOfConcatScalarInput) { diff --git a/tensorflow/compiler/xla/service/all_reduce_combiner.cc b/tensorflow/compiler/xla/service/all_reduce_combiner.cc new file mode 100644 index 00000000000..2b41f19f288 --- /dev/null +++ b/tensorflow/compiler/xla/service/all_reduce_combiner.cc @@ -0,0 +1,452 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/all_reduce_combiner.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/hlo_domain_map.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_query.h" +#include "tensorflow/compiler/xla/service/hlo_reachability.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +// Combines the elements of to_combine into a single AllReduce op. All +// entries in to_combine must be AllReduce ops with exactly one operand +// and the same reduction operation. +Status CombineAllReduces(absl::Span to_combine) { + if (to_combine.size() < 2) { + return Status::OK(); + } + VLOG(1) << "Combined " << to_combine.size() << " CRS ops"; + + HloComputation& computation = *to_combine.back()->parent(); + HloComputation* reduction = to_combine[0]->to_apply(); + const HloOpcode type = reduction->root_instruction()->opcode(); + + // Create a single bigger AllReduce of the operands of the smaller + // AllReduces. + std::vector operands; + std::vector operand_shapes; + VLOG(1) << "Combining set"; + for (HloInstruction* hlo : to_combine) { + VLOG(1) << "Set element: " << hlo->ToString(); + TF_RET_CHECK(hlo->opcode() == HloOpcode::kAllReduce); + TF_RET_CHECK(hlo->operands().size() == 1); + TF_RET_CHECK(hlo->to_apply() == reduction || + (hlo->to_apply()->instruction_count() == 3 && + hlo->to_apply()->num_parameters() == 2 && + hlo->to_apply()->root_instruction()->opcode() == type)); + TF_RET_CHECK(hlo->shape().IsArray()); + for (HloInstruction* operand : hlo->operands()) { + operands.push_back(operand); + operand_shapes.push_back(operand->shape()); + } + } + + HloInstruction* combined; + // AllReduce ops with more than one operand produce a tuple. + TF_RET_CHECK(operands.size() >= 2); + combined = computation.AddInstruction(HloInstruction::CreateAllReduce( + ShapeUtil::MakeTupleShape(operand_shapes), operands, reduction, + to_combine.front()->replica_groups(), + /*constrain_layout=*/false, to_combine.front()->channel_id())); + + // We have to propagate the sharding manually because Domain instructions are + // not guaranteed to preserve it for side effecting instructions. + if (to_combine.front()->has_sharding()) { + combined->set_sharding(to_combine.front()->sharding()); + } + VLOG(1) << "Replacing with : " << combined->ToString(); + + // Replace all the smaller AllReduces with elements of the tuple output + // of the single bigger AllReduce. + for (int64 i = 0; i < to_combine.size(); ++i) { + auto replace_with = HloInstruction::CreateGetTupleElement( + to_combine[i]->shape(), combined, i); + TF_RETURN_IF_ERROR(computation.ReplaceWithNewInstruction( + to_combine[i], std::move(replace_with))); + } + return Status::OK(); +} + +struct GroupKey { + GroupKey(const HloInstruction* hlo, const HloDomainMap& domain_map) + : opcode(hlo->to_apply()->root_instruction()->opcode()), + accum_type(hlo->to_apply()->root_instruction()->shape().element_type()), + domain_id(domain_map.GetDomainMetadataId(hlo)), + is_cross_shard(hlo->channel_id().has_value()), + replica_groups(hlo->replica_groups()) {} + + bool operator<(const GroupKey& other) const { + if (opcode != other.opcode) { + return opcode < other.opcode; + } + if (accum_type != other.accum_type) { + return accum_type < other.accum_type; + } + if (domain_id != other.domain_id) { + return domain_id < other.domain_id; + } + if (is_cross_shard != other.is_cross_shard) { + return is_cross_shard < other.is_cross_shard; + } + if (replica_groups.size() != other.replica_groups.size()) { + return replica_groups.size() < other.replica_groups.size(); + } + for (int64 i = 0; i < replica_groups.size(); ++i) { + const auto& rg = replica_groups[i]; + const auto& org = other.replica_groups[i]; + if (rg.replica_ids_size() != org.replica_ids_size()) { + return rg.replica_ids_size() < org.replica_ids_size(); + } + for (int64 j = 0; j < rg.replica_ids_size(); ++j) { + if (rg.replica_ids(j) != org.replica_ids(j)) { + return rg.replica_ids(j) < org.replica_ids(j); + } + } + } + return false; + } + + HloOpcode opcode; + PrimitiveType accum_type; + int64 domain_id; + bool is_cross_shard; + std::vector replica_groups; +}; + +// Group AllReduce instructions by the reduction types, e.g., add, min, +// max, replica groups and domain. For cross-module all reduce instructions +// we group them by the set of domains they are reducing across. +// +// Note that the shape of the reduction computation is not included in the +// reduction types, e.g.: "f32[] add" and "bf16[] add" will be the same type. We +// need to disallow combining CRS instructions with different domain metadata as +// well as that could end up short-cutting two or more different domains. +// +// In each group, the instructions should be in post order. We will then iterate +// each group and try to combine them, so to prevent non-determinism, we use +// std::map here. +// +// The return value is a list of groups where every group contains a list of +// all-reduce instruction sets in topological order and with a deterministic +// order within the set. Additionally due to the above constraints every all +// reduce set within a group will contain the same number of elements +// and every instruction within an all reduce set will have the same +// all-reduce-id (if specified) and thus shape (all reduce sets without an +// all-reduce-id will have a single instruction). +using InstructionGroups = + std::vector>>; +StatusOr CreateComputationGroups( + HloComputation* computation) { + TF_ASSIGN_OR_RETURN(auto domain_map, HloDomainMap::Create(computation, "")); + + // Group instructions by opcode, domain id and replica group. + std::map> opcode_groups; + for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) { + if (instruction->opcode() != HloOpcode::kAllReduce) { + continue; + } + if (instruction->to_apply()->instruction_count() != 3 || + instruction->to_apply()->num_parameters() != 2) { + VLOG(1) << "Skipping due to non-trivial reduction function."; + continue; + } + opcode_groups[GroupKey(instruction, *domain_map)].push_back(instruction); + } + + // Generate a unique all-reduce-id for instructions without one by negating + // the unique id of the hlo. This way we can treat cross module and normal CRS + // instructions uniformly. + auto channel_id = [](const HloInstruction* all_reduce) { + return all_reduce->IsCrossModuleAllReduce() + ? all_reduce->channel_id().value() + : -1 * all_reduce->unique_id(); + }; + + // Group instructions by all-reduce id with instructions for an all-reduce id + // is listed along their group id and the (group id, instruction) pairs are + // sorted by group id in the vector. + std::map>> + all_reduce_sets; + int64 group_id = 0; + for (auto& domain_groups : opcode_groups) { + for (HloInstruction* hlo : domain_groups.second) { + all_reduce_sets[channel_id(hlo)].emplace_back(group_id, hlo); + } + ++group_id; + } + + // Group instructions by participating group ids. Instructions within a group + // are sorted by topological order and instructions within an all reduce group + // is still sorted by group id. + std::map, std::vector>> + all_reduce_group_map; + for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) { + if (instruction->opcode() != HloOpcode::kAllReduce) { + continue; + } + if (instruction->to_apply()->instruction_count() != 3 || + instruction->to_apply()->num_parameters() != 2) { + VLOG(1) << "Skipping due to non-trivial reduction function."; + continue; + } + + int64 arid = channel_id(instruction); + if (all_reduce_sets.count(arid) == 0) { + // Already processed. + continue; + } + + std::vector group_ids; + std::vector instructions; + for (const auto& hlo : all_reduce_sets[arid]) { + group_ids.push_back(hlo.first); + instructions.push_back(hlo.second); + } + all_reduce_group_map[group_ids].push_back(std::move(instructions)); + all_reduce_sets.erase(arid); + } + CHECK(all_reduce_sets.empty()); + + InstructionGroups groups; + for (const auto& all_reduce_group : all_reduce_group_map) { + groups.push_back(all_reduce_group.second); + } + return std::move(groups); +} + +} // namespace + +AllReduceCombiner::AllReduceCombiner(int64 combine_threshold_in_bytes, + int64 combine_threshold_count) + : combine_threshold_in_bytes_(combine_threshold_in_bytes), + combine_threshold_count_(combine_threshold_count) {} + +StatusOr AllReduceCombiner::Run(HloModule* module) { + VLOG(1) << "Running AllReduceCombiner with threshold of " + << combine_threshold_in_bytes_ << " bytes"; + + if (hlo_query::ContainsLayoutConstrainedAllReduce(*module)) { + VLOG(1) << "Skip AllReduceCombiner because the module contains all-reduce " + "with constrained layouts"; + return false; + } + + bool changed = false; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + TF_ASSIGN_OR_RETURN(auto groups, CreateComputationGroups(computation)); + for (auto group : groups) { + // Recompute reachability after every combine group because we can't + // maintain a cross group topolgical order to be able to rely on the + // transitive dependencies to detect cycles. + auto reachability = HloReachabilityMap::Build(computation); + + // Create a map to be able to find an instruction group based on the first + // instruction in the group. It will be used during the post order + // iteration to be able to process full groups at a time. Doing it only + // for one instruction in every group will be sufficient because all + // instruction have to schedule at the same time due to cross core + // dependencies. + absl::flat_hash_map*> + group_map; + for (auto& instruction : group) { + group_map[instruction.front()] = &instruction; + } + + // Collect sets of AllReduce instructions to combine. + std::vector>> combine_sets(1); + int64 current_size_in_bytes = 0; + int64 current_operand_count = 0; + + // Iterate all instructions in post order and skip the ones not in the + // current group. We have to create a new post order iteration for every + // group because merging instructions in the previous group can made the + // original post order no longer hold. + // This will make it likely that we won't increase memory pressure much + // above combine_threshold_in_bytes, since two AllReduces that are + // near in post order are most likely, but not for sure, also near in + // scheduled order. + // + // TODO(b/70235266): This should usually be fine, but it's probably + // possible to construct some case where the memory usage increases beyond + // the threshold due to reordering of the instructions in scheduling. If + // this ever comes up as a real problem, it would be nice to implement + // safeguards so that that cannot possibly happen. + for (const HloInstruction* inst : + computation->MakeInstructionPostOrder()) { + auto it = group_map.find(inst); + if (it == group_map.end()) { + // Instruction belongs to a different group. + continue; + } + const auto& instructions = *it->second; + + VLOG(1) << "Considering HLO " << instructions.front()->ToString() + << " with current set size of " << current_size_in_bytes + << " and current operand count of " << current_operand_count; + + // We do not handle AllReduce ops that do not have exactly 1 + // operand since that is simpler and this pass is the only way to + // generate such ops and it should rarely be important to consider the + // same ops again. + if (instructions.front()->operands().size() != 1) { + VLOG(1) << "Skipping due to " + << instructions.front()->operands().size() << " operands"; + continue; + } + + int64 size_in_bytes; + TF_RET_CHECK(instructions.front()->shape().IsArray()); + size_in_bytes = ShapeUtil::ByteSizeOf(instructions.front()->shape()); + + if (size_in_bytes > combine_threshold_in_bytes_) { + VLOG(1) << "Skipping due to size " << size_in_bytes + << " above threshold"; + // If the instruction is greather than the threshold, then we can + // never combine it with anything. + continue; + } + + // If the current set is dependent on the instruction, then create a new + // one to avoid the dependency. We move on from the current set instead + // of ignoring the instruction since otherwise a single AllReduce + // instruction that all the other ones depend on (such as one on the + // forward pass of a model) could disable this optimization entirely. + TF_RET_CHECK(!combine_sets.empty()); + for (const auto& previous : combine_sets.back()) { + // The reachability information does not reflect the planned + // combination from combine_sets. We cannot just bring it up to date + // cheaply since HloReachabilityMap does not track reachability + // updates transitively and doing it directly is expensive. However, + // leaving it stale has no effect on the reachability queries that we + // are doing here because we are considering the ops in a topological + // order, so we can just leave it stale. + // + // Proof: Suppose A is the instruction we are looking to combine and B + // is an element of the current combine set that we are looking to + // combine A into. + // + // First of all, we check that all elements in each set do not depend + // on each other, so combining the *current* combine set cannot create + // new dependencies between A and B. It remains to prove that + // combining the prior combine sets also cannot create a dependency + // between A and B. + // + // Assume to get a contradiction that there are two AllReduce + // ops C and D in combine_sets that will be combined and that A and B + // are not connected now but that they will be after combining C and + // D. Then there exist paths in the dependency graph such that one of + // these cases is true: + // + // A -> ... -> C and D -> ... -> B + // A -> ... -> D and C -> ... -> B + // B -> ... -> C and D -> ... -> A + // B -> ... -> D and C -> ... -> A + // + // None of these cases are possible because we are visiting the nodes + // in a topological order, so C and D cannot be in-between A and B. + // That is a contradiction, so combining the prior combine sets also + // cannot create a dependency between A and B. + bool new_set = false; + for (int64 i = 0; i < instructions.size(); ++i) { + if (reachability->IsReachable(previous[i], instructions[i])) { + VLOG(1) << "Starting new set due to dependency between " + << previous[i]->ToString() << " AND " + << instructions[i]->ToString(); + new_set = true; + break; + } + } + if (new_set) { + combine_sets.emplace_back(); + current_size_in_bytes = 0; + current_operand_count = 0; + break; + } + } + + if (current_size_in_bytes + size_in_bytes > + combine_threshold_in_bytes_ || + current_operand_count + 1 > combine_threshold_count_) { + VLOG(1) << "The instruction cannot be entered into the set due " + "to the combined size being too large."; + // In this case we cannot include the instruction into the current set + // since then it would grow beyond the threshold. The set of + // instructions to carry forward will either be the current set or the + // instruction by itself, whichever is smaller, since that maximizes + // the chance of being able to combine with the next instruction. + if (size_in_bytes > current_size_in_bytes) { + VLOG(1) << "Skipping as the instruction is larger than the set."; + continue; // keep the current set + } + VLOG(1) + << "Resetting the set as the set is larger than the instruction."; + combine_sets.emplace_back(); + current_size_in_bytes = 0; + current_operand_count = 0; + } + + VLOG(1) << "Adding instruction to set."; + combine_sets.back().push_back(instructions); + current_size_in_bytes += size_in_bytes; + current_operand_count += 1; + TF_RET_CHECK(current_size_in_bytes <= combine_threshold_in_bytes_); + TF_RET_CHECK(current_operand_count <= combine_threshold_count_); + } + VLOG(1) << "Done constructing sets. Final set size is " + << current_size_in_bytes << " bytes and " << current_operand_count + << " operands"; + + // Combine the collected sets of AllReduce instructions. + for (const auto& combine_set : combine_sets) { + if (combine_set.size() >= 2) { + changed = true; + for (int64 i = 0; i < combine_set.front().size(); ++i) { + std::vector to_combine; + to_combine.reserve(combine_set.size()); + for (const auto& c : combine_set) { + to_combine.push_back(c[i]); + } + TF_RETURN_IF_ERROR(CombineAllReduces(to_combine)); + } + } + } + } + } + + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/all_reduce_combiner.h b/tensorflow/compiler/xla/service/all_reduce_combiner.h new file mode 100644 index 00000000000..92f85058552 --- /dev/null +++ b/tensorflow/compiler/xla/service/all_reduce_combiner.h @@ -0,0 +1,51 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ALL_REDUCE_COMBINER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_ALL_REDUCE_COMBINER_H_ + +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +// Combines small non-dependent AllReduce ops into larger combined +// AllReduce ops. A typical AllReduce implementation has a minimum +// latency-induced time for a AllReduce op so a single combined op can be +// more efficient than many small ones. +class AllReduceCombiner : public HloModulePass { + public: + AllReduceCombiner(int64 combine_threshold_in_bytes, + int64 combine_threshold_count); + + absl::string_view name() const override { return "all-reduce-combiner"; } + + StatusOr Run(HloModule* module) override; + + private: + // Combine all reduce ops up to this threshold. + int64 combine_threshold_in_bytes_; + + // Combine all reduce ops up to this threshold (number of operands). + int64 combine_threshold_count_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_ALL_REDUCE_COMBINER_H_ diff --git a/tensorflow/compiler/xla/service/all_reduce_combiner_test.cc b/tensorflow/compiler/xla/service/all_reduce_combiner_test.cc new file mode 100644 index 00000000000..0793ba2ba4b --- /dev/null +++ b/tensorflow/compiler/xla/service/all_reduce_combiner_test.cc @@ -0,0 +1,477 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/all_reduce_combiner.h" + +#include + +#include "absl/memory/memory.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +using absl::nullopt; +using ::testing::AllOf; +namespace op = xla::testing::opcode_matchers; +int64 kMaxCombineCount = 256; + +int64 AllReduceCount(const HloModule& module) { + int64 count = 0; + for (HloComputation* computation : module.computations()) { + if (computation->IsFusionComputation()) { + continue; + } + for (HloInstruction* hlo : computation->instructions()) { + if (hlo->opcode() == HloOpcode::kAllReduce) { + ++count; + } + } + } + return count; +} + +// inputs[i] will be some op producing a shape of size sizes_in_kib[i] which +// feeds into a a all reduce op in all_reduces[i]. Returns a tuple +// of the all_reduces. +HloInstruction* MakeCrossReplicaReductions( + std::vector sizes_in_kib, std::vector reductions, + std::vector* inputs, HloComputation::Builder* b) { + CHECK_EQ(reductions.size(), sizes_in_kib.size()); + std::vector all_reduces; + for (int i = 0; i < sizes_in_kib.size(); i++) { + int64 size_in_kib = sizes_in_kib[i]; + HloComputation* reduction = reductions[i]; + auto constant = b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.3))); + Shape shape = ShapeUtil::MakeShape( + F32, {static_cast(size_in_kib * 1024 / sizeof(float))}); + auto input = + b->AddInstruction(HloInstruction::CreateBroadcast(shape, constant, {})); + inputs->push_back(input); + all_reduces.push_back(b->AddInstruction(HloInstruction::CreateAllReduce( + shape, {input}, reduction, /*replica_groups=*/{}, + /*constrain_layout=*/false, /*channel_id=*/nullopt))); + } + return b->AddInstruction(HloInstruction::CreateTuple(all_reduces)); +} + +// Create and add a reduction computation in the given type to the module. +HloComputation* MakeReduction(const HloOpcode type, HloModule* module) { + HloComputation::Builder sum_builder(HloOpcodeString(type)); + auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "x")); + auto y = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "y")); + sum_builder.AddInstruction( + HloInstruction::CreateBinary(ShapeUtil::MakeShape(F32, {}), type, x, y)); + HloComputation* reduction = + module->AddEmbeddedComputation(sum_builder.Build()); + return reduction; +} + +// Creates replica groups for AllReduce. groups[i] represents replica ids +// for group 'i'. +std::vector CreateReplicaGroups( + absl::Span> groups) { + std::vector replica_groups(groups.size()); + for (int64 i = 0; i < groups.size(); ++i) { + *replica_groups[i].mutable_replica_ids() = {groups[i].begin(), + groups[i].end()}; + } + return replica_groups; +} + +using AllReduceCombinerTest = HloTestBase; + +// Tests combination of several AllReduce instructions. +TEST_F(AllReduceCombinerTest, CombineAllReduces) { + auto module = CreateNewVerifiedModule(); + HloComputation* sum = MakeReduction(HloOpcode::kAdd, module.get()); + + HloComputation::Builder b(TestName()); + std::vector inputs; + auto root = MakeCrossReplicaReductions( + {1, 2, 10, 7, 6}, {sum, sum, sum, sum, sum}, &inputs, &b); + auto computation = module->AddEntryComputation(b.Build()); + + // Run the AllReduce combiner optimization pass. + AllReduceCombiner combine(10 * 1024 * 1024, kMaxCombineCount); + ASSERT_EQ(AllReduceCount(*module), inputs.size()); + TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get())); + ASSERT_EQ(AllReduceCount(*module), 1); + EXPECT_TRUE(changed); + + ASSERT_EQ(root, computation->root_instruction()); + ASSERT_EQ(inputs.size(), root->operands().size()); + + HloInstruction* combined = nullptr; + for (int64 i = 0; i < root->operands().size(); ++i) { + HloInstruction* hlo = root->mutable_operand(i); + ASSERT_TRUE(hlo->opcode() == HloOpcode::kGetTupleElement); + EXPECT_EQ(hlo->tuple_index(), i); + EXPECT_TRUE(ShapeUtil::Equal(inputs[i]->shape(), hlo->shape())); + + if (combined == nullptr) { + // Verify the combined all reduce instruction. + combined = hlo->mutable_operand(0); + ASSERT_TRUE(combined->opcode() == HloOpcode::kAllReduce); + EXPECT_TRUE(ShapeUtil::Equal(root->shape(), combined->shape())); + ASSERT_EQ(combined->operands().size(), inputs.size()); + } + EXPECT_EQ(combined, hlo->operand(0)); + EXPECT_TRUE(ShapeUtil::Equal(inputs[i]->shape(), hlo->shape())); + EXPECT_EQ(combined->operand(i), inputs[i]); + EXPECT_EQ(1, inputs[i]->users().size()); + } + ASSERT_NE(combined, nullptr); +} + +// Tests combination of several cross replica reduction instructions in +// different types.k +TEST_F(AllReduceCombinerTest, CombineCrossReplicaReductionsInGroups) { + auto module = CreateNewVerifiedModule(); + HloComputation* sum = MakeReduction(HloOpcode::kAdd, module.get()); + HloComputation* min = MakeReduction(HloOpcode::kMinimum, module.get()); + HloComputation* max = MakeReduction(HloOpcode::kMaximum, module.get()); + HloComputation* sum_2 = MakeReduction(HloOpcode::kAdd, module.get()); + + HloComputation::Builder b(TestName()); + std::vector inputs; + MakeCrossReplicaReductions( + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + {sum, sum_2, min, min, min, max, max, max, sum, sum_2}, &inputs, &b); + module->AddEntryComputation(b.Build()); + + // Run the AllReduce combiner optimization pass. + AllReduceCombiner combine(10 * 1024 * 1024, kMaxCombineCount); + ASSERT_EQ(AllReduceCount(*module), inputs.size()); + TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get())); + ASSERT_EQ(AllReduceCount(*module), 3) + << "expects 3 groups for 3 reduction types."; + EXPECT_TRUE(changed); +} + +// Tests that the combination threshold is respected. +TEST_F(AllReduceCombinerTest, RespectThreshold) { + auto module = CreateNewVerifiedModule(); + HloComputation* sum = MakeReduction(HloOpcode::kAdd, module.get()); + + HloComputation::Builder b(TestName()); + std::vector inputs; + MakeCrossReplicaReductions({8, 4}, {sum, sum}, &inputs, &b); + module->AddEntryComputation(b.Build()); + + // Run the AllReduce combiner optimization pass with threshold less than + // the combined size of the all reduce ops so that the combination + // cannot occur. + { + AllReduceCombiner combine((8 + 4) * 1024 - 1, kMaxCombineCount); + ASSERT_EQ(AllReduceCount(*module), inputs.size()); + TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get())); + EXPECT_EQ(AllReduceCount(*module), inputs.size()); + EXPECT_FALSE(changed); + } + + // Run the AllReduce combiner optimization pass again with a slightly + // higher threshold so that the combination can occur. + { + AllReduceCombiner combine((8 + 4) * 1024, kMaxCombineCount); + ASSERT_EQ(AllReduceCount(*module), inputs.size()); + TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get())); + EXPECT_EQ(AllReduceCount(*module), 1); + EXPECT_TRUE(changed); + } +} + +// Tests that dependent all reduces are not combined. +TEST_F(AllReduceCombinerTest, NoDependentCombination) { + auto module = CreateNewVerifiedModule(); + HloComputation* reduction = MakeReduction(HloOpcode::kAdd, module.get()); + + HloComputation::Builder b(TestName()); + auto constant = b.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.3))); + auto all_reduce = b.AddInstruction(HloInstruction::CreateAllReduce( + constant->shape(), {constant}, reduction, /*replica_groups=*/{}, + /*constrain_layout=*/false, /*channel_id=*/nullopt)); + b.AddInstruction(HloInstruction::CreateAllReduce( + constant->shape(), {all_reduce}, reduction, + /*replica_groups=*/{}, /*constrain_layout=*/false, + /*channel_id=*/nullopt)); + + module->AddEntryComputation(b.Build()); + + AllReduceCombiner combine(1024 * 1024, kMaxCombineCount); + ASSERT_EQ(AllReduceCount(*module), 2); + TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get())); + EXPECT_EQ(AllReduceCount(*module), 2); + EXPECT_FALSE(changed); +} + +// Tests that AllReduce ops with different groups are not combined. +TEST_F(AllReduceCombinerTest, GroupAllReduce) { + auto module = CreateNewVerifiedModule(); + HloComputation::Builder b(TestName()); + HloComputation* reduction = MakeReduction(HloOpcode::kAdd, module.get()); + + auto constant = b.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.3))); + auto crs0 = b.AddInstruction( + HloInstruction::CreateAllReduce(constant->shape(), {constant}, reduction, + CreateReplicaGroups({{0, 1}, {2, 3}}), + /*constrain_layout=*/false, + /*channel_id=*/nullopt)); + auto crs1 = b.AddInstruction( + HloInstruction::CreateAllReduce(constant->shape(), {constant}, reduction, + CreateReplicaGroups({{0, 2}, {1, 3}}), + /*constrain_layout=*/false, + /*channel_id=*/nullopt)); + b.AddInstruction(HloInstruction::CreateTuple({crs0, crs1})); + + module->AddEntryComputation(b.Build()); + + AllReduceCombiner combine(1024 * 1024, kMaxCombineCount); + ASSERT_EQ(AllReduceCount(*module), 2); + TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get())); + EXPECT_EQ(AllReduceCount(*module), 2); + EXPECT_FALSE(changed); +} + +TEST_F(AllReduceCombinerTest, DomainPreventsCombining) { + const char* const hlo_string = R"( +HloModule Module + +summit { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY entry { + param0 = f32[128] parameter(0), sharding={maximal device=0} + param1 = f32[128] parameter(1), sharding={maximal device=1} + crs0 = f32[128] all-reduce(param0), + replica_groups={}, to_apply=summit, sharding={maximal device=0} + crs1 = f32[128] all-reduce(param1), + replica_groups={}, to_apply=summit, sharding={maximal device=1} + domain0 = f32[128] domain(crs0), + domain={kind="sharding", entry={{maximal device=0}, {maximal device=1}}, exit={maximal device=0}} + domain1 = f32[128] domain(crs1), + domain={kind="sharding", entry={{maximal device=0}, {maximal device=1}}, exit={maximal device=1}} + ROOT tuple = (f32[128], f32[128]) tuple(domain0, domain1), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + AllReduceCombiner combine(1024 * 1024, kMaxCombineCount); + ASSERT_EQ(AllReduceCount(*module), 2); + TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get())); + EXPECT_EQ(AllReduceCount(*module), 2); + EXPECT_FALSE(changed); +} + +// This test checks that two CRS instructions that are in separate domains +// but with the same domain metadata can be combined. +TEST_F(AllReduceCombinerTest, CombineFromTwoDomainsWithSameMetadata) { + const char* const hlo_string = R"( +HloModule Module + +summit { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY entry { + param0 = f32[128] parameter(0), sharding={maximal device=0} + param1 = f32[128] parameter(1), sharding={maximal device=1} + param2 = f32[128] parameter(2), sharding={maximal device=1} + crs0 = f32[128] all-reduce(param0), + replica_groups={}, to_apply=summit, sharding={maximal device=0} + crs1 = f32[128] all-reduce(param1), + replica_groups={}, to_apply=summit, sharding={maximal device=1} + crs2 = f32[128] all-reduce(param2), + replica_groups={}, to_apply=summit, sharding={maximal device=0} + domain0 = f32[128] domain(crs0), + domain={kind="sharding", entry={{maximal device=0}, {maximal device=1}, + {maximal device=0}}, exit={maximal device=0}} + domain1 = f32[128] domain(crs1), + domain={kind="sharding", entry={{maximal device=0}, {maximal device=1}, + {maximal device=0}}, exit={maximal device=1}} + domain2 = f32[128] domain(crs2), + domain={kind="sharding", entry={{maximal device=0}, {maximal device=1}, + {maximal device=0}}, exit={maximal device=0}} + ROOT tuple = (f32[128], f32[128], f32[128]) tuple(domain0, domain1, domain2), + sharding={{maximal device=0}, {maximal device=1}, {maximal device=0}} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AllReduceCombiner combine(1024 * 1024, kMaxCombineCount); + ASSERT_EQ(AllReduceCount(*module), 3); + TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get())); + EXPECT_EQ(AllReduceCount(*module), 2); + EXPECT_TRUE(changed); +} + +TEST_F(AllReduceCombinerTest, DoNotCombineCrossShardAndCrosReplicaInSPMD) { + const char* const hlo_string = R"( +HloModule Module + +summit { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY entry { + param0 = f32[128] parameter(0), sharding={maximal device=0} + param1 = f32[128] parameter(1), sharding={maximal device=1} + cross_shard_ar = f32[128] all-reduce(param0), + replica_groups={{0}}, to_apply=summit, channel_id=1 + cross_replica_ar = f32[128] all-reduce(param1), + replica_groups={{0}}, to_apply=summit, sharding={maximal device=1} + ROOT tuple = (f32[128], f32[128]) tuple(cross_shard_ar, cross_replica_ar) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AllReduceCombiner combine(1024 * 1024, kMaxCombineCount); + ASSERT_EQ(AllReduceCount(*module), 2); + TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get())); + EXPECT_EQ(AllReduceCount(*module), 2); + EXPECT_FALSE(changed); +} + +TEST_F(AllReduceCombinerTest, CrossCoreAllReduce) { + const char* const hlo_string = R"( +HloModule Module + +summit { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY entry { + param0 = f32[128] parameter(0), sharding={maximal device=0} + param1 = f32[128] parameter(1), sharding={maximal device=1} + crs00 = f32[128] all-reduce(param0), + replica_groups={{0}}, channel_id=1, to_apply=summit, + sharding={maximal device=0} + crs01 = f32[128] all-reduce(param1), + replica_groups={{0}}, channel_id=1, to_apply=summit, + sharding={maximal device=1} + crs10 = f32[128] all-reduce(param0), + replica_groups={{0}}, channel_id=2, to_apply=summit, + sharding={maximal device=0} + crs11 = f32[128] all-reduce(param1), + replica_groups={{0}}, channel_id=2, to_apply=summit, + sharding={maximal device=1} + domain0 = f32[128] domain(crs00), + domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}} + ROOT add = f32[128] add(domain0, crs11), + sharding={maximal device=1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AllReduceCombiner combine(1024 * 1024, kMaxCombineCount); + ASSERT_EQ(AllReduceCount(*module), 4); + TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get())); + EXPECT_EQ(AllReduceCount(*module), 2); + EXPECT_TRUE(changed); + + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Add(op::Domain(op::GetTupleElement( + AllOf(op::AllReduce(op::Parameter(0), op::Parameter(0)), + op::Shape("(f32[128], f32[128])")), + 1)), + op::GetTupleElement( + AllOf(op::AllReduce(op::Parameter(1), op::Parameter(1)), + op::Shape("(f32[128], f32[128])")), + 0))); +} + +TEST_F(AllReduceCombinerTest, CrossCombineGroupCycle) { + const char* const hlo_string = R"( +HloModule module + +%add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +%max { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] maximum(lhs, rhs) +} +ENTRY %comp { + p0 = f32[128] parameter(0) + p1 = f32[128] parameter(1) + + crs00 = f32[128] all-reduce(p0), to_apply=add + crs10 = f32[128] all-reduce(p1), to_apply=max + + crs01 = f32[128] all-reduce(crs00), to_apply=max + crs11 = f32[128] all-reduce(crs10), to_apply=add + add0 = f32[128] add(crs01, crs11) + + crs02 = f32[128] all-reduce(add0), to_apply=add + crs12 = f32[128] all-reduce(crs11), to_apply=add + ROOT tuple = (f32[128], f32[128]) tuple(crs02, crs12) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AllReduceCombiner combine(1024 * 1024, kMaxCombineCount); + ASSERT_EQ(AllReduceCount(*module), 6); + TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get())); + EXPECT_EQ(AllReduceCount(*module), 4); + EXPECT_TRUE(changed); + + auto crs0 = op::AllReduce(op::Parameter(0), op::AllReduce(op::Parameter(1))); + auto add = op::Add(op::AllReduce(op::GetTupleElement(crs0, 0)), + op::GetTupleElement(crs0, 1)); + auto crs1 = op::AllReduce(add, op::GetTupleElement(crs0)); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Tuple(op::GetTupleElement(crs1, 0), op::GetTupleElement(crs1, 1))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.cc b/tensorflow/compiler/xla/service/ar_crs_combiner.cc index ec8c391a542..dae9589e0a9 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner.cc +++ b/tensorflow/compiler/xla/service/ar_crs_combiner.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_query.h" #include "tensorflow/compiler/xla/service/hlo_replication_analysis.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -32,6 +33,60 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" namespace xla { +namespace { + +// In SPMD mode, if there's a cross-replica all-reduce that produces the same +// value for all partitions, replaces it with a global all-reduce and then +// divide by the number of partitions. Depending on the topology and the +// implementation of the all-reduce for the backend, this may give a better +// performance. +StatusOr ReplaceReplicatedAllReduce(HloModule* module, + int64 replica_count, + int64 partition_count) { + TF_ASSIGN_OR_RETURN( + auto replication_analysis, + HloReplicationAnalysis::Run(module, /*cross_partition_spmd=*/true)); + + bool changed = false; + int64 next_channel = hlo_query::NextChannelId(*module); + for (auto computation : module->computations()) { + for (auto instruction : computation->instructions()) { + if (auto ar = DynCast(instruction)) { + const Shape& shape = ar->shape(); + if (ar->channel_id()) { + continue; + } + if (ar->replica_groups().size() > 1) { + continue; + } + if (shape.IsTuple() || shape.element_type() != F32) { + continue; + } + // We would need a cost model for the target, but in general we want to + // rewrite only if the replica count in the original op was large. + if (replica_count < 8 * partition_count) { + continue; + } + if (replication_analysis->HloInstructionIsReplicatedAt(ar, {})) { + VLOG(2) << "Replaced replicated all-reduce:" << ar->ToString(); + ar->set_channel_id(next_channel++); + auto divisor = + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(partition_count))); + auto bcast = computation->AddInstruction( + HloInstruction::CreateBroadcast(shape, divisor, {})); + auto div = computation->AddInstruction(HloInstruction::CreateBinary( + ar->shape(), HloOpcode::kDivide, ar, bcast)); + TF_RETURN_IF_ERROR(ar->ReplaceAllUsesWith(div)); + changed = true; + } + } + } + } + return changed; +} + +} // namespace namespace m = match; @@ -508,7 +563,16 @@ StatusOr ArCrsCombiner::Run(HloModule* module) { TF_RETURN_IF_ERROR(KeepProvablyEqualInstructionGroupsMPMD()); } - return RewriteGraph(); + TF_ASSIGN_OR_RETURN(auto changed, RewriteGraph()); + + if (num_replicas_ > 1 && spmd_partition_) { + TF_ASSIGN_OR_RETURN(auto replaced, + ReplaceReplicatedAllReduce(module, num_replicas_, + num_spatial_partitions_)); + changed |= replaced; + } + + return changed; } } // namespace xla diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc index 609da2c33a0..2aaac4f2344 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc +++ b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc @@ -1711,9 +1711,9 @@ HloModule foobar ENTRY %entrycomp (p: bf16[]) -> (f32[]) { %p = bf16[] parameter(0) - %all-reduce.0 = f32[] all-reduce(%p), channel_id=1, replica_groups={{0,1}}, + %all-reduce.0 = f32[] all-reduce(%p), channel_id=1, replica_groups={{0},{1}}, to_apply=%sum.f32 - %all-reduce.2 = f32[] all-reduce(%all-reduce.0), replica_groups={{0,1}}, + %all-reduce.2 = f32[] all-reduce(%all-reduce.0), replica_groups={{0},{1}}, to_apply=%sum.f32 ROOT %tuple = (f32[]) tuple(%all-reduce.2) } @@ -1727,5 +1727,39 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[]) { EXPECT_FALSE(changed); } +TEST_F(ArCrsCombinerTest, ReplaceReplicatedAllReduceSPMD) { + const char* module_str = R"( +HloModule foobar + +%sum.f32 (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[2,4]) -> f32[2,4] { + %p = f32[2,4] parameter(0), sharding={replicated} + ROOT %all-reduce = f32[2,4] all-reduce(%p), replica_groups={{0,1}}, + to_apply=%sum.f32 +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + ArCrsCombiner combiner(/*num_spatial_partitions=*/4, /*num_replicas=*/64, + /*spmd_partition=*/true); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Divide(op::AllReduce(op::Parameter()), + op::Broadcast(op::Constant()))); + + auto ar = root->operand(0); + auto divisor = root->operand(1)->operand(0); + EXPECT_TRUE(ar->channel_id()); + EXPECT_TRUE(divisor->literal().IsAllFloat(4)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/collective_ops_utils.h b/tensorflow/compiler/xla/service/collective_ops_utils.h index 8b3c60f76de..2524b4190e9 100644 --- a/tensorflow/compiler/xla/service/collective_ops_utils.h +++ b/tensorflow/compiler/xla/service/collective_ops_utils.h @@ -149,7 +149,6 @@ struct AllReduceParticipantData { explicit AllReduceParticipantData(RendezvousKey rendezvous_key) : rendezvous_key(rendezvous_key) {} - int64 element_count; int64 device_ordinal; RendezvousKey rendezvous_key; @@ -157,20 +156,30 @@ struct AllReduceParticipantData { // source_buffer == destination_buffer if that avoids a NCCL copy (will depend // on how well the NCCL in-place implementation performs vs the out-of-place // implementation). - se::DeviceMemoryBase source_data; - se::DeviceMemoryBase destination_data; + struct Buffer { + int64 element_count; + se::DeviceMemoryBase source_data; + se::DeviceMemoryBase destination_data; + PrimitiveType primitive_type; + }; + std::vector buffers; se::Stream* stream; ReductionKind reduction_kind; - PrimitiveType primitive_type; int num_participants() const { return rendezvous_key.num_participants(); } string ToString() const { + std::vector buffer_strs; + for (const Buffer& buffer : buffers) { + buffer_strs.push_back( + absl::StrFormat("{element_count=%d}", buffer.element_count)); + } return absl::StrFormat( - "AllReduceParticipantData{element_count=%d, rendezvous_key=%s, " + "AllReduceParticipantData{buffers=[%s], rendezvous_key=%s, " "device_ordinal=%d, stream=%p}", - element_count, rendezvous_key.ToString(), device_ordinal, stream); + absl::StrJoin(buffer_strs, ","), rendezvous_key.ToString(), + device_ordinal, stream); } }; @@ -245,7 +254,7 @@ class Rendezvous { // Spot check for consistent replica counts among submitting threads. if (!participants_.empty() && - (participants_.back().element_count != participant.element_count || + (participants_.back().buffers.size() != participant.buffers.size() || participants_.back().rendezvous_key != participant.rendezvous_key)) { return InvalidArgument( "Mismatch among all-reduce participants. Expected same " diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index 1f6107d6f36..c07c3eb3c3b 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -1043,15 +1043,31 @@ Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph, HloInstruction* root = computation->root_instruction(); // Mark nondistinct/ambiguous indices. - absl::flat_hash_set seen; + absl::flat_hash_map seen; ShapeUtil::ForEachSubshape( root->shape(), [&](const Shape& /*subshape*/, const ShapeIndex& index) { std::vector buffers_at_index = alias_analysis->ComputeBuffersAt(root, index); bool buffer_seen_before = false; for (const HloBuffer* buffer : buffers_at_index) { - buffer_seen_before |= !seen.insert(buffer).second; + buffer_seen_before |= !seen.emplace(buffer, index).second; } + + if (buffer_seen_before && policy.copy_root_replicated_buffers && + computation == module->entry_computation() && + module->input_output_alias_config().OutputHasAlias(index) && + buffers_at_index.size() == 1) { + absl::optional alias = + module->input_output_alias_config().GetAliasedParameter(index); + CHECK(alias) << "Alias does not exist"; + const ShapeIndex& other_index = seen[buffers_at_index[0]]; + VLOG(2) << "Output indices " << index.ToString() << " and " + << other_index.ToString() << " are both aliased to " + << alias->parameter_number << " copying " << other_index; + add_index_to_copy(root, other_index); + return; + } + if (buffers_at_index.size() > 1 || (buffer_seen_before && policy.copy_root_replicated_buffers)) { VLOG(2) << "Index " << index << " of computation " @@ -1097,6 +1113,18 @@ Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph, return Status::OK(); } +static int64 GetNumExistingCopies(const HloModule* module) { + int64 num_existing_copies = 0; + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kCopy) { + ++num_existing_copies; + } + } + } + return num_existing_copies; +} + Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering, HloModule* module) { TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, @@ -1112,13 +1140,24 @@ Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering, } std::unique_ptr call_graph = CallGraph::Build(module); - for (HloComputation* computation : module->computations()) { - for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kCopy && - copy_remover.TryElideCopy(instruction)) { - TF_RETURN_IF_ERROR(StripControlDependenciesFrom(instruction)); - TF_RETURN_IF_ERROR( - instruction->ReplaceAllUsesWith(instruction->mutable_operand(0))); + + int64 num_existing_copies = GetNumExistingCopies(module); + bool changed = true; + int64 num_iterations = -1; + while (changed) { + CHECK_LE(++num_iterations, num_existing_copies); + changed = false; + VLOG(2) << "Running fixpoint iteration " << num_iterations + << " of copy elision"; + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kCopy && + copy_remover.TryElideCopy(instruction)) { + changed = true; + TF_RETURN_IF_ERROR(StripControlDependenciesFrom(instruction)); + TF_RETURN_IF_ERROR( + instruction->ReplaceAllUsesWith(instruction->mutable_operand(0))); + } } } } @@ -1156,17 +1195,6 @@ StatusOr CopyInsertion::Run(HloModule* module) { "Call graph must be flattened before copy insertion."); } - int64 num_existing_copies = 0; - if (VLOG_IS_ON(1)) { - for (HloComputation* computation : module->computations()) { - for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kCopy) { - ++num_existing_copies; - } - } - } - } - TF_RETURN_IF_ERROR(AddCopiesToResolveInterference(module)); // Simplify the tuple structures introduced by the deep copies. This should be @@ -1185,7 +1213,6 @@ StatusOr CopyInsertion::Run(HloModule* module) { RemoveUnnecessaryCopies(DependencyHloOrdering(module), module)); DumpHloModuleDuringPassIfEnabled(name(), "after removing unnecessary copies", *module); - TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module)); DumpHloModuleDuringPassIfEnabled(name(), "after adding special-case copies", *module); @@ -1202,7 +1229,8 @@ StatusOr CopyInsertion::Run(HloModule* module) { } } } - VLOG(1) << "Num copies before copy-insertion: " << num_existing_copies; + VLOG(1) << "Num copies before copy-insertion: " + << GetNumExistingCopies(module); VLOG(1) << "Num copies after copy-insertion: " << num_total_copies; } diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 8587c79ffb1..d58ee0ef20b 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -2274,5 +2274,69 @@ ENTRY TestComputation { op::While(op::Copy(op::Parameter()))); } +TEST_F(CopyInsertionTest, FixpointComputationRequired) { + const string& hlo_string = R"( +HloModule Module + +fused_computation { + param0 = f32[3,3,96,1] parameter(0) + param1 = f32[] parameter(1) + broadcast = f32[3,3,96,1] broadcast(f32[] param1), dimensions={} + ROOT %add.0 = f32[3,3,96,1] add(f32[3,3,96,1] param0, f32[3,3,96,1] broadcast) +} + +ENTRY entry_computation { + arg0 = f32[3,3,96,1] parameter(0) + arg1 = f32[] parameter(1) + fusion = f32[3,3,96,1] fusion(f32[3,3,96,1] arg0, f32[] arg1), + kind=kLoop, calls=fused_computation + negate = f32[] negate(f32[] arg1) + ROOT tuple = (f32[3,3,96,1], f32[3,3,96,1], f32[], f32[]) tuple( + f32[3,3,96,1] fusion, + f32[3,3,96,1] arg0, + f32[] negate, + f32[] arg1) +} + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + // Set up the aliasing manually which normally would be set by + // alias_passthrough_params pass. + ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( + /*output_index=*/{1}, + /*param_number=*/0, + /*param_index=*/{}, HloInputOutputAliasConfig::AliasKind::kUserAlias)); + ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( + /*output_index=*/{3}, + /*param_number=*/1, + /*param_index=*/{}, HloInputOutputAliasConfig::AliasKind::kUserAlias)); + + InsertCopies(module.get()); + + // There should be no copies inserted. + EXPECT_EQ(CountCopies(*module), 0); +} + +TEST_F(CopyInsertionTest, NoAliasCheckViolation) { + const string& hlo_string = R"( +HloModule cluster + +ENTRY Entry { + %arg = f32[8,28,28,1] parameter(0) + %bitcast.2 = f32[8,1,28,28] bitcast(f32[8,28,28,1] %arg) + ROOT %tuple.1 = (f32[8,1,28,28], f32[8,28,28,1]) tuple(f32[8,1,28,28] %bitcast.2, f32[8,28,28,1] %arg) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( + /*output_index=*/{1}, + /*param_number=*/0, + /*param_index=*/{}, HloInputOutputAliasConfig::AliasKind::kUserAlias)); + InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 1); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 56d663f7b24..98c23b679fa 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -262,7 +262,8 @@ class CpuAllReduceRendezvous : public xla::Rendezvous { protected: xla::StatusOr> SubmitParticipantImpl( xla::AllReduceParticipantData participant) override { - xla::PrimitiveType datatype = participant.primitive_type; + TF_RET_CHECK(participant.buffers.size() == 1); + xla::PrimitiveType datatype = participant.buffers.front().primitive_type; bool primary = [&] { tensorflow::mutex_lock lock(mu_); if (!initialized_) { @@ -316,10 +317,8 @@ class CpuAllReduceRendezvous : public xla::Rendezvous { using T = typename xla::primitive_util::PrimitiveTypeToNative::type; tensorflow::mutex_lock lock(mu_); CHECK(!participants_.empty()); - xla::int64 element_count = participant.element_count; xla::ReductionKind reduction_kind = participant.reduction_kind; for (const auto& p : participants_) { - CHECK_EQ(p.element_count, element_count); CHECK(p.reduction_kind == reduction_kind); } @@ -329,11 +328,19 @@ class CpuAllReduceRendezvous : public xla::Rendezvous { output_buffers.reserve(participants_.size()); for (auto& p : participants_) { - input_buffers.emplace_back(static_cast(p.source_data.opaque()), - element_count); - output_buffers.emplace_back(static_cast(p.destination_data.opaque()), - element_count); + CHECK_EQ(p.buffers.size(), 1); + CHECK_EQ(p.buffers.front().element_count, + participants_.front().buffers.front().element_count); + xla::int64 element_count = participant.buffers.front().element_count; + input_buffers.emplace_back( + static_cast(p.buffers.front().source_data.opaque()), + element_count); + output_buffers.emplace_back( + static_cast(p.buffers.front().destination_data.opaque()), + element_count); } + xla::int64 element_count = + participants_.front().buffers.front().element_count; auto compute = [reduction_kind](T a, T b) -> T { switch (reduction_kind) { @@ -416,7 +423,6 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce( xla::RendezvousKey rendezvous_key(run_options->run_id(), participating_replicas_vec, op_kind, op_id); - auto shape_str = ShapeString(shape_ptr, shape_length); VLOG(2) << "All-reduce input/output shape : " << shape_str; @@ -426,14 +432,16 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce( << "All-reduce on CPU is implemented only for dense arrays"; xla::AllReduceParticipantData participant(rendezvous_key); - participant.element_count = xla::ShapeUtil::ElementsIn(shape); participant.device_ordinal = device_ordinal; - participant.primitive_type = shape.element_type(); participant.stream = run_options->stream(); - participant.source_data = + xla::AllReduceParticipantData::Buffer buffer; + buffer.element_count = xla::ShapeUtil::ElementsIn(shape); + buffer.primitive_type = shape.element_type(); + buffer.source_data = se::DeviceMemoryBase(input_buffer, xla::ShapeUtil::ByteSizeOf(shape)); - participant.destination_data = + buffer.destination_data = se::DeviceMemoryBase(output_buffer, xla::ShapeUtil::ByteSizeOf(shape)); + participant.buffers = {buffer}; participant.reduction_kind = static_cast(reduction_kind); TF_CHECK_OK( diff --git a/tensorflow/compiler/xla/service/dump.cc b/tensorflow/compiler/xla/service/dump.cc index 05186f26ef6..3cb0eb78c5b 100644 --- a/tensorflow/compiler/xla/service/dump.cc +++ b/tensorflow/compiler/xla/service/dump.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/proto_serialization.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/regexp.h" namespace xla { @@ -110,10 +111,7 @@ struct CanonicalDebugOptions { string dump_to_lower = absl::AsciiStrToLower(opts.xla_dump_to()); if (dump_to_lower == "sponge" || dump_to_lower == "test_undeclared_outputs_dir") { - const char* dir = getenv("TEST_UNDECLARED_OUTPUTS_DIR"); - if (dir != nullptr) { - dump_to = dir; - } else { + if (!tensorflow::io::GetTestUndeclaredOutputsDir(&dump_to)) { LOG(ERROR) << "--xla_dump_to=" << opts.xla_dump_to() << ", but environment variable TEST_UNDECLARED_OUTPUTS_DIR " "is not set, so cannot dump anywhere."; diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 1f1efbd8545..d13eca30cdc 100755 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1131,6 +1131,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:algebraic_simplifier", + "//tensorflow/compiler/xla/service:all_reduce_combiner", "//tensorflow/compiler/xla/service:batchnorm_expander", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:call_inliner", @@ -1285,6 +1286,7 @@ cc_library( ":reduction_dimension_grouper", ":reduction_layout_normalizer", ":target_constants", + ":tree_reduction_rewriter", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/service:algebraic_simplifier", "//tensorflow/compiler/xla/service:hlo", @@ -1658,6 +1660,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/platform:resource_loader", "//tensorflow/stream_executor:dnn", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc index 0e2e27ee9a3..97013804271 100644 --- a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/reduction_dimension_grouper.h" #include "tensorflow/compiler/xla/service/gpu/reduction_layout_normalizer.h" #include "tensorflow/compiler/xla/service/gpu/target_constants.h" +#include "tensorflow/compiler/xla/service/gpu/tree_reduction_rewriter.h" #include "tensorflow/compiler/xla/service/hlo_constant_folding.h" #include "tensorflow/compiler/xla/service/hlo_cse.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" diff --git a/tensorflow/compiler/xla/service/gpu/dummy_all_reduce_thunk.cc b/tensorflow/compiler/xla/service/gpu/dummy_all_reduce_thunk.cc index 8e562387aac..7c3d76c1c92 100644 --- a/tensorflow/compiler/xla/service/gpu/dummy_all_reduce_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/dummy_all_reduce_thunk.cc @@ -42,15 +42,11 @@ NcclAllReduceThunk::DevicesWithOpenNcclChannels() { struct NcclAllReduceThunk::AuxData {}; NcclAllReduceThunk::NcclAllReduceThunk( - int64 replica_count, int64 element_count, - const BufferAllocation::Slice& source_buffer, - const BufferAllocation::Slice& destination_buffer, + int64 replica_count, std::vector buffers, const HloInstruction* all_reduce) : Thunk(Thunk::kNcclAllReduce, all_reduce), replica_count_(replica_count), - element_count_(element_count), - source_buffer_(source_buffer), - destination_buffer_(destination_buffer) {} + buffers_(std::move(buffers)) {} } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index bccf13b6104..e4c57203543 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -31,6 +31,7 @@ limitations under the License. #include "llvm/IR/Verifier.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" +#include "tensorflow/compiler/xla/service/all_reduce_combiner.h" #include "tensorflow/compiler/xla/service/batchnorm_expander.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/call_inliner.h" @@ -291,7 +292,13 @@ Status GpuCompiler::OptimizeHloModule( horizontal_fusion.AddPass(); TF_RETURN_IF_ERROR(horizontal_fusion.Run(hlo_module).status()); } - + { + HloPassPipeline pipeline("all_reduce_combiner"); + pipeline.AddPass( + /*combine_threshold_in_bytes=*/30 * 1024 * 1024, + /*combine_threshold_count=*/256); + TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); + } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/hlo_algorithm_blacklist_test.cc b/tensorflow/compiler/xla/service/gpu/hlo_algorithm_blacklist_test.cc index bf9ac31559a..bc24f486668 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_algorithm_blacklist_test.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_algorithm_blacklist_test.cc @@ -17,6 +17,8 @@ limitations under the License. #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/path.h" +#include "tensorflow/core/platform/resource_loader.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/stream_executor/dnn.h" @@ -31,9 +33,9 @@ class BlacklistTest : public testing::Test { "XLA_FLAGS", absl::StrCat( "--xla_gpu_algorithm_blacklist_path=", - tensorflow::io::JoinPath(tensorflow::testing::TensorFlowSrcRoot(), - "compiler", "xla", "service", "gpu", - "data", "hlo_algorithm_blacklist.pbtxt")) + tensorflow::GetDataDependencyFilepath(tensorflow::io::JoinPath( + "tensorflow", "compiler", "xla", "service", "gpu", "data", + "hlo_algorithm_blacklist.pbtxt"))) .data(), 0); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index c6b167f7402..8efcd2384a3 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -1210,10 +1210,7 @@ Status IrEmitterUnnested::HandleCollectivePermute(HloInstruction* hlo) { return Status::OK(); } -namespace { - - -} // namespace +namespace {} // namespace Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) { VLOG(2) << "AllReduce; replica count: " << hlo_module_config_.replica_count() @@ -1226,13 +1223,37 @@ Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) { NcclAllReduceThunk::CanImplement(crs); if (should_use_nccl_thunk) { - CHECK(crs->operand(0)->shape().IsArray()) - << "Operands to all-reduce must be arrays: " << crs->ToString(); - AddThunkToThunkSequence(absl::make_unique( + std::vector buffers; + std::vector tuple_element_buffers; + buffers.resize(crs->operand_count()); + tuple_element_buffers.reserve(crs->operand_count()); + CHECK(crs->shape().IsArray() && crs->operand_count() == 1 || + crs->shape().IsTuple() && + crs->shape().tuple_shapes_size() == crs->operand_count()); + for (int i = 0; i < crs->operand_count(); ++i) { + CHECK(crs->operand(i)->shape().IsArray()) + << "Operands to all-reduce must be arrays: " << crs->ToString(); + buffers[i].element_count = + ShapeUtil::ElementsIn(crs->operand(i)->shape()); + buffers[i].source_buffer = GetAllocationSlice(*crs->operand(i)); + buffers[i].destination_buffer = GetAllocationSlice( + *crs, crs->shape().IsTuple() ? ShapeIndex({i}) : ShapeIndex({})); + tuple_element_buffers.push_back(buffers[i].destination_buffer); + } + auto all_reduce_thunk = absl::make_unique( /*replica_count=*/hlo_module_config_.replica_count(), - /*elements=*/ShapeUtil::ElementsIn(crs->operand(0)->shape()), - /*source_address=*/GetAllocationSlice(*crs->operand(0)), - /*destination_buffer=*/GetAllocationSlice(*crs), crs)); + /*buffers=*/std::move(buffers), crs); + if (crs->shape().IsTuple()) { + std::vector> thunks; + thunks.push_back(std::move(all_reduce_thunk)); + thunks.push_back(absl::make_unique( + tuple_element_buffers, GetAllocationSlice(*crs), nullptr)); + AddThunkToThunkSequence( + absl::make_unique(std::move(thunks), crs)); + } else { + AddThunkToThunkSequence(std::move(all_reduce_thunk)); + } + return Status::OK(); } @@ -1957,32 +1978,32 @@ void IrEmitterUnnested::EmitTile( // // TODO(cheshire): Once ptxas is fixed and TF switches to it, remove the // workaround. - ksl->For( - loop_name + "_y_in_tile", - /*start=*/constant(0), - /*end=*/ - ceil_of_ratio(b_.CreateSub(tile_height, thread_id_info.thread_id_y), - num_threads_y), - /*step=*/constant(1), [&](llvm::Value* y_indvar) { - llvm::Value* y_loc = b_.CreateAdd( - thread_id_info.thread_id_y, b_.CreateMul(y_indvar, num_threads_y)); - for (int64 j = 0; j < x_num_steps; j++) { - llvm::Value* x_loc = - b_.CreateAdd(constant(j * step_x), start_offset_x, "x_loc"); - IrArray::Index source_idx_x = - source_idx.AddOffsetToDim(y_loc, kDimY, &b_) - .AddOffsetToDim(constant(j * step_x), kDimX, &b_); - auto emit_element = [&] { - return emit_elem_function(source_idx_x, y_loc, x_loc, j); - }; - if (!x_tile_fits) { - ksl->If(loop_name + "_x_in_tile", - b_.CreateICmpULT(x_loc, tile_width), emit_element); - } else { - emit_element(); - } - } - }); + ksl->For(loop_name + "_y_in_tile", + /*start=*/constant(0), + /*end=*/ + ceil_of_ratio(b_.CreateSub(tile_height, thread_id_info.thread_id_y), + num_threads_y), + /*step=*/constant(1), [&](llvm::Value* y_indvar) { + llvm::Value* y_loc = + b_.CreateAdd(thread_id_info.thread_id_y, + b_.CreateMul(y_indvar, num_threads_y)); + for (int64 j = 0; j < x_num_steps; j++) { + llvm::Value* x_loc = + b_.CreateAdd(constant(j * step_x), start_offset_x, "x_loc"); + IrArray::Index source_idx_x = + source_idx.AddOffsetToDim(y_loc, kDimY, &b_) + .AddOffsetToDim(constant(j * step_x), kDimX, &b_); + auto emit_element = [&] { + return emit_elem_function(source_idx_x, y_loc, x_loc, j); + }; + if (!x_tile_fits) { + ksl->If(loop_name + "_x_in_tile", + b_.CreateICmpULT(x_loc, tile_width), emit_element); + } else { + emit_element(); + } + } + }); } // Emits code to process a tensor element in a tile for the given kCopy HLO that diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD index f1083553c57..1419a4f792d 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD @@ -69,6 +69,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "//tensorflow/core/platform:resource_loader", "@llvm-project//llvm:core", "@llvm-project//llvm:support", ], diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils_test.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils_test.cc index 8c7f70ebcfb..84e3520c873 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils_test.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils_test.cc @@ -17,25 +17,28 @@ limitations under the License. #include -#include "tensorflow/core/lib/io/path.h" - #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/resource_loader.h" #include "tensorflow/core/platform/test.h" namespace xla { namespace gpu { namespace { -const char kSaxpyIRFile[] = - "compiler/xla/service/gpu/llvm_gpu_backend/tests_data/saxpy.ll"; +string SaxpyIRFile() { + return tensorflow::io::JoinPath("tensorflow", "compiler", "xla", "service", + "gpu", "llvm_gpu_backend", "tests_data", + "saxpy.ll"); +} TEST(UtilsTest, TestLoadIRModule) { llvm::LLVMContext llvm_context; string test_srcdir = tensorflow::testing::TensorFlowSrcRoot(); std::unique_ptr module = LoadIRModule( - tensorflow::io::JoinPath(test_srcdir, kSaxpyIRFile), &llvm_context); + tensorflow::GetDataDependencyFilepath(SaxpyIRFile()), &llvm_context); // Sanity check that the module was loaded properly. ASSERT_NE(nullptr, module); ASSERT_NE(std::string::npos, module->getModuleIdentifier().find("saxpy.ll")); diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc index 9b2662a9a05..4498793113a 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc @@ -154,10 +154,6 @@ ncclRedOp_t ReductionKindToNccl(ReductionKind kind) { } } -PrimitiveType AllReducePrimitiveType(const HloInstruction* instr) { - return instr->operand(0)->shape().element_type(); -} - absl::optional DatatypeToNccl(PrimitiveType element_type) { switch (element_type) { case S8: @@ -402,9 +398,6 @@ RendezvousNcclAllReduce::SubmitParticipantImpl( VLOG(3) << "Performing all reduce from device ordinal: " << participant.device_ordinal; ncclRedOp_t computation = ReductionKindToNccl(participant.reduction_kind); - absl::optional allreduce_datatype = - DatatypeToNccl(participant.primitive_type); - CHECK(allreduce_datatype.has_value()); se::StreamExecutor* executor = participant.stream->parent(); se::cuda::ScopedActivateExecutorContext scoped_context(executor); @@ -412,19 +405,26 @@ RendezvousNcclAllReduce::SubmitParticipantImpl( participant.stream->implementation()->GpuStreamMemberHack()); VLOG(3) << "Using stream pointer: " << cu_stream << " on device: " << participant.device_ordinal; - void* send_buffer = participant.source_data.opaque(); - void* recv_buffer = participant.destination_data.opaque(); - VLOG(3) << absl::StreamFormat( - "Calling ncclAllReduce(send_buffer=%p, recv_buffer=%p, count=%d, " - "comm=%p, stream=%p)", - send_buffer, recv_buffer, participant.element_count, - static_cast(comm), cu_stream); - XLA_CUDA_RETURN_IF_ERROR(ncclAllReduce(send_buffer, recv_buffer, - /*count=*/participant.element_count, - /*datatype=*/*allreduce_datatype, - /*op=*/computation, - /*comm=*/comm, - /*stream=*/*cu_stream)); + XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart()); + for (auto& buffer : participant.buffers) { + void* send_buffer = buffer.source_data.opaque(); + void* recv_buffer = buffer.destination_data.opaque(); + absl::optional allreduce_datatype = + DatatypeToNccl(buffer.primitive_type); + CHECK(allreduce_datatype.has_value()); + VLOG(3) << absl::StreamFormat( + "Calling ncclAllReduce(send_buffer=%p, recv_buffer=%p, count=%d, " + "comm=%p, stream=%p)", + send_buffer, recv_buffer, buffer.element_count, + static_cast(comm), cu_stream); + XLA_CUDA_RETURN_IF_ERROR(ncclAllReduce(send_buffer, recv_buffer, + /*count=*/buffer.element_count, + /*datatype=*/*allreduce_datatype, + /*op=*/computation, + /*comm=*/comm, + /*stream=*/*cu_stream)); + } + XLA_CUDA_RETURN_IF_ERROR(ncclGroupEnd()); VLOG(3) << "Done performing all reduce for ordinal: " << participant.device_ordinal; @@ -453,11 +453,14 @@ struct NcclAllReduceThunk::AuxData { }; /*static*/ bool NcclAllReduceThunk::CanImplement(const HloInstruction* crs) { + auto operands_are_supported = [crs]() { + return absl::c_all_of(crs->operands(), [](HloInstruction* operand) { + return LayoutUtil::IsDenseArray(operand->shape()) && + DatatypeToNccl(operand->shape().element_type()).has_value(); + }); + }; return MatchReductionComputation(crs->to_apply()).has_value() && - DatatypeToNccl(AllReducePrimitiveType(crs)).has_value() && - crs->IsCrossReplicaAllReduce() && - crs->operand_count() == 1 && // One array to reduce. - LayoutUtil::IsDenseArray(crs->operand(0)->shape()); + crs->IsCrossReplicaAllReduce() && operands_are_supported(); } /*static*/ absl::flat_hash_set @@ -471,16 +474,14 @@ NcclAllReduceThunk::DevicesWithOpenNcclChannels() { } NcclAllReduceThunk::NcclAllReduceThunk( - int64 replica_count, int64 element_count, - const BufferAllocation::Slice& source_buffer, - const BufferAllocation::Slice& destination_buffer, + int64 replica_count, std::vector buffers, const HloInstruction* all_reduce) : Thunk(Thunk::kNcclAllReduce, all_reduce), replica_count_(replica_count), - element_count_(element_count), - source_buffer_(source_buffer), - destination_buffer_(destination_buffer), - aux_data_(absl::make_unique()) {} + buffers_(std::move(buffers)), + aux_data_(absl::make_unique()) { + CHECK_EQ(hlo_instruction()->operand_count(), buffers_.size()); +} // Figures out which devices (named by their replica-ids) are participating in // the all-reduce subgroup that contains device_ordinal. @@ -506,18 +507,24 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) { << absl::StrJoin(participating_replicas, ", "); AllReduceParticipantData participant(rendezvous_key); - participant.element_count = element_count_; participant.device_ordinal = device_ordinal; - participant.source_data = - params.buffer_allocations->GetDeviceAddress(source_buffer_); - participant.destination_data = - params.buffer_allocations->GetDeviceAddress(destination_buffer_); + for (size_t i = 0; i < buffers_.size(); ++i) { + const NcclAllReduceThunk::Buffer& buffer = buffers_[i]; + AllReduceParticipantData::Buffer pbuffer; + pbuffer.element_count = buffer.element_count; + pbuffer.source_data = + params.buffer_allocations->GetDeviceAddress(buffer.source_buffer); + pbuffer.destination_data = + params.buffer_allocations->GetDeviceAddress(buffer.destination_buffer); + pbuffer.primitive_type = + hlo_instruction()->operand(i)->shape().element_type(); + participant.buffers.push_back(pbuffer); + } participant.stream = params.stream; auto reduction_kind = MatchReductionComputation(hlo_instruction()->to_apply()); CHECK(reduction_kind.has_value()); participant.reduction_kind = *reduction_kind; - participant.primitive_type = AllReducePrimitiveType(hlo_instruction()); TF_ASSIGN_OR_RETURN( std::shared_ptr clique, diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h index 36b757ae567..7633a99794f 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h @@ -50,9 +50,12 @@ class NcclAllReduceThunk : public Thunk { // TODO(b/125951860): Support all-reduces with replica groups, i.e. // all-reduces that compute multiple sums across subsets of all replicas. - NcclAllReduceThunk(int64 replica_count, int64 element_count, - const BufferAllocation::Slice& source_buffer, - const BufferAllocation::Slice& destination_buffer, + struct Buffer { + int64 element_count; + BufferAllocation::Slice source_buffer; + BufferAllocation::Slice destination_buffer; + }; + NcclAllReduceThunk(int64 replica_count, std::vector buffers, const HloInstruction* all_reduce); ~NcclAllReduceThunk() override; @@ -70,9 +73,7 @@ class NcclAllReduceThunk : public Thunk { struct AuxData; const int64 replica_count_; - const int64 element_count_; - const BufferAllocation::Slice source_buffer_; - const BufferAllocation::Slice destination_buffer_; + const std::vector buffers_; std::unique_ptr aux_data_; }; diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index f61ccd77c86..a1a901f0b94 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -55,6 +55,7 @@ limitations under the License. #include "tensorflow/core/platform/cuda_libdevice_path.h" #include "tensorflow/core/platform/tracing.h" #include "tensorflow/core/profiler/lib/traceme.h" +#include "tensorflow/core/util/env_var.h" #include "tensorflow/stream_executor/cuda/cuda_diagnostics.h" #include "tensorflow/stream_executor/gpu/asm_compiler.h" @@ -151,6 +152,16 @@ Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization( return Status::OK(); } +// TODO(cheshire): Duplication with gpu_conv_algorithm picker, figure out a +// right way to share this. +static bool RequireDeterminism() { + bool deterministic_ops = false; + TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_DETERMINISTIC_OPS", + /*default_val=*/false, + &deterministic_ops)); + return deterministic_ops; +} + Status NVPTXCompiler::OptimizeHloPostLayoutAssignment( HloModule* hlo_module, se::StreamExecutor* stream_exec, se::DeviceMemoryAllocator* device_allocator) { @@ -172,7 +183,8 @@ Status NVPTXCompiler::OptimizeHloPostLayoutAssignment( options.set_is_layout_sensitive(true); pipeline.AddPass>(options); - if (hlo_module->config().debug_options().xla_gpu_deterministic_reductions()) { + if (RequireDeterminism() || + hlo_module->config().debug_options().xla_gpu_deterministic_reductions()) { pipeline.AddPass>(); } diff --git a/tensorflow/compiler/xla/service/gpu/tests/tree_reduction_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/tests/tree_reduction_rewriter_test.cc index c0210ff941d..eb821c36fae 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/tree_reduction_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/tree_reduction_rewriter_test.cc @@ -67,24 +67,23 @@ ENTRY main { zero = f32[] constant(0) ROOT out = f32[] reduce(input, zero), dimensions={0}, to_apply=add } - )"; // TODO(cheshire): a more generic check, do not hardcode the names. MatchOptimizedHloWithShapes(hlo_text, R"( -// CHECK: %fused_computation (param_0.2: f32[50000]) -> f32[7] { +// CHECK: %fused_computation (param_0.2: f32[50000]) -> f32[224] { // CHECK: %param_0.2 = f32[50000]{0} parameter(0) // CHECK: %zero_1 = f32[] constant(0) -// CHECK: %pad.1 = f32[57344]{0} pad(f32[50000]{0} %param_0.2, f32[] %zero_1), padding=0_7344 -// CHECK: %bitcast.1 = f32[7,8192]{1,0} bitcast(f32[57344]{0} %pad.1) -// CHECK: ROOT %reduce.2 = f32[7]{0} reduce(f32[7,8192]{1,0} %bitcast.1, f32[] %zero_1), dimensions={1}, to_apply=%add +// CHECK: %pad.1 = f32[50176]{0} pad(f32[50000]{0} %param_0.2, f32[] %zero_1), padding=0_176 +// CHECK: %bitcast.1 = f32[224,224]{1,0} bitcast(f32[50176]{0} %pad.1) +// CHECK: ROOT %reduce.2 = f32[224]{0} reduce(f32[224,224]{1,0} %bitcast.1, f32[] %zero_1), dimensions={1}, to_apply=%add // CHECK: } // CHECK: ENTRY %main (input: f32[50000]) -> f32[] { // CHECK: %input = f32[50000]{0} parameter(0) -// CHECK: %fusion = f32[7]{0} fusion(f32[50000]{0} %input), kind=kInput, calls=%fused_computation +// CHECK: %fusion = f32[224]{0} fusion(f32[50000]{0} %input), kind=kInput, calls=%fused_computation // CHECK: %zero = f32[] constant(0) -// CHECK: ROOT %reduce.1 = f32[] reduce(f32[7]{0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add +// CHECK: ROOT %reduce.1 = f32[] reduce(f32[224]{0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add // CHECK: } )"); @@ -107,27 +106,25 @@ ENTRY main { zero = f32[] constant(0) ROOT out = f32[100,100] reduce(input, zero), dimensions={2}, to_apply=add } - )"; EnsureDeterminism(hlo_text); MatchOptimizedHloWithShapes(hlo_text, R"( -// CHECK: %fused_computation (param_0.2: f32[100,100,10000]) -> f32[100,100,2] { +// CHECK: %fused_computation (param_0.2: f32[100,100,10000]) -> f32[100,100,100] { // CHECK: %param_0.2 = f32[100,100,10000]{2,1,0} parameter(0) // CHECK: %zero_1 = f32[] constant(0) -// CHECK: %pad.1 = f32[100,100,16384]{2,1,0} pad(f32[100,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_6384 -// CHECK: %bitcast.1 = f32[100,100,2,8192]{3,2,1,0} bitcast(f32[100,100,16384]{2,1,0} %pad.1) -// CHECK: ROOT %reduce.2 = f32[100,100,2]{2,1,0} reduce(f32[100,100,2,8192]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={3}, to_apply=%add +// CHECK: %pad.1 = f32[100,100,10000]{2,1,0} pad(f32[100,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_0 +// CHECK: %bitcast.1 = f32[100,100,100,100]{3,2,1,0} bitcast(f32[100,100,10000]{2,1,0} %pad.1) +// CHECK: ROOT %reduce.2 = f32[100,100,100]{2,1,0} reduce(f32[100,100,100,100]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={3}, to_apply=%add // CHECK: } // CHECK: ENTRY %main (input: f32[100,100,10000]) -> f32[100,100] { // CHECK: %input = f32[100,100,10000]{2,1,0} parameter(0) -// CHECK: %fusion = f32[100,100,2]{2,1,0} fusion(f32[100,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation +// CHECK: %fusion = f32[100,100,100]{2,1,0} fusion(f32[100,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation // CHECK: %zero = f32[] constant(0) -// CHECK: ROOT %reduce.1 = f32[100,100]{1,0} reduce(f32[100,100,2]{2,1,0} %fusion, f32[] %zero), dimensions={2}, to_apply=%add +// CHECK: ROOT %reduce.1 = f32[100,100]{1,0} reduce(f32[100,100,100]{2,1,0} %fusion, f32[] %zero), dimensions={2}, to_apply=%add // CHECK: } - )"); EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); @@ -149,23 +146,22 @@ ENTRY main { zero = f32[] constant(0) ROOT out = f32[] reduce(input, zero), dimensions={0}, to_apply=add } - )"; MatchOptimizedHloWithShapes(hlo_text, R"( -// CHECK: %fused_computation (param_0.2: f32[1000000]) -> f32[123] { +// CHECK: %fused_computation (param_0.2: f32[1000000]) -> f32[1000] { // CHECK: %param_0.2 = f32[1000000]{0} parameter(0) // CHECK: %zero_1 = f32[] constant(0) -// CHECK: %pad.1 = f32[1007616]{0} pad(f32[1000000]{0} %param_0.2, f32[] %zero_1), padding=0_7616 -// CHECK: %bitcast.1 = f32[123,8192]{1,0} bitcast(f32[1007616]{0} %pad.1) -// CHECK: ROOT %reduce.2 = f32[123]{0} reduce(f32[123,8192]{1,0} %bitcast.1, f32[] %zero_1), dimensions={1}, to_apply=%add +// CHECK: %pad.1 = f32[1000000]{0} pad(f32[1000000]{0} %param_0.2, f32[] %zero_1), padding=0_0 +// CHECK: %bitcast.1 = f32[1000,1000]{1,0} bitcast(f32[1000000]{0} %pad.1) +// CHECK: ROOT %reduce.2 = f32[1000]{0} reduce(f32[1000,1000]{1,0} %bitcast.1, f32[] %zero_1), dimensions={1}, to_apply=%add // CHECK: } // CHECK: ENTRY %main (input: f32[1000000]) -> f32[] { // CHECK: %input = f32[1000000]{0} parameter(0) -// CHECK: %fusion = f32[123]{0} fusion(f32[1000000]{0} %input), kind=kInput, calls=%fused_computation +// CHECK: %fusion = f32[1000]{0} fusion(f32[1000000]{0} %input), kind=kInput, calls=%fused_computation // CHECK: %zero = f32[] constant(0) -// CHECK: ROOT %reduce.1 = f32[] reduce(f32[123]{0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add +// CHECK: ROOT %reduce.1 = f32[] reduce(f32[1000]{0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add // CHECK: } )"); @@ -188,25 +184,24 @@ ENTRY main { zero = f32[] constant(0) ROOT out = f32[100] reduce(input, zero), dimensions={0,2}, to_apply=add } - )"; EnsureDeterminism(hlo_text); MatchOptimizedHloWithShapes(hlo_text, R"( -// CHECK: %fused_computation (param_0.2: f32[8,100,10000]) -> f32[100,2] { +// CHECK: %fused_computation (param_0.2: f32[8,100,10000]) -> f32[100,100] { // CHECK: %param_0.2 = f32[8,100,10000]{2,1,0} parameter(0) // CHECK: %zero_1 = f32[] constant(0) -// CHECK: %pad.1 = f32[8,100,16384]{2,1,0} pad(f32[8,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_6384 -// CHECK: %bitcast.1 = f32[8,100,2,8192]{3,2,1,0} bitcast(f32[8,100,16384]{2,1,0} %pad.1) -// CHECK: ROOT %reduce.2 = f32[100,2]{1,0} reduce(f32[8,100,2,8192]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={3,0}, to_apply=%add +// CHECK: %pad.1 = f32[8,100,10000]{2,1,0} pad(f32[8,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_0 +// CHECK: %bitcast.1 = f32[8,100,100,100]{3,2,1,0} bitcast(f32[8,100,10000]{2,1,0} %pad.1) +// CHECK: ROOT %reduce.2 = f32[100,100]{1,0} reduce(f32[8,100,100,100]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={3,0}, to_apply=%add // CHECK: } // CHECK: ENTRY %main (input: f32[8,100,10000]) -> f32[100] { // CHECK: %input = f32[8,100,10000]{2,1,0} parameter(0) -// CHECK: %fusion = f32[100,2]{1,0} fusion(f32[8,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation +// CHECK: %fusion = f32[100,100]{1,0} fusion(f32[8,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation // CHECK: %zero = f32[] constant(0) -// CHECK: ROOT %reduce.1 = f32[100]{0} reduce(f32[100,2]{1,0} %fusion, f32[] %zero), dimensions={1}, to_apply=%add +// CHECK: ROOT %reduce.1 = f32[100]{0} reduce(f32[100,100]{1,0} %fusion, f32[] %zero), dimensions={1}, to_apply=%add // CHECK: } )"); @@ -234,23 +229,19 @@ ENTRY main { MatchOptimizedHloWithShapes(hlo_text, R"( -// CHECK: %fused_computation (param_0.4: f32[32,100,2]) -> f32[100] { -// CHECK: %param_0.4 = f32[32,100,2]{2,1,0} parameter(0) +// CHECK: %fused_computation (param_0.2: f32[32,100,10000]) -> f32[32,100,100] { +// CHECK: %param_0.2 = f32[32,100,10000]{2,1,0} parameter(0) // CHECK: %zero_1 = f32[] constant(0) -// CHECK: %reduce.5 = f32[32,100]{1,0} reduce(f32[32,100,2]{2,1,0} %param_0.4, f32[] %zero_1), dimensions={2}, to_apply=%add -// CHECK: ROOT %reduce.4 = f32[100]{0} reduce(f32[32,100]{1,0} %reduce.5, f32[] %zero_1), dimensions={0}, to_apply=%add -// CHECK: } -// CHECK: %fused_computation.1 (param_0.5: f32[32,100,10000]) -> f32[32,100,2] { -// CHECK: %param_0.5 = f32[32,100,10000]{2,1,0} parameter(0) -// CHECK: %zero_2 = f32[] constant(0) -// CHECK: %pad.1 = f32[32,100,16384]{2,1,0} pad(f32[32,100,10000]{2,1,0} %param_0.5, f32[] %zero_2), padding=0_0x0_0x0_6384 -// CHECK: %bitcast.1 = f32[32,100,2,8192]{3,2,1,0} bitcast(f32[32,100,16384]{2,1,0} %pad.1) -// CHECK: ROOT %reduce.6 = f32[32,100,2]{2,1,0} reduce(f32[32,100,2,8192]{3,2,1,0} %bitcast.1, f32[] %zero_2), dimensions={3}, to_apply=%add +// CHECK: %pad.1 = f32[32,100,10000]{2,1,0} pad(f32[32,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_0 +// CHECK: %bitcast.1 = f32[32,100,100,100]{3,2,1,0} bitcast(f32[32,100,10000]{2,1,0} %pad.1) +// CHECK: ROOT %reduce.4 = f32[32,100,100]{2,1,0} reduce(f32[32,100,100,100]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={3}, to_apply=%add // CHECK: } // CHECK: ENTRY %main (input: f32[32,100,10000]) -> f32[100] { // CHECK: %input = f32[32,100,10000]{2,1,0} parameter(0) -// CHECK: %fusion.1 = f32[32,100,2]{2,1,0} fusion(f32[32,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation.1 -// CHECK: ROOT %fusion = f32[100]{0} fusion(f32[32,100,2]{2,1,0} %fusion.1), kind=kInput, calls=%fused_computation +// CHECK: %fusion = f32[32,100,100]{2,1,0} fusion(f32[32,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation +// CHECK: %zero = f32[] constant(0) +// CHECK: %reduce.3 = f32[32,100]{1,0} reduce(f32[32,100,100]{2,1,0} %fusion, f32[] %zero), dimensions={2}, to_apply=%add +// CHECK: ROOT %reduce.1 = f32[100]{0} reduce(f32[32,100]{1,0} %reduce.3, f32[] %zero), dimensions={0}, to_apply=%add // CHECK: } )"); @@ -274,22 +265,22 @@ ENTRY main { zero = f32[] constant(0) ROOT out = f32[100] reduce(input, zero), dimensions={0}, to_apply=add } - )"; MatchOptimizedHloWithShapes(hlo_text, R"( -// CHECK: %fused_computation (param_0.2: f32[10000,100]) -> f32[100] { -// CHECK: %param_0.2 = f32[10000,100]{1,0} parameter(0) -// CHECK: %zero_1 = f32[] constant(0) -// CHECK: %pad.1 = f32[12288,100]{1,0} pad(f32[10000,100]{1,0} %param_0.2, f32[] %zero_1), padding=0_2288x0_0 -// CHECK: %bitcast.1 = f32[3,4096,100]{2,1,0} bitcast(f32[12288,100]{1,0} %pad.1) -// CHECK: %reduce.3 = f32[4096,100]{1,0} reduce(f32[3,4096,100]{2,1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add -// CHECK: ROOT %reduce.2 = f32[100]{0} reduce(f32[4096,100]{1,0} %reduce.3, f32[] %zero_1), dimensions={0}, to_apply=%add +// CHECK: %fused_computation (param_0.2: f32[10000,100]) -> f32[100,100] { +// CHECK: %param_0.2 = f32[10000,100]{1,0} parameter(0) +// CHECK: %zero_1 = f32[] constant(0) +// CHECK: %pad.1 = f32[10000,100]{1,0} pad(f32[10000,100]{1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0 +// CHECK: %bitcast.1 = f32[100,100,100]{2,1,0} bitcast(f32[10000,100]{1,0} %pad.1) +// CHECK: ROOT %reduce.2 = f32[100,100]{1,0} reduce(f32[100,100,100]{2,1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add // CHECK: } // CHECK: ENTRY %main (input: f32[10000,100]) -> f32[100] { -// CHECK: %input = f32[10000,100]{1,0} parameter(0) -// CHECK: ROOT %fusion = f32[100]{0} fusion(f32[10000,100]{1,0} %input), kind=kInput, calls=%fused_computation +// CHECK: %input = f32[10000,100]{1,0} parameter(0) +// CHECK: %fusion = f32[100,100]{1,0} fusion(f32[10000,100]{1,0} %input), kind=kInput, calls=%fused_computation +// CHECK: %zero = f32[] constant(0) +// CHECK: ROOT %reduce.1 = f32[100]{0} reduce(f32[100,100]{1,0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add // CHECK: } )"); @@ -316,17 +307,18 @@ ENTRY main { MatchOptimizedHloWithShapes(hlo_text, R"( -// CHECK: %fused_computation (param_0.2: f32[10000,2,2,2]) -> f32[2,2,2] { -// CHECK: %param_0.2 = f32[10000,2,2,2]{3,2,1,0} parameter(0) -// CHECK: %zero_1 = f32[] constant(0) -// CHECK: %pad.1 = f32[12288,2,2,2]{3,2,1,0} pad(f32[10000,2,2,2]{3,2,1,0} %param_0.2, f32[] %zero_1), padding=0_2288x0_0x0_0x0_0 -// CHECK: %bitcast.1 = f32[3,4096,2,2,2]{4,3,2,1,0} bitcast(f32[12288,2,2,2]{3,2,1,0} %pad.1) -// CHECK: %reduce.3 = f32[4096,2,2,2]{3,2,1,0} reduce(f32[3,4096,2,2,2]{4,3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add -// CHECK: ROOT %reduce.2 = f32[2,2,2]{2,1,0} reduce(f32[4096,2,2,2]{3,2,1,0} %reduce.3, f32[] %zero_1), dimensions={0}, to_apply=%add +// CHECK: %fused_computation (param_0.2: f32[10000,2,2,2]) -> f32[100,2,2,2] { +// CHECK: %param_0.2 = f32[10000,2,2,2]{3,2,1,0} parameter(0) +// CHECK: %zero_1 = f32[] constant(0) +// CHECK: %pad.1 = f32[10000,2,2,2]{3,2,1,0} pad(f32[10000,2,2,2]{3,2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_0x0_0 +// CHECK: %bitcast.1 = f32[100,100,2,2,2]{4,3,2,1,0} bitcast(f32[10000,2,2,2]{3,2,1,0} %pad.1) +// CHECK: ROOT %reduce.2 = f32[100,2,2,2]{3,2,1,0} reduce(f32[100,100,2,2,2]{4,3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add // CHECK: } // CHECK: ENTRY %main (input: f32[10000,2,2,2]) -> f32[2,2,2] { -// CHECK: %input = f32[10000,2,2,2]{3,2,1,0} parameter(0) -// CHECK: ROOT %fusion = f32[2,2,2]{2,1,0} fusion(f32[10000,2,2,2]{3,2,1,0} %input), kind=kInput, calls=%fused_computation +// CHECK: %input = f32[10000,2,2,2]{3,2,1,0} parameter(0) +// CHECK: %fusion = f32[100,2,2,2]{3,2,1,0} fusion(f32[10000,2,2,2]{3,2,1,0} %input), kind=kInput, calls=%fused_computation +// CHECK: %zero = f32[] constant(0) +// CHECK: ROOT %reduce.1 = f32[2,2,2]{2,1,0} reduce(f32[100,2,2,2]{3,2,1,0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add // CHECK: } )"); @@ -355,18 +347,18 @@ ENTRY main { MatchOptimizedHloWithShapes(hlo_text, R"( -// CHECK: %fused_computation (param_0.2: f32[1000000,5]) -> f32[4096,5] { -// CHECK: %param_0.2 = f32[1000000,5]{1,0} parameter(0) -// CHECK: %zero_1 = f32[] constant(0) -// CHECK: %pad.1 = f32[1003520,5]{1,0} pad(f32[1000000,5]{1,0} %param_0.2, f32[] %zero_1), padding=0_3520x0_0 -// CHECK: %bitcast.1 = f32[245,4096,5]{2,1,0} bitcast(f32[1003520,5]{1,0} %pad.1) -// CHECK: ROOT %reduce.2 = f32[4096,5]{1,0} reduce(f32[245,4096,5]{2,1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add +// CHECK: %fused_computation (param_0.2: f32[1000000,5]) -> f32[1000,5] { +// CHECK: %param_0.2 = f32[1000000,5]{1,0} parameter(0) +// CHECK: %zero_1 = f32[] constant(0) +// CHECK: %pad.1 = f32[1000000,5]{1,0} pad(f32[1000000,5]{1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0 +// CHECK: %bitcast.1 = f32[1000,1000,5]{2,1,0} bitcast(f32[1000000,5]{1,0} %pad.1) +// CHECK: ROOT %reduce.2 = f32[1000,5]{1,0} reduce(f32[1000,1000,5]{2,1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add // CHECK: } // CHECK: ENTRY %main (input: f32[1000000,5]) -> f32[5] { -// CHECK: %input = f32[1000000,5]{1,0} parameter(0) -// CHECK: %fusion = f32[4096,5]{1,0} fusion(f32[1000000,5]{1,0} %input), kind=kInput, calls=%fused_computation -// CHECK: %zero = f32[] constant(0) -// CHECK: ROOT %reduce.1 = f32[5]{0} reduce(f32[4096,5]{1,0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add +// CHECK: %input = f32[1000000,5]{1,0} parameter(0) +// CHECK: %fusion = f32[1000,5]{1,0} fusion(f32[1000000,5]{1,0} %input), kind=kInput, calls=%fused_computation +// CHECK: %zero = f32[] constant(0) +// CHECK: ROOT %reduce.1 = f32[5]{0} reduce(f32[1000,5]{1,0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add // CHECK: } )"); diff --git a/tensorflow/compiler/xla/service/gpu/tree_reduction_rewriter.cc b/tensorflow/compiler/xla/service/gpu/tree_reduction_rewriter.cc index 5dad97dab39..e6d4569478c 100644 --- a/tensorflow/compiler/xla/service/gpu/tree_reduction_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/tree_reduction_rewriter.cc @@ -46,6 +46,11 @@ static constexpr int64 kColumnAtomicFreeBound = kWarpSize * 128; // decreased column/row tiling. static constexpr int64 kBatchedAtomicFreeBound = 8; +// Returns the square root of the input rounded up to the nearest square. +static int64 SqrtOfRoundUpToNearestSquare(int64 input) { + return static_cast(std::ceil(std::sqrt(input))); +} + class ReductionRewriterVisitor : public DfsHloRewriteVisitor { public: explicit ReductionRewriterVisitor() {} @@ -105,39 +110,29 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor { int64 reduced_dim_size = input_shape.dimensions(reduced_input_dimension); VLOG(3) << "reduced_dim_size = " << reduced_dim_size; - // TODO(cheshire): if atomic_free_bound is large, num_fit is likely to be - // small. Generating a reduction with very small reduced dimension is not - // efficient, it would be better to split the dimension sizes more evenly. - // - // One possible idea is to pad to a nearest square (ceil(sqrt(x)))^2. - // Given that: + + // We pad to a nearest square (ceil(sqrt(x)))^2. Given that: // // (n + 1)^2 = n^2 + (2n+1) // // it can be seen that the distance to the nearest square is at most twice // the square root of the input number. - int64 num_fit = CeilOfRatio(reduced_dim_size, atomic_free_bound); + int64 num_fit = SqrtOfRoundUpToNearestSquare(reduced_dim_size); // Pad reduced dimension to the required number of elements. HloInstruction *padded = [&] { - // TODO(cheshire): if atomic_free_bound is very large, padding all the way - // up to to atomic_free_bound is wasteful, we could pad to a much smaller - // value. - if (reduced_dim_size % atomic_free_bound != 0) { - int64 padded_num_elements = num_fit * atomic_free_bound; - PaddingConfig padding_config = MakeNoPaddingConfig(input_shape.rank()); - padding_config.mutable_dimensions(reduced_input_dimension) - ->set_edge_padding_high(padded_num_elements - reduced_dim_size); - std::vector padded_dimensions(input_shape.dimensions().begin(), - input_shape.dimensions().end()); - padded_dimensions[reduced_input_dimension] = padded_num_elements; - Shape padded_shape = - ShapeUtil::MakeShape(input_shape.element_type(), padded_dimensions); - VLOG(3) << "Generated padded shape: " << padded_shape.ToString(); - return hlo->parent()->AddInstruction(HloInstruction::CreatePad( - padded_shape, input, initial_value, padding_config)); - } - return input; + int64 padded_num_elements = num_fit * num_fit; + PaddingConfig padding_config = MakeNoPaddingConfig(input_shape.rank()); + padding_config.mutable_dimensions(reduced_input_dimension) + ->set_edge_padding_high(padded_num_elements - reduced_dim_size); + std::vector padded_dimensions(input_shape.dimensions().begin(), + input_shape.dimensions().end()); + padded_dimensions[reduced_input_dimension] = padded_num_elements; + Shape padded_shape = + ShapeUtil::MakeShape(input_shape.element_type(), padded_dimensions); + VLOG(3) << "Generated padded shape: " << padded_shape.ToString(); + return hlo->parent()->AddInstruction(HloInstruction::CreatePad( + padded_shape, input, initial_value, padding_config)); }(); VLOG(1) << "Generated padding: " << padded->ToString(); @@ -146,7 +141,7 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor { dim_idx++) { if (dim_idx == reduced_input_dimension) { reshaped_dimensions.push_back(num_fit); - reshaped_dimensions.push_back(atomic_free_bound); + reshaped_dimensions.push_back(num_fit); } else { reshaped_dimensions.push_back(padded->shape().dimensions(dim_idx)); } diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 122122aae55..22d9f1bc648 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -309,6 +309,8 @@ Status HloComputation::RemoveInstructionImpl(HloInstruction* instruction, auto inst_it = instruction_iterators_.find(instruction); TF_RET_CHECK(inst_it != instruction_iterators_.end()); (*inst_it->second)->set_parent(nullptr); + to_be_deleted_.emplace_back(inst_it->second->release()); + to_be_deleted_.back()->DetachFromOperandsAndUsers(); instructions_.erase(inst_it->second); instruction_iterators_.erase(inst_it); return Status::OK(); diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 9ca60403929..f1568858d9f 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -469,6 +469,12 @@ class HloComputation { int64 unique_id() const { return unique_id_; } + // Deallocate instructions that are marked by "RemoveInstruction". The two + // stage clean up process is designed such that HloPass can have stable + // internal pointers to HloInstructions while we create and remove + // HloInstructions in a pass. + void Cleanup() { to_be_deleted_.clear(); } + private: explicit HloComputation( const string& name, int parameter_count, @@ -527,6 +533,10 @@ class HloComputation { absl::flat_hash_map instruction_iterators_; + // Removed instructions are moved into to_be_deleted_ first and then + // deallocated when Cleanup is called. + std::vector> to_be_deleted_; + std::vector param_instructions_; TF_DISALLOW_COPY_AND_ASSIGN(HloComputation); diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index ef3809c1b94..2e089f34bac 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -751,7 +751,7 @@ Status HloCostAnalysis::HandleRngBitGenerator(const HloInstruction* random) { // cost changes with the implementation and the distribution. For now, assume // the cost of each RNG is same as a transcendental operation. current_properties_[kTranscendentalsKey] = - ShapeUtil::ElementsIn(random->shape()); + ShapeUtil::ElementsInRecursive(random->shape()); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index 846b9cfbeb5..dd174772c62 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -33,6 +33,15 @@ limitations under the License. namespace xla { using absl::StrCat; +StatusOr MakeUnaryHlo(HloOpcode opcode, + HloInstruction* operand) { + HloComputation* computation = operand->parent(); + TF_ASSIGN_OR_RETURN(Shape unary_op_shape, + ShapeInference::InferUnaryOpShape(opcode, operand)); + return computation->AddInstruction( + HloInstruction::CreateUnary(unary_op_shape, opcode, operand)); +} + StatusOr MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs, HloInstruction* rhs) { HloComputation* computation = lhs->parent(); @@ -344,6 +353,15 @@ StatusOr MakeReduceHlo(HloInstruction* operand, scalar_shape, operand, init_value, all_dims, reduce_computation)); } +StatusOr MakeReverseHlo(HloInstruction* operand, + absl::Span dimensions) { + HloComputation* computation = operand->parent(); + TF_ASSIGN_OR_RETURN(Shape reverse_shape, ShapeInference::InferReverseShape( + operand->shape(), dimensions)); + return computation->AddInstruction( + HloInstruction::CreateReverse(reverse_shape, operand, dimensions)); +} + StatusOr MakeSelectHlo(HloInstruction* pred, HloInstruction* on_true, HloInstruction* on_false, diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h index 754f7e2be33..3f2e3aa25a1 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.h +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -27,6 +27,11 @@ namespace xla { // ergonomic. We don't have a complete set of helpers yet -- I expect we'll // expand this interface as needed on an ad-hoc basis. +// Creates a unary HLO instruction and adds it to the computation containing +// `operand`. +StatusOr MakeUnaryHlo(HloOpcode opcode, + HloInstruction* operand); + // Creates a binary HLO instruction and adds it to the computation containing // `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation). StatusOr MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs, @@ -145,6 +150,11 @@ StatusOr MakeReduceHlo(HloInstruction* operand, HloOpcode binary_opcode, HloModule* module); +// Creates a Reverse HLO instruction and adds it to the computation containing +// `operand`. +StatusOr MakeReverseHlo(HloInstruction* operand, + absl::Span dimensions); + // Creates a Select HLO instruction and adds it to the computation containing // the predicate. The on_true and on_false instructions must also be contained // in the same computation. If on_true and on_false are tuples, create a tuple diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index a58fcf4460a..373f4f12ba4 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/hash.h" namespace xla { @@ -96,17 +97,54 @@ StatusOr CombineConstants(HloComputation* computation, // share the exact same set of operands. int64 CseHash(const HloInstruction* instruction) { int64 hash = std::hash()(static_cast(instruction->opcode())); + auto c_hash = [](auto c) { + return tensorflow::Hash64(reinterpret_cast(c.data()), + c.size() * sizeof(c[0])); + }; + auto proto_hash = [](auto proto) { + return std::hash{}(proto.ByteSizeLong()); + }; hash = tensorflow::Hash64Combine( hash, instruction->opcode() == HloOpcode::kGetTupleElement ? instruction->tuple_index() - : -1); + : c_hash(instruction->shape().dimensions())); for (auto operand : instruction->operands()) { hash = tensorflow::Hash64Combine(hash, operand->unique_id()); } - if (instruction->opcode() == HloOpcode::kConstant) { - hash = tensorflow::Hash64Combine(hash, instruction->literal().Hash()); + for (auto c : instruction->called_computations()) { + hash = tensorflow::Hash64Combine( + hash, std::hash()( + static_cast(c->root_instruction()->opcode()))); + } + switch (instruction->opcode()) { + case HloOpcode::kConstant: + return tensorflow::Hash64Combine(hash, instruction->literal().Hash()); + case HloOpcode::kSlice: + return tensorflow::Hash64Combine( + tensorflow::Hash64Combine(hash, c_hash(instruction->slice_starts())), + c_hash(instruction->slice_strides())); + case HloOpcode::kPad: + return tensorflow::Hash64Combine( + hash, proto_hash(instruction->padding_config())); + case HloOpcode::kDot: + return tensorflow::Hash64Combine( + hash, proto_hash(instruction->dot_dimension_numbers())); + case HloOpcode::kConvolution: + return tensorflow::Hash64Combine( + tensorflow::Hash64Combine( + hash, proto_hash(instruction->convolution_dimension_numbers())), + proto_hash(instruction->window())); + case HloOpcode::kReduceWindow: + return tensorflow::Hash64Combine(hash, proto_hash(instruction->window())); + case HloOpcode::kConcatenate: + case HloOpcode::kBroadcast: + case HloOpcode::kTranspose: + case HloOpcode::kIota: + case HloOpcode::kReduce: + return tensorflow::Hash64Combine(hash, c_hash(instruction->dimensions())); + default: + return hash; } - return hash; } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index fc9d42c1b17..803004225d2 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -133,6 +133,10 @@ class HloEvaluator : public DfsHloVisitorWithDefault { dynamic_dimension_inference_ = dynamic_dimension_inference; } + DynamicDimensionInference* dynamic_dimension_inference() { + return dynamic_dimension_inference_; + } + // Enable the fast path for certain operations like dot or convolution. void set_use_fast_path(bool value) { use_fast_path_ = value; } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 9f45cac028c..8aeb92b40de 100755 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -1661,7 +1661,11 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( return clone; } -HloInstruction::~HloInstruction() { +void HloInstruction::DetachFromOperandsAndUsers() { + if (cleaned_up_) { + return; + } + cleaned_up_ = true; // Detach from operands. An instruction may be repeated as an operand. To // avoid calling RemoveUser twice on the same operand, check before remove. for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index a108a91d5f9..33c0daca686 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -480,7 +480,11 @@ class HloInstruction { kCustom, }; - virtual ~HloInstruction(); + virtual ~HloInstruction() { DetachFromOperandsAndUsers(); } + + // Detaches an instruction from its operands and users. That is, remove the + // instruction from each operand's user set and user's operand set. + void DetachFromOperandsAndUsers(); // Creates an instruction from the given proto. Arguments: // @@ -2025,6 +2029,10 @@ class HloInstruction { // a default configuration. bool is_default_config_ = false; + // True if this instruction has already been detached from its user and + // operands. + bool cleaned_up_ = false; + // String identifier for instruction. string name_; diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 5e662e0bebc..f25f4694f21 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -184,6 +184,13 @@ class HloModule { // Gets the number of instructions in this module. int64 instruction_count() const; + // Deallocate removed instructions in each computation. + void Cleanup() { + for (auto& comp : computations_) { + comp->Cleanup(); + } + } + // Compute and return a post order of all computations in the module. The sort // is defined like so: if computation A has an instruction which calls // computation B, then A will appear after B in the sort. diff --git a/tensorflow/compiler/xla/service/hlo_module_group.h b/tensorflow/compiler/xla/service/hlo_module_group.h index c4b10f3b22a..217f65b4a75 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group.h +++ b/tensorflow/compiler/xla/service/hlo_module_group.h @@ -64,6 +64,13 @@ class HloModuleGroup { string ToString() const; + // Deallocate removed instructions in each module. + void Cleanup() { + for (auto& module : modules_) { + module->Cleanup(); + } + } + // Serialize the module group to/from a proto. HloModuleGroupProto ToProto() const; static StatusOr CreateFromProto( diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h index ad4070e3e23..16fad113b0d 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h @@ -104,11 +104,15 @@ class HloPassPipeline : public HloPassInterface { // helpers enable templating of the core of the pipeline logic by providing // HloModule and HloModuleGroup specific methods with the same name. static StatusOr RunHelper(HloPassInterface* pass, HloModule* module) { - return pass->Run(module); + TF_ASSIGN_OR_RETURN(bool changed, pass->Run(module)); + module->Cleanup(); + return changed; } static StatusOr RunHelper(HloPassInterface* pass, HloModuleGroup* module_group) { - return pass->RunOnModuleGroup(module_group); + TF_ASSIGN_OR_RETURN(bool changed, pass->RunOnModuleGroup(module_group)); + module_group->Cleanup(); + return changed; } const string name_; diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 5a34c502071..21be4216469 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -989,7 +989,6 @@ Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item, ItemList unplaced_users; for (Item* user : old_buffer.users) { if (user->placed) { - CHECK(IsFinished(user)) << user->instruction->name(); placed_users.push_back(user); } else { unplaced_users.push_back(user); diff --git a/tensorflow/compiler/xla/service/mlir_gpu/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/BUILD index 36e20656974..afceefdeae6 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/BUILD +++ b/tensorflow/compiler/xla/service/mlir_gpu/BUILD @@ -73,6 +73,7 @@ cc_library( "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", "@llvm-project//mlir:TargetNVVMIR", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:compiler", @@ -158,7 +159,6 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@llvm-project//mlir:AffineToStandardTransforms", - "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:CFGTransforms", "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:GPUToNVVMTransforms", @@ -193,7 +193,9 @@ cc_library( "//tensorflow/compiler/xla/tests:codegen_test_base", "//tensorflow/compiler/xla/tests:filecheck", "//tensorflow/compiler/xla/tests:verified_hlo_module", + "//tensorflow/core:lib", "//tensorflow/core:test", + "//tensorflow/core/platform:resource_loader", "//tensorflow/core/platform:test", "@com_google_absl//absl/memory", "@llvm-project//llvm:support", diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.cc b/tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.cc index dbc6efe9ec9..fa2167a4bd9 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.cc @@ -32,6 +32,9 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/filecheck.h" #include "tensorflow/compiler/xla/tests/verified_hlo_module.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/path.h" +#include "tensorflow/core/platform/resource_loader.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -46,8 +49,10 @@ void MlirIrGenTestBase::CompileIr(std::unique_ptr hlo_module, TF_ASSERT_OK(status); } -void MlirIrGenTestBase::PatternMatch(const string& str, const string& pattern) { - StatusOr filecheck_result = RunFileCheck(str, pattern); +void MlirIrGenTestBase::PatternMatch(const std::string& str, + const std::string& pattern_file) { + StatusOr filecheck_result = + RunFileCheckWithPatternFile(str, pattern_file); TF_ASSERT_OK(filecheck_result.status()); EXPECT_TRUE(filecheck_result.ValueOrDie()); } @@ -55,7 +60,7 @@ void MlirIrGenTestBase::PatternMatch(const string& str, const string& pattern) { string MlirIrGenTestBase::CompileIr( std::unique_ptr hlo_module, MlirCompiler::IRHook::LoweringStage printing_stage) { - string ir; + std::string ir; CompileIr(std::move(hlo_module), {[&ir](mlir::ModuleOp module) -> Status { std::string buffer_string; @@ -70,23 +75,21 @@ string MlirIrGenTestBase::CompileIr( } void MlirIrGenTestBase::CompileAndVerifyIr( - std::unique_ptr hlo_module, const string& pattern, + std::unique_ptr hlo_module, const std::string& pattern_file, LoweringStage printing_stage) { - string ir = CompileIr(std::move(hlo_module), printing_stage); - PatternMatch(ir, pattern); + std::string ir = CompileIr(std::move(hlo_module), printing_stage); + PatternMatch(ir, pattern_file); } -void MlirIrGenTestBase::CompileAndVerifyIr(const string& hlo_text, - const string& expected_llvm_ir, +void MlirIrGenTestBase::CompileAndVerifyIr(const std::string& hlo_text_filename, LoweringStage printing_stage) { - HloModuleConfig config; - config.set_debug_options(GetDebugOptionsForTest()); - auto module = absl::make_unique( - "Module", config, /*verifier_layout_sensitive=*/true, - /*allow_mixed_precision_in_hlo_verifier=*/false, - /*shape_size_function=*/ShapeUtil::ByteSizeOfElements); - TF_ASSERT_OK(module->ParseHloStringAndVerifyModule(hlo_text)); - CompileAndVerifyIr(std::move(module), expected_llvm_ir, printing_stage); + std::string hlo_text_absolute_filename = + tensorflow::GetDataDependencyFilepath(hlo_text_filename); + TF_ASSERT_OK_AND_ASSIGN(auto module, + GetVerifiedHloModule(hlo_text_absolute_filename)); + CompileAndVerifyIr(std::move(module), + /*pattern_file=*/hlo_text_absolute_filename, + printing_stage); } MlirCompiler::IRHook MlirIrGenTestBase::getIRHookBreakingLoweringStage( @@ -104,7 +107,7 @@ MlirCompiler::IRHook MlirIrGenTestBase::getIRHookBreakingLoweringStage( StatusOr MlirIrGenTestBase::CompileAndInjectErrors( std::unique_ptr hlo_module, LoweringStage breaking_stage) { - string errors; + std::string errors; auto error_handler = [&errors](const EmissionContext::ErrorMap& error_map, HloModule* hlo_module) { errors = "ERRORS FOUND: "; @@ -127,19 +130,32 @@ StatusOr MlirIrGenTestBase::CompileAndInjectErrors( return status; } -void MlirIrGenTestBase::CompileAndVerifyErrors(const string& hlo_text, - const string& expected_errors, - LoweringStage breaking_stage) { +void MlirIrGenTestBase::CompileAndVerifyErrors( + const std::string& hlo_text_filename, LoweringStage breaking_stage) { + std::string test_srcdir = tensorflow::testing::TensorFlowSrcRoot(); + std::string hlo_text_absolute_filename = + tensorflow::GetDataDependencyFilepath(hlo_text_filename); + TF_ASSERT_OK_AND_ASSIGN(auto module, + GetVerifiedHloModule(hlo_text_absolute_filename)); + TF_ASSERT_OK_AND_ASSIGN( + std::string errors, + CompileAndInjectErrors(std::move(module), breaking_stage)); + PatternMatch(errors, /*pattern_file=*/hlo_text_absolute_filename); +} + +StatusOr> +MlirIrGenTestBase::GetVerifiedHloModule(const std::string& hlo_text_filename) { HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); auto module = absl::make_unique( "Module", config, /*verifier_layout_sensitive=*/true, /*allow_mixed_precision_in_hlo_verifier=*/false, /*shape_size_function=*/ShapeUtil::ByteSizeOfElements); - TF_ASSERT_OK(module->ParseHloStringAndVerifyModule(hlo_text)); - TF_ASSERT_OK_AND_ASSIGN( - string errors, CompileAndInjectErrors(std::move(module), breaking_stage)); - PatternMatch(errors, expected_errors); + std::string hlo_text; + TF_RETURN_IF_ERROR(tensorflow::ReadFileToString( + tensorflow::Env::Default(), hlo_text_filename, &hlo_text)); + TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text)); + return std::move(module); } MlirCompiler* MlirIrGenTestBase::GetMLIRCompiler() { diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.h b/tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.h index a46b606d75e..46246c0d4d6 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.h +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.h @@ -39,38 +39,36 @@ class MlirIrGenTestBase : public CodegenTestBase { // steps to LLVM IR are applied; otherwise, the IR before lowering is // matched. void CompileAndVerifyIr(std::unique_ptr hlo_module, - const string& pattern, LoweringStage printing_stage); + const std::string& pattern_file, + LoweringStage printing_stage); - // A thin wrapper around CompileAndVerifyIr that parses `hlo_text` to create - // an HLO module. - void CompileAndVerifyIr(const string& hlo_text, - const string& expected_llvm_ir, + // A thin wrapper around CompileAndVerifyIr that parses the hlo text in + // `hlo_text_filename` to create an HLO module. + void CompileAndVerifyIr(const std::string& hlo_text_filename, LoweringStage printing_stage = LoweringStage::LHLO); - // Compiles and returns module with optimizations from a given HLO. - StatusOr> GetOptimizedModule( - absl::string_view hlo); - // Adds the InjectErrorsForTestingPass to MLIRCompiler on the provided - // lowering stage, compiles the given HLO module, and returns a string + // lowering stage, compiles the given HLO module, and returns a std::string // representation of all the errors occurred during compiling. StatusOr CompileAndInjectErrors(std::unique_ptr hlo_module, LoweringStage breaking_stage); // Adds the InjectErrorsForTestingPass to MLIRCompiler on the provided // lowering stage, parses and compiles `hlo_text`, and verifies that the - // string representation of all the errors occurred during compiling matches - // the given pattern. - void CompileAndVerifyErrors(const string& hlo_text, - const string& expected_errors, + // std::string representation of all the errors occurred during compiling + // matches the given pattern. + void CompileAndVerifyErrors(const std::string& hlo_text_filename, LoweringStage breaking_stage); private: + StatusOr> GetVerifiedHloModule( + const std::string& hlo_text_filename); + void CompileIr(std::unique_ptr hlo_module, const MlirCompiler::IRHook& ir_hook); - void PatternMatch(const string& str, const string& pattern); - string CompileIr(std::unique_ptr hlo_module, - LoweringStage printing_stage); + void PatternMatch(const std::string& str, const std::string& pattern_file); + std::string CompileIr(std::unique_ptr hlo_module, + LoweringStage printing_stage); MlirCompiler::IRHook getIRHookBreakingLoweringStage( LoweringStage breaking_stage); MlirCompiler* GetMLIRCompiler(); diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/tests/BUILD index 84f1c7668e5..aeaaf0b16c4 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/BUILD @@ -25,14 +25,39 @@ package_group( tf_cc_test( name = "mlir_gpu_lhlo_gen_test", srcs = if_cuda_is_configured(["mlir_gpu_lhlo_gen_test.cc"]), - tags = tf_cuda_tests_tags() + [ - "no_rocm", - "no_oss", # TODO(b/149544192): Fix the test. + data = [ + "abs.hlo", + "add.hlo", + "add_as_kernel.hlo", + "add_in_gpu_dialect.hlo", + "add_multiply.hlo", + "add_multiply_gpu.hlo", + "add_reduce.hlo", + "broadcast.hlo", + "broken_add.hlo", + "ceil.hlo", + "compare.hlo", + "const.hlo", + "copy.hlo", + "cos.hlo", + "exp.hlo", + "fused_reduce.hlo", + "iota.hlo", + "iota_add_multiply.hlo", + "log.hlo", + "neg.hlo", + "rem.hlo", + "rsqrt.hlo", + "select.hlo", + "sign.hlo", + "tanh.hlo", ], + tags = tf_cuda_tests_tags() + ["no_rocm"], deps = [ "//tensorflow/core:test_main", "//tensorflow/core:test", ] + if_cuda_is_configured([ + "//tensorflow/core:lib", "//tensorflow/compiler/xla/service:gpu_plugin_mlir", "//tensorflow/compiler/xla/service/mlir_gpu:mlir_irgen_test_base", "//tensorflow/stream_executor/lib", diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/abs.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/abs.hlo new file mode 100644 index 00000000000..6a4353d8d45 --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/abs.hlo @@ -0,0 +1,9 @@ +HloModule Abs +ENTRY %Abs (val: f32[2,2]) -> f32[2,2] { + %val = f32[2,2]{1,0} parameter(0) + ROOT %abs = f32[2,2]{1,0} abs(f32[2,2]{1,0} %val) +} + +// CHECK: func @abs(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { +// CHECK: "xla_lhlo.abs"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () +// CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/add.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/add.hlo new file mode 100644 index 00000000000..d48fcf89658 --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/add.hlo @@ -0,0 +1,11 @@ +HloModule Add + +ENTRY %Add (x: f32[2,2], y: f32[2,2]) -> f32[2,2] { + %x = f32[2,2]{1,0} parameter(0) + %y = f32[2,2]{1,0} parameter(1) + ROOT %add = f32[2,2]{1,0} add(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y) +} + +// CHECK: func @add(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]]) { +// CHECK: "xla_lhlo.add"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> () +// CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/add_as_kernel.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/add_as_kernel.hlo new file mode 100644 index 00000000000..c477cc99c39 --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/add_as_kernel.hlo @@ -0,0 +1,62 @@ +HloModule Add + +ENTRY %Add (x: f32[2,2], y: f32[2,2]) -> f32[2,2] { + %x = f32[2,2]{1,0} parameter(0) + %y = f32[2,2]{1,0} parameter(1) + ROOT %add = f32[2,2]{1,0} add(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y) +} + +// CHECK: func @add_kernel(%[[ARG0:.*]]: [[TYPE:!llvm<.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]] + +// +// Check that relevant sizes and strides are emitted. +// +// CHECK: %[[CAST0:.*]] = llvm.bitcast %[[ARG0:.*]] : !llvm<"i8*"> to !llvm<"float*"> +// CHECK: %[[SIZE00:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64 +// CHECK: %[[SIZE01:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64 +// CHECK: %[[STRIDE01:.*]] = llvm.mlir.constant(1 : i64) : !llvm.i64 +// CHECK: %[[STRIDE00:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64 + +// CHECK: %[[CAST1:.*]] = llvm.bitcast %[[ARG1:.*]] : !llvm<"i8*"> to !llvm<"float*"> +// CHECK: %[[SIZE10:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64 +// CHECK: %[[SIZE11:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64 +// CHECK: %[[STRIDE11:.*]] = llvm.mlir.constant(1 : i64) : !llvm.i64 +// CHECK: %[[STRIDE10:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64 + +// CHECK: %[[CAST2:.*]] = llvm.bitcast %[[ARG2:.*]] : !llvm<"i8*"> to !llvm<"float*"> +// CHECK: %[[SIZE20:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64 +// CHECK: %[[SIZE21:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64 +// CHECK: %[[STRIDE21:.*]] = llvm.mlir.constant(1 : i64) : !llvm.i64 +// CHECK: %[[STRIDE20:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64 + +// +// Check that the emitted sizes and strides, as well the pointers to HLO buffers, +// are inserted into the memref descriptors. +// +// CHECK: %[[DESC0:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: %[[DESC01:.*]] = llvm.insertvalue %[[CAST0]], %[[DESC0]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: %[[DESC02:.*]] = llvm.insertvalue %[[CAST0]], %[[DESC01]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: %[[DESC03:.*]] = llvm.insertvalue %{{.*}}, %[[DESC02]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: %[[DESC04:.*]] = llvm.insertvalue %[[SIZE00]], %[[DESC03]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: %[[DESC05:.*]] = llvm.insertvalue %[[STRIDE00]], %[[DESC04]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: %[[DESC06:.*]] = llvm.insertvalue %[[SIZE01]], %[[DESC05]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: %{{.*}} = llvm.insertvalue %[[STRIDE01]], %[[DESC06]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + +// CHECK: %[[DESC1:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: %[[DESC11:.*]] = llvm.insertvalue %[[CAST1]], %[[DESC1]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: %[[DESC12:.*]] = llvm.insertvalue %[[CAST1]], %[[DESC11]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: %[[DESC13:.*]] = llvm.insertvalue %{{.*}}, %[[DESC12]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: %[[DESC14:.*]] = llvm.insertvalue %[[SIZE10]], %[[DESC13]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: %[[DESC15:.*]] = llvm.insertvalue %[[STRIDE10]], %[[DESC14]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: %[[DESC16:.*]] = llvm.insertvalue %[[SIZE11]], %[[DESC15]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: %{{.*}} = llvm.insertvalue %[[STRIDE11]], %[[DESC16]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + +// CHECK: %[[DESC2:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: %[[DESC21:.*]] = llvm.insertvalue %[[CAST2]], %[[DESC2]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: %[[DESC22:.*]] = llvm.insertvalue %[[CAST2]], %[[DESC21]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: %[[DESC23:.*]] = llvm.insertvalue %{{.*}}, %[[DESC22]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: %[[DESC24:.*]] = llvm.insertvalue %[[SIZE20]], %[[DESC23]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: %[[DESC25:.*]] = llvm.insertvalue %[[STRIDE20]], %[[DESC24]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: %[[DESC26:.*]] = llvm.insertvalue %[[SIZE21]], %[[DESC25]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: %{{.*}} = llvm.insertvalue %[[STRIDE21]], %[[DESC26]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/add_in_gpu_dialect.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/add_in_gpu_dialect.hlo new file mode 100644 index 00000000000..ec7df87af64 --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/add_in_gpu_dialect.hlo @@ -0,0 +1,19 @@ +HloModule Add + +ENTRY %Add (x: f32[2,2], y: f32[2,2]) -> f32[2,2] { + %x = f32[2,2]{1,0} parameter(0) + %y = f32[2,2]{1,0} parameter(1) + ROOT %add = f32[2,2]{1,0} add(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y) +} + +// CHECK: func @add(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]]) { +// CHECK: "gpu.launch_func"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[ARG0]], %[[ARG1]], %[[ARG2]] +// CHECK: } +// CHECK: func @add_kernel(%[[ARG0]]: [[TYPE]], %[[ARG1]]: [[TYPE]], %[[ARG2]]: [[TYPE]] +// CHECK-DAG: std.subview %[[ARG0]]{{\[}}[[INDEX:.*]]] +// CHECK-DAG: std.subview %[[ARG1]]{{\[}}[[INDEX]]] +// CHECK-DAG: std.subview %[[ARG2]]{{\[}}[[INDEX]]] +// CHECK: %[[VAL1:.*]] = load %{{.*\[}}[[INDEX:.*]]] +// CHECK: %[[VAL2:.*]] = load %{{.*\[}}[[INDEX]]] +// CHECK: %[[RES:.*]] = addf %[[VAL1]], %[[VAL2]] +// CHECK: store %[[RES]], %{{.*\[}}[[INDEX]]] diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/add_multiply.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/add_multiply.hlo new file mode 100644 index 00000000000..f4f2e4d2c91 --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/add_multiply.hlo @@ -0,0 +1,21 @@ +HloModule AddMultiply + +ENTRY %AddMultiply (x: f32[2,2], y: f32[2,2], z: f32[2,2]) -> f32[2,2] { + %x = f32[2,2]{1,0} parameter(0) + %y = f32[2,2]{1,0} parameter(1) + %z = f32[2,2]{1,0} parameter(2) + %add = f32[2,2]{1,0} add(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y) + ROOT %mul = f32[2,2]{1,0} multiply(f32[2,2]{1,0} %add, f32[2,2]{1,0} %z) +} + +// CHECK: func @fusion(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]], %[[RESULT:.*]]: [[TYPE]]) +// CHECK: "xla_lhlo.fusion"() ( { +// CHECK: %[[REF0:.*]] = tensor_load %[[ARG0]] : [[TYPE]] +// CHECK: %[[REF1:.*]] = tensor_load %[[ARG1]] : [[TYPE]] +// CHECK: %[[REF2:.*]] = tensor_load %[[ARG2]] : [[TYPE]] +// CHECK: %[[ADD:.*]] = xla_hlo.add %[[REF1]], %[[REF2]] +// CHECK: %[[MUL:.*]] = xla_hlo.mul %[[ADD]], %[[REF0]] +// CHECK: tensor_store %[[MUL]], %[[RESULT]] +// CHECK: "xla_lhlo.terminator"() +// CHECK-NEXT: } + diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/add_multiply_gpu.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/add_multiply_gpu.hlo new file mode 100644 index 00000000000..e9000956c23 --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/add_multiply_gpu.hlo @@ -0,0 +1,22 @@ +HloModule AddMultiply + +ENTRY %AddMultiply (x: f32[2,2], y: f32[2,2], z: f32[2,2]) -> f32[2,2] { + %x = f32[2,2]{1,0} parameter(0) + %y = f32[2,2]{1,0} parameter(1) + %z = f32[2,2]{1,0} parameter(2) + %add = f32[2,2]{1,0} add(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y) + ROOT %mul = f32[2,2]{1,0} multiply(f32[2,2]{1,0} %add, f32[2,2]{1,0} %z) +} + +// CHECK: func @fusion_kernel(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]], %[[RESULT:.*]]: [[TYPE]]) +// CHECK-DAG: std.subview %[[ARG0]]{{\[}}[[INDEX:.*]]] +// CHECK-DAG: std.subview %[[ARG1]]{{\[}}[[INDEX]]] +// CHECK-DAG: std.subview %[[ARG2]]{{\[}}[[INDEX]]] +// CHECK-DAG: std.subview %[[RESULT]]{{\[}}[[INDEX]]] +// CHECK: %[[V0:.*]] = load %{{.*\[}}[[CSTIDX:.*]]] +// CHECK: %[[V1:.*]] = load %{{.*\[}}[[CSTIDX:.*]]] +// CHECK: %[[ADD:.*]] = addf %[[V0]], %[[V1]] +// CHECK: %[[V2:.*]] = load %{{.*\[}}[[CSTIDX:.*]]] +// CHECK: %[[MUL:.*]] = mulf %[[ADD]], %[[V2]] +// CHECK: store %[[MUL]], %{{.*\[}}[[CSTIDX:.*]]] +// CHECK-NEXT: return diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/add_reduce.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/add_reduce.hlo new file mode 100644 index 00000000000..6df8f284b72 --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/add_reduce.hlo @@ -0,0 +1,23 @@ +HloModule AddReduce + +%add (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %x, f32[] %y) +} + +ENTRY %AddReduce (x: f32[100,10], c: f32[]) -> f32[100] { + %x = f32[100,10]{1,0} parameter(0) + %c = f32[] parameter(1) + ROOT %reduce = f32[100]{0} reduce(f32[100,10]{1,0} %x, f32[] %c), dimensions={1}, to_apply=%add +} + +// CHECK: func @reduce(%[[ARG:.*]]: [[ARGT:.*]], %[[CST:.*]]: memref, %[[RES:.*]]: [[REST:.*]]) { +// CHECK: "xla_lhlo.reduce"(%[[ARG]], %[[CST]], %[[RES]]) ( { +// CHECK: ^bb0(%[[FARG0:.*]]: memref, %[[FARG1:.*]]: memref, %[[FRES:.*]]: memref): +// CHECK: %[[LHS:.*]] = tensor_load %[[FARG0]] : memref +// CHECK: %[[RHS:.*]] = tensor_load %[[FARG1]] : memref +// CHECK: %[[RES:.*]] = xla_hlo.add %[[LHS]], %[[RHS]] : tensor +// CHECK: tensor_store %[[RES]], %[[FRES]] : memref +// CHECK: "xla_lhlo.terminator"() : () -> () +// CHECK-NEXT: }) {dimensions = dense<1> : tensor<1xi64>} : ([[ARGT]], memref, [[REST]]) -> () diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/broadcast.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/broadcast.hlo new file mode 100644 index 00000000000..b0613ac96ac --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/broadcast.hlo @@ -0,0 +1,13 @@ +HloModule Broadcast + +ENTRY %Broadcast (x: f32[10]) -> f32[10, 5] { + %x = f32[10]{0} parameter(0) + ROOT %broadcast = f32[10, 5]{1,0} broadcast(f32[10]{0} %x), dimensions={0} +} + +// CHECK: func @broadcast(%[[IN:.*]]: [[IN_T:.*]], %[[OUT:.*]]: [[OUT_T:.*]]) { +// CHECK: "xla_lhlo.broadcast_in_dim"(%[[IN]], %[[OUT]]) +// CHECK: {broadcast_dimensions = dense<0> : tensor<1xi64>} +// CHECK: : ([[IN_T]], [[OUT_T]]) -> () +// CHECK: } + diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/broken_add.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/broken_add.hlo new file mode 100644 index 00000000000..b4b22f42f29 --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/broken_add.hlo @@ -0,0 +1,9 @@ +HloModule Add + +ENTRY %Add (x: f32[2,2,2], y: f32[2,2,2]) -> f32[2,2,2] { + %x = f32[2,2,2]{2,1,0} parameter(0) + %y = f32[2,2,2]{2,1,0} parameter(1) + ROOT %add = f32[2,2,2]{2,1,0} add(f32[2,2,2]{2,1,0} %x, f32[2,2,2]{2,1,0} %y) +} + +// CHECK: ERRORS FOUND: [%add = f32[2,2,2]{2,1,0} add(f32[2,2,2]{2,1,0} %x, f32[2,2,2]{2,1,0} %y): failed for testing: xla_lhlo.add; failed for testing: std.return] diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/ceil.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/ceil.hlo new file mode 100644 index 00000000000..ff4e8191da4 --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/ceil.hlo @@ -0,0 +1,9 @@ +HloModule Ceil +ENTRY %Ceil (val: f32[2,2]) -> f32[2,2] { + %val = f32[2,2]{1,0} parameter(0) + ROOT %ceil = f32[2,2]{1,0} ceil(f32[2,2]{1,0} %val) +} + +// CHECK: func @ceil(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { +// CHECK: "xla_lhlo.ceil"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () +// CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/compare.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/compare.hlo new file mode 100644 index 00000000000..a0f88efbd2f --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/compare.hlo @@ -0,0 +1,12 @@ +HloModule Compare + +ENTRY %Compare (x: f32[2,2], y: f32[2,2]) -> pred[2,2] { + %x = f32[2,2]{1,0} parameter(0) + %y = f32[2,2]{1,0} parameter(1) + ROOT %compare = pred[2,2]{1,0} compare(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y), direction=EQ +} + +// CHECK: func @compare(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[PRED:.*]]: [[PRED_TYPE:.*]]) { +// CHECK: "xla_lhlo.compare"(%[[ARG0]], %[[ARG1]], %[[PRED]]) +// CHECK: {comparison_direction = "EQ"} : ([[TYPE]], [[TYPE]], [[PRED_TYPE]]) -> () +// CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/const.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/const.hlo new file mode 100644 index 00000000000..9c28b3619ac --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/const.hlo @@ -0,0 +1,11 @@ +HloModule Const + +ENTRY %Const () -> s32[100] { + %const.0 = s32[] constant(10) + ROOT %broadcast.0 = s32[100]{0} broadcast(s32[] %const.0), dimensions={} +} + +// CHECK: func @constant(%[[ARG0:.*]]: memref) +// CHECK: "xla_lhlo.constant"(%[[ARG0]]) {value = dense<10> : tensor} +// CHECK: func @broadcast(%[[ARG1:.*]]: memref, %[[ARG2:.*]]: memref<100xi32>) +// CHECK: "xla_lhlo.broadcast_in_dim"(%[[ARG1]], %[[ARG2]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/copy.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/copy.hlo new file mode 100644 index 00000000000..a729a4375b6 --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/copy.hlo @@ -0,0 +1,9 @@ +HloModule Copy + +ENTRY %Copy (x: f32[2,4]) -> f32[2,4] { + %x = f32[2,4] parameter(0) + ROOT %copy = f32[2,4] copy(f32[2,4] %x) +} + +// CHECK: func @copy(%[[OPERAND:.*]]: memref<2x4xf32>, %[[RESULT:.*]]: memref<2x4xf32>) { +// CHECK: "xla_lhlo.copy"(%[[OPERAND]], %[[RESULT]]) : (memref<2x4xf32>, memref<2x4xf32>) -> () diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/cos.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/cos.hlo new file mode 100644 index 00000000000..9abc2dad0aa --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/cos.hlo @@ -0,0 +1,9 @@ +HloModule Cos +ENTRY %Cos (val: f32[2,2]) -> f32[2,2] { + %val = f32[2,2]{1,0} parameter(0) + ROOT %cos = f32[2,2]{1,0} cosine(f32[2,2]{1,0} %val) +} + +// CHECK: func @cosine(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { +// CHECK: "xla_lhlo.cos"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () +// CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/exp.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/exp.hlo new file mode 100644 index 00000000000..9af0de99d42 --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/exp.hlo @@ -0,0 +1,11 @@ +HloModule Exp + +ENTRY %Exp (x: f32[2,2]) -> f32[2,2] { + %x = f32[2,2]{1,0} parameter(0) + ROOT %exp = f32[2,2]{1,0} exponential(f32[2,2]{1,0} %x) +} + +// CHECK: func @exponential(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { +// CHECK: "xla_lhlo.exp"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () +// CHECK: } + diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/fused_reduce.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/fused_reduce.hlo new file mode 100644 index 00000000000..a673469977f --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/fused_reduce.hlo @@ -0,0 +1,34 @@ +HloModule FusedReduce + +%add (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %x, f32[] %y) +} + +%fused_computation (param: f32[100,10]) -> f32[10] { + %param = f32[100,10] parameter(0) + %constant = f32[] constant(0) + ROOT %reduce = f32[10]{0} reduce(f32[100,10]{1,0} %param, f32[] %constant), + dimensions={0}, to_apply=%add +} + +ENTRY %FusedReduce (x: f32[100,10]) -> f32[10] { + %x = f32[100,10] parameter(0) + ROOT %fusion = f32[10]{0} fusion(f32[100,10]{1,0} %x), kind=kInput, + calls=%fused_computation +} + +// CHECK: func @fusion(%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[RTYPE:.*]]) +// CHECK: "xla_lhlo.fusion"() ( { +// CHECK: %[[REF0:.*]] = tensor_load %arg0 : [[TYPE]] +// CHECK: %[[CT0:.*]] = xla_hlo.constant dense<0.000000e+00> +// CHECK: %[[RED:.*]] = "xla_hlo.reduce"(%0, %1) ( { +// CHECK: ^bb0(%[[BARG0:.*]]: [[ETYPE:.*]], %[[BARG1:.*]]: [[ETYPE]]) +// CHECK: %[[ADD:.*]] = xla_hlo.add %[[BARG0]], %[[BARG1]] : [[ETYPE]] +// CHECK: "xla_hlo.return"(%[[ADD]]) +// CHECK: }) +// CHECK: tensor_store %[[RED]], %[[RESULT]] : [[RTYPE]] +// CHECK: "xla_lhlo.terminator"() +// CHECK-NEXT: }) + diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/iota.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/iota.hlo new file mode 100644 index 00000000000..d622ed0e528 --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/iota.hlo @@ -0,0 +1,10 @@ +HloModule Iota + + ENTRY %Iota() -> s64[10, 5] { + ROOT %iota = s64[10, 5]{1,0} iota(), iota_dimension=0 +} + +// CHECK: func @iota(%[[OUT:.*]]: [[OUT_T:.*]]) { +// CHECK: "xla_lhlo.iota"(%[[OUT]]) +// CHECK: {iota_dimension = 0 : i64} : ([[OUT_T]]) -> () +// CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/iota_add_multiply.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/iota_add_multiply.hlo new file mode 100644 index 00000000000..89b7a43a102 --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/iota_add_multiply.hlo @@ -0,0 +1,15 @@ +HloModule AddMultiply + +ENTRY %AddMultiply (x: s32[2,2], y: s32[2,2]) -> s32[2,2] { + %x = s32[2,2]{1,0} parameter(0) + %y = s32[2,2]{1,0} parameter(1) + + %add = s32[2,2]{1,0} add(s32[2,2]{1,0} %x, s32[2,2]{1,0} %y) + %iota = s32[2, 2]{1,0} iota(), iota_dimension=0 + + ROOT %mul = s32[2,2]{1,0} multiply(s32[2,2]{1,0} %add, s32[2,2]{1,0} %iota) +} + +// CHECK-NOT: store +// CHECK: %[[RESULT:.*]] = muli %{{.*}}, %{{.*}} +// CHECK: store %[[RESULT]] diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/log.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/log.hlo new file mode 100644 index 00000000000..c7e2574558a --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/log.hlo @@ -0,0 +1,10 @@ +HloModule Log + +ENTRY %Log (x: f32[2,2]) -> f32[2,2] { + %x = f32[2,2]{1,0} parameter(0) + ROOT %log = f32[2,2]{1,0} log(f32[2,2]{1,0} %x) +} + +// CHECK: func @log(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { +// CHECK: "xla_lhlo.log"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () +// CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc b/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc index 9a23ff8748e..7afb7e9281d 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.h" +#include "tensorflow/core/platform/path.h" namespace xla { namespace mlir_gpu { @@ -21,513 +22,174 @@ namespace mlir_gpu { class LhloGenTest : public MlirIrGenTestBase {}; TEST_F(LhloGenTest, Const) { - CompileAndVerifyIr(R"( -HloModule Const - -ENTRY %Const () -> s32[100] { - %const.0 = s32[] constant(10) - ROOT %broadcast.0 = s32[100]{0} broadcast(s32[] %const.0), dimensions={} -})", - R"( -;CHECK: func @constant(%[[ARG0:.*]]: memref) -;CHECK: "xla_lhlo.constant"(%[[ARG0]]) {value = dense<10> : tensor} -;CHECK: func @broadcast(%[[ARG1:.*]]: memref, %[[ARG2:.*]]: memref<100xi32>) -;CHECK: "xla_lhlo.broadcast_in_dim"(%[[ARG1]], %[[ARG2]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} -)", - LoweringStage::LHLO); + CompileAndVerifyIr( + /*hlo_text_filename=*/tensorflow::io::JoinPath( + "tensorflow", "compiler", "xla", "service", "mlir_gpu", "tests", + "const.hlo"), + LoweringStage::LHLO); } TEST_F(LhloGenTest, BrokenAdd) { CompileAndVerifyErrors( - R"( -HloModule Add - -ENTRY %Add (x: f32[2,2,2], y: f32[2,2,2]) -> f32[2,2,2] { - %x = f32[2,2,2]{2,1,0} parameter(0) - %y = f32[2,2,2]{2,1,0} parameter(1) - ROOT %add = f32[2,2,2]{2,1,0} add(f32[2,2,2]{2,1,0} %x, f32[2,2,2]{2,1,0} %y) -})", - R"(CHECK: ERRORS FOUND: [%add = f32[2,2,2]{2,1,0} add(f32[2,2,2]{2,1,0} %x, f32[2,2,2]{2,1,0} %y): failed for testing: xla_lhlo.add; failed for testing: std.return])", + /*hlo_text_filename=*/ + tensorflow::io::JoinPath("tensorflow", "compiler", "xla", "service", + "mlir_gpu", "tests", "broken_add.hlo"), LoweringStage::LHLO); } TEST_F(LhloGenTest, Add) { - CompileAndVerifyIr(R"( -HloModule Add - -ENTRY %Add (x: f32[2,2], y: f32[2,2]) -> f32[2,2] { - %x = f32[2,2]{1,0} parameter(0) - %y = f32[2,2]{1,0} parameter(1) - ROOT %add = f32[2,2]{1,0} add(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y) -})", - R"( -;CHECK: func @add(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]]) { -;CHECK: "xla_lhlo.add"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> () -;CHECK: } - )"); + CompileAndVerifyIr( + /*hlo_text_filename=*/tensorflow::io::JoinPath( + "tensorflow", "compiler", "xla", "service", "mlir_gpu", "tests", + "add.hlo")); } TEST_F(LhloGenTest, Compare) { - CompileAndVerifyIr(R"( -HloModule Compare - -ENTRY %Compare (x: f32[2,2], y: f32[2,2]) -> pred[2,2] { - %x = f32[2,2]{1,0} parameter(0) - %y = f32[2,2]{1,0} parameter(1) - ROOT %compare = pred[2,2]{1,0} compare(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y), direction=EQ -})", - R"( -;CHECK: func @compare(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[PRED:.*]]: [[PRED_TYPE:.*]]) { -;CHECK: "xla_lhlo.compare"(%[[ARG0]], %[[ARG1]], %[[PRED]]) -;CHECK: {comparison_direction = "EQ"} : ([[TYPE]], [[TYPE]], [[PRED_TYPE]]) -> () -;CHECK: } -)"); + CompileAndVerifyIr( + /*hlo_text_filename=*/tensorflow::io::JoinPath( + "tensorflow", "compiler", "xla", "service", "mlir_gpu", "tests", + "compare.hlo")); } TEST_F(LhloGenTest, Copy) { - CompileAndVerifyIr(R"( -HloModule Copy - -ENTRY %Copy (x: f32[2,4]) -> f32[2,4] { - %x = f32[2,4] parameter(0) - ROOT %copy = f32[2,4] copy(f32[2,4] %x) -})", - R"( -;CHECK: func @copy(%[[OPERAND:.*]]: memref<2x4xf32>, %[[RESULT:.*]]: memref<2x4xf32>) { -;CHECK: "xla_lhlo.copy"(%[[OPERAND]], %[[RESULT]]) : (memref<2x4xf32>, memref<2x4xf32>) -> () - )"); + CompileAndVerifyIr( + /*hlo_text_filename=*/tensorflow::io::JoinPath( + "tensorflow", "compiler", "xla", "service", "mlir_gpu", "tests", + "copy.hlo")); } TEST_F(LhloGenTest, Select) { - CompileAndVerifyIr(R"( -HloModule Select - -ENTRY %Select (p: pred[2,2], x: f32[2,2], y: f32[2,2]) -> f32[2,2] { - %p = pred[2,2]{1,0} parameter(0) - %x = f32[2,2]{1,0} parameter(1) - %y = f32[2,2]{1,0} parameter(2) - ROOT %select = f32[2,2]{1,0} select(pred[2,2]{1,0} %p, f32[2,2]{1,0} %x, f32[2,2]{1,0} %y) -})", - R"( -;CHECK: func @select(%[[PRED:.*]]: [[PRED_TYPE:.*]], %[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]]) { -;CHECK: "xla_lhlo.select"(%[[PRED]], %[[ARG0]], %[[ARG1]], %[[ARG2]]) : ([[PRED_TYPE]], [[TYPE]], [[TYPE]], [[TYPE]]) -> () -;CHECK: } - )"); + CompileAndVerifyIr( + /*hlo_text_filename=*/tensorflow::io::JoinPath( + "tensorflow", "compiler", "xla", "service", "mlir_gpu", "tests", + "select.hlo")); } TEST_F(LhloGenTest, Exp) { - CompileAndVerifyIr(R"( -HloModule Exp - -ENTRY %Exp (x: f32[2,2]) -> f32[2,2] { - %x = f32[2,2]{1,0} parameter(0) - ROOT %exp = f32[2,2]{1,0} exponential(f32[2,2]{1,0} %x) -})", - R"( -;CHECK: func @exponential(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { -;CHECK: "xla_lhlo.exp"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () -;CHECK: } - )"); + CompileAndVerifyIr( + /*hlo_text_filename=*/tensorflow::io::JoinPath( + "tensorflow", "compiler", "xla", "service", "mlir_gpu", "tests", + "exp.hlo")); } TEST_F(LhloGenTest, Log) { - CompileAndVerifyIr(R"( -HloModule Log - -ENTRY %Log (x: f32[2,2]) -> f32[2,2] { - %x = f32[2,2]{1,0} parameter(0) - ROOT %log = f32[2,2]{1,0} log(f32[2,2]{1,0} %x) -})", - R"( -;CHECK: func @log(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { -;CHECK: "xla_lhlo.log"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () -;CHECK: } - )"); + CompileAndVerifyIr( + /*hlo_text_filename=*/tensorflow::io::JoinPath( + "tensorflow", "compiler", "xla", "service", "mlir_gpu", "tests", + "log.hlo")); } TEST_F(LhloGenTest, AddInGPUDialect) { - CompileAndVerifyIr(R"( -HloModule Add - -ENTRY %Add (x: f32[2,2], y: f32[2,2]) -> f32[2,2] { - %x = f32[2,2]{1,0} parameter(0) - %y = f32[2,2]{1,0} parameter(1) - ROOT %add = f32[2,2]{1,0} add(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y) -})", - R"( -;CHECK: func @add(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]]) { -;CHECK: "gpu.launch_func"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[ARG0]], %[[ARG1]], %[[ARG2]] -;CHECK: } -;CHECK: func @add_kernel(%[[ARG0]]: [[TYPE]], %[[ARG1]]: [[TYPE]], %[[ARG2]]: [[TYPE]] -;CHECK-DAG: std.subview %[[ARG0]]{{\[}}[[INDEX:.*]]] -;CHECK-DAG: std.subview %[[ARG1]]{{\[}}[[INDEX]]] -;CHECK-DAG: std.subview %[[ARG2]]{{\[}}[[INDEX]]] -;CHECK: %[[VAL1:.*]] = load %{{.*\[}}[[INDEX:.*]]] -;CHECK: %[[VAL2:.*]] = load %{{.*\[}}[[INDEX]]] -;CHECK: %[[RES:.*]] = addf %[[VAL1]], %[[VAL2]] -;CHECK: store %[[RES]], %{{.*\[}}[[INDEX]]] - )", - LoweringStage::GPU); + CompileAndVerifyIr( + /*hlo_text_filename=*/ + tensorflow::io::JoinPath("tensorflow", "compiler", "xla", "service", + "mlir_gpu", "tests", "add_in_gpu_dialect.hlo"), + LoweringStage::GPU); } // This test verifies that the kernel signature is amended correctly. The actual // body of the generated function does not matter, it is already checked at the // GPU level above. TEST_F(LhloGenTest, AddAsKernel) { - CompileAndVerifyIr(R"( -HloModule Add - -ENTRY %Add (x: f32[2,2], y: f32[2,2]) -> f32[2,2] { - %x = f32[2,2]{1,0} parameter(0) - %y = f32[2,2]{1,0} parameter(1) - ROOT %add = f32[2,2]{1,0} add(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y) -})", - R"( -;CHECK: func @add_kernel(%[[ARG0:.*]]: [[TYPE:!llvm<.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]] - -; -; Check that relevant sizes and strides are emitted. -; -;CHECK: %[[CAST0:.*]] = llvm.bitcast %[[ARG0:.*]] : !llvm<"i8*"> to !llvm<"float*"> -;CHECK: %[[SIZE00:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64 -;CHECK: %[[SIZE01:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64 -;CHECK: %[[STRIDE01:.*]] = llvm.mlir.constant(1 : i64) : !llvm.i64 -;CHECK: %[[STRIDE00:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64 - -;CHECK: %[[CAST1:.*]] = llvm.bitcast %[[ARG1:.*]] : !llvm<"i8*"> to !llvm<"float*"> -;CHECK: %[[SIZE10:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64 -;CHECK: %[[SIZE11:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64 -;CHECK: %[[STRIDE11:.*]] = llvm.mlir.constant(1 : i64) : !llvm.i64 -;CHECK: %[[STRIDE10:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64 - -;CHECK: %[[CAST2:.*]] = llvm.bitcast %[[ARG2:.*]] : !llvm<"i8*"> to !llvm<"float*"> -;CHECK: %[[SIZE20:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64 -;CHECK: %[[SIZE21:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64 -;CHECK: %[[STRIDE21:.*]] = llvm.mlir.constant(1 : i64) : !llvm.i64 -;CHECK: %[[STRIDE20:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64 - -; -; Check that the emitted sizes and strides, as well the pointers to HLO buffers, -; are inserted into the memref descriptors. -; -;CHECK: %[[DESC0:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -;CHECK: %[[DESC01:.*]] = llvm.insertvalue %[[CAST0]], %[[DESC0]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -;CHECK: %[[DESC02:.*]] = llvm.insertvalue %[[CAST0]], %[[DESC01]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -;CHECK: %[[DESC03:.*]] = llvm.insertvalue %{{.*}}, %[[DESC02]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -;CHECK: %[[DESC04:.*]] = llvm.insertvalue %[[SIZE00]], %[[DESC03]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -;CHECK: %[[DESC05:.*]] = llvm.insertvalue %[[STRIDE00]], %[[DESC04]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -;CHECK: %[[DESC06:.*]] = llvm.insertvalue %[[SIZE01]], %[[DESC05]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -;CHECK: %{{.*}} = llvm.insertvalue %[[STRIDE01]], %[[DESC06]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> - -;CHECK: %[[DESC1:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -;CHECK: %[[DESC11:.*]] = llvm.insertvalue %[[CAST1]], %[[DESC1]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -;CHECK: %[[DESC12:.*]] = llvm.insertvalue %[[CAST1]], %[[DESC11]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -;CHECK: %[[DESC13:.*]] = llvm.insertvalue %{{.*}}, %[[DESC12]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -;CHECK: %[[DESC14:.*]] = llvm.insertvalue %[[SIZE10]], %[[DESC13]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -;CHECK: %[[DESC15:.*]] = llvm.insertvalue %[[STRIDE10]], %[[DESC14]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -;CHECK: %[[DESC16:.*]] = llvm.insertvalue %[[SIZE11]], %[[DESC15]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -;CHECK: %{{.*}} = llvm.insertvalue %[[STRIDE11]], %[[DESC16]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> - -;CHECK: %[[DESC2:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -;CHECK: %[[DESC21:.*]] = llvm.insertvalue %[[CAST2]], %[[DESC2]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -;CHECK: %[[DESC22:.*]] = llvm.insertvalue %[[CAST2]], %[[DESC21]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -;CHECK: %[[DESC23:.*]] = llvm.insertvalue %{{.*}}, %[[DESC22]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -;CHECK: %[[DESC24:.*]] = llvm.insertvalue %[[SIZE20]], %[[DESC23]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -;CHECK: %[[DESC25:.*]] = llvm.insertvalue %[[STRIDE20]], %[[DESC24]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -;CHECK: %[[DESC26:.*]] = llvm.insertvalue %[[SIZE21]], %[[DESC25]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -;CHECK: %{{.*}} = llvm.insertvalue %[[STRIDE21]], %[[DESC26]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> - )", - LoweringStage::KERNEL); + CompileAndVerifyIr( + tensorflow::io::JoinPath("tensorflow", "compiler", "xla", "service", + "mlir_gpu", "tests", "add_as_kernel.hlo"), + LoweringStage::KERNEL); } // TODO(b/149302060) Reenable once fusion is fixed. TEST_F(LhloGenTest, DISABLED_AddMultiply) { - CompileAndVerifyIr(R"( -HloModule AddMultiply - -ENTRY %AddMultiply (x: f32[2,2], y: f32[2,2], z: f32[2,2]) -> f32[2,2] { - %x = f32[2,2]{1,0} parameter(0) - %y = f32[2,2]{1,0} parameter(1) - %z = f32[2,2]{1,0} parameter(2) - %add = f32[2,2]{1,0} add(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y) - ROOT %mul = f32[2,2]{1,0} multiply(f32[2,2]{1,0} %add, f32[2,2]{1,0} %z) -})", - R"( -;CHECK: func @fusion(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]], %[[RESULT:.*]]: [[TYPE]]) -;CHECK: "xla_lhlo.fusion"() ( { -;CHECK: %[[REF0:.*]] = tensor_load %[[ARG0]] : [[TYPE]] -;CHECK: %[[REF1:.*]] = tensor_load %[[ARG1]] : [[TYPE]] -;CHECK: %[[REF2:.*]] = tensor_load %[[ARG2]] : [[TYPE]] -;CHECK: %[[ADD:.*]] = xla_hlo.add %[[REF1]], %[[REF2]] -;CHECK: %[[MUL:.*]] = xla_hlo.mul %[[ADD]], %[[REF0]] -;CHECK: tensor_store %[[MUL]], %[[RESULT]] -;CHECK: "xla_lhlo.terminator"() -;CHECK-NEXT: } - )"); + CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla", + "service", "mlir_gpu", "tests", + "add_multiply.hlo")); } // TODO(b/149302060) Reenable once fusion is fixed. TEST_F(LhloGenTest, DISABLED_IotaAddMultiply) { - CompileAndVerifyIr(R"( -HloModule AddMultiply - -ENTRY %AddMultiply (x: s32[2,2], y: s32[2,2]) -> s32[2,2] { - %x = s32[2,2]{1,0} parameter(0) - %y = s32[2,2]{1,0} parameter(1) - - %add = s32[2,2]{1,0} add(s32[2,2]{1,0} %x, s32[2,2]{1,0} %y) - %iota = s32[2, 2]{1,0} iota(), iota_dimension=0 - - ROOT %mul = s32[2,2]{1,0} multiply(s32[2,2]{1,0} %add, s32[2,2]{1,0} %iota) -})", - R"( -;CHECK-NOT: store -;CHECK: %[[RESULT:.*]] = muli %{{.*}}, %{{.*}} -;CHECK: store %[[RESULT]] -)", - LoweringStage::GPU); + CompileAndVerifyIr( + tensorflow::io::JoinPath("tensorflow", "compiler", "xla", "service", + "mlir_gpu", "tests", "iota_add_multiply.hlo"), + LoweringStage::GPU); } TEST_F(LhloGenTest, AddMultiplyGPU) { - CompileAndVerifyIr(R"( -HloModule AddMultiply - -ENTRY %AddMultiply (x: f32[2,2], y: f32[2,2], z: f32[2,2]) -> f32[2,2] { - %x = f32[2,2]{1,0} parameter(0) - %y = f32[2,2]{1,0} parameter(1) - %z = f32[2,2]{1,0} parameter(2) - %add = f32[2,2]{1,0} add(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y) - ROOT %mul = f32[2,2]{1,0} multiply(f32[2,2]{1,0} %add, f32[2,2]{1,0} %z) -})", - R"( -;CHECK: func @fusion_kernel(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]], %[[RESULT:.*]]: [[TYPE]]) -;CHECK-DAG: std.subview %[[ARG0]]{{\[}}[[INDEX:.*]]] -;CHECK-DAG: std.subview %[[ARG1]]{{\[}}[[INDEX]]] -;CHECK-DAG: std.subview %[[ARG2]]{{\[}}[[INDEX]]] -;CHECK-DAG: std.subview %[[RESULT]]{{\[}}[[INDEX]]] -;CHECK: %[[V0:.*]] = load %{{.*\[}}[[CSTIDX:.*]]] -;CHECK: %[[V1:.*]] = load %{{.*\[}}[[CSTIDX:.*]]] -;CHECK: %[[ADD:.*]] = addf %[[V0]], %[[V1]] -;CHECK: %[[V2:.*]] = load %{{.*\[}}[[CSTIDX:.*]]] -;CHECK: %[[MUL:.*]] = mulf %[[ADD]], %[[V2]] -;CHECK: store %[[MUL]], %{{.*\[}}[[CSTIDX:.*]]] -;CHECK-NEXT: return - )", - LoweringStage::GPU); + CompileAndVerifyIr( + tensorflow::io::JoinPath("tensorflow", "compiler", "xla", "service", + "mlir_gpu", "tests", "add_multiply_gpu.hlo"), + LoweringStage::GPU); } // TODO(b/137624192): Reenable once we can fuse reductions. TEST_F(LhloGenTest, DISABLED_FusedReduce) { - CompileAndVerifyIr(R"( -HloModule FusedReduce - -%add (x: f32[], y: f32[]) -> f32[] { - %x = f32[] parameter(0) - %y = f32[] parameter(1) - ROOT %add = f32[] add(f32[] %x, f32[] %y) -} - -%fused_computation (param: f32[100,10]) -> f32[10] { - %param = f32[100,10] parameter(0) - %constant = f32[] constant(0) - ROOT %reduce = f32[10]{0} reduce(f32[100,10]{1,0} %param, f32[] %constant), - dimensions={0}, to_apply=%add -} - -ENTRY %FusedReduce (x: f32[100,10]) -> f32[10] { - %x = f32[100,10] parameter(0) - ROOT %fusion = f32[10]{0} fusion(f32[100,10]{1,0} %x), kind=kInput, - calls=%fused_computation -} -)", - R"( -;CHECK: func @fusion(%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[RTYPE:.*]]) -;CHECK: "xla_lhlo.fusion"() ( { -;CHECK: %[[REF0:.*]] = tensor_load %arg0 : [[TYPE]] -;CHECK: %[[CT0:.*]] = xla_hlo.constant dense<0.000000e+00> -;CHECK: %[[RED:.*]] = "xla_hlo.reduce"(%0, %1) ( { -;CHECK: ^bb0(%[[BARG0:.*]]: [[ETYPE:.*]], %[[BARG1:.*]]: [[ETYPE]]) -;CHECK: %[[ADD:.*]] = xla_hlo.add %[[BARG0]], %[[BARG1]] : [[ETYPE]] -;CHECK: "xla_hlo.return"(%[[ADD]]) -;CHECK: }) -;CHECK: tensor_store %[[RED]], %[[RESULT]] : [[RTYPE]] -;CHECK: "xla_lhlo.terminator"() -;CHECK-NEXT: }) - )"); + CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla", + "service", "mlir_gpu", "tests", + "fused_reduce.hlo")); } TEST_F(LhloGenTest, Broadcast) { - CompileAndVerifyIr(R"( -HloModule Broadcast - -ENTRY %Broadcast (x: f32[10]) -> f32[10, 5] { - %x = f32[10]{0} parameter(0) - ROOT %broadcast = f32[10, 5]{1,0} broadcast(f32[10]{0} %x), dimensions={0} -})", - R"( -;CHECK: func @broadcast(%[[IN:.*]]: [[IN_T:.*]], %[[OUT:.*]]: [[OUT_T:.*]]) { -;CHECK: "xla_lhlo.broadcast_in_dim"(%[[IN]], %[[OUT]]) -;CHECK: {broadcast_dimensions = dense<0> : tensor<1xi64>} -;CHECK: : ([[IN_T]], [[OUT_T]]) -> () -;CHECK: } -)"); + CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla", + "service", "mlir_gpu", "tests", + "broadcast.hlo")); } TEST_F(LhloGenTest, Iota) { - CompileAndVerifyIr(R"( - HloModule Iota - - ENTRY %Iota() -> s64[10, 5] { - ROOT %iota = s64[10, 5]{1,0} iota(), iota_dimension=0 -})", - R"( -;CHECK: func @iota(%[[OUT:.*]]: [[OUT_T:.*]]) { -;CHECK: "xla_lhlo.iota"(%[[OUT]]) -;CHECK: {iota_dimension = 0 : i64} : ([[OUT_T]]) -> () -;CHECK: } -)"); + CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla", + "service", "mlir_gpu", "tests", + "iota.hlo")); } TEST_F(LhloGenTest, AddReduce) { - CompileAndVerifyIr(R"( -HloModule AddReduce - -%add (x: f32[], y: f32[]) -> f32[] { - %x = f32[] parameter(0) - %y = f32[] parameter(1) - ROOT %add = f32[] add(f32[] %x, f32[] %y) -} - -ENTRY %AddReduce (x: f32[100,10], c: f32[]) -> f32[100] { - %x = f32[100,10]{1,0} parameter(0) - %c = f32[] parameter(1) - ROOT %reduce = f32[100]{0} reduce(f32[100,10]{1,0} %x, f32[] %c), dimensions={1}, to_apply=%add -})", - R"( -;CHECK: func @reduce(%[[ARG:.*]]: [[ARGT:.*]], %[[CST:.*]]: memref, %[[RES:.*]]: [[REST:.*]]) { -;CHECK: "xla_lhlo.reduce"(%[[ARG]], %[[CST]], %[[RES]]) ( { -;CHECK: ^bb0(%[[FARG0:.*]]: memref, %[[FARG1:.*]]: memref, %[[FRES:.*]]: memref): -;CHECK: %[[LHS:.*]] = tensor_load %[[FARG0]] : memref -;CHECK: %[[RHS:.*]] = tensor_load %[[FARG1]] : memref -;CHECK: %[[RES:.*]] = xla_hlo.add %[[LHS]], %[[RHS]] : tensor -;CHECK: tensor_store %[[RES]], %[[FRES]] : memref -;CHECK: "xla_lhlo.terminator"() : () -> () -;CHECK-NEXT: }) {dimensions = dense<1> : tensor<1xi64>} : ([[ARGT]], memref, [[REST]]) -> () - )"); + CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla", + "service", "mlir_gpu", "tests", + "add_reduce.hlo")); } TEST_F(LhloGenTest, Abs) { - CompileAndVerifyIr(R"( -HloModule Abs -ENTRY %Abs (val: f32[2,2]) -> f32[2,2] { - %val = f32[2,2]{1,0} parameter(0) - ROOT %abs = f32[2,2]{1,0} abs(f32[2,2]{1,0} %val) -})", - R"( -;CHECK: func @abs(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { -;CHECK: "xla_lhlo.abs"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () -;CHECK: } - )"); + CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla", + "service", "mlir_gpu", "tests", + "abs.hlo")); } TEST_F(LhloGenTest, Ceil) { - CompileAndVerifyIr(R"( -HloModule Ceil -ENTRY %Ceil (val: f32[2,2]) -> f32[2,2] { - %val = f32[2,2]{1,0} parameter(0) - ROOT %ceil = f32[2,2]{1,0} ceil(f32[2,2]{1,0} %val) -})", - R"( -;CHECK: func @ceil(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { -;CHECK: "xla_lhlo.ceil"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () -;CHECK: } - )"); + CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla", + "service", "mlir_gpu", "tests", + "ceil.hlo")); } TEST_F(LhloGenTest, Cos) { - CompileAndVerifyIr(R"( -HloModule Cos -ENTRY %Cos (val: f32[2,2]) -> f32[2,2] { - %val = f32[2,2]{1,0} parameter(0) - ROOT %cos = f32[2,2]{1,0} cosine(f32[2,2]{1,0} %val) -})", - R"( -;CHECK: func @cosine(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { -;CHECK: "xla_lhlo.cos"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () -;CHECK: } - )"); + CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla", + "service", "mlir_gpu", "tests", + "cos.hlo")); } TEST_F(LhloGenTest, Neg) { - CompileAndVerifyIr(R"( -HloModule Neg -ENTRY %Neg (val: f32[2,2]) -> f32[2,2] { - %val = f32[2,2]{1,0} parameter(0) - ROOT %neg = f32[2,2]{1,0} negate(f32[2,2]{1,0} %val) -})", - R"( -;CHECK: func @negate(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { -;CHECK: "xla_lhlo.neg"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () -;CHECK: } - )"); + CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla", + "service", "mlir_gpu", "tests", + "neg.hlo")); } TEST_F(LhloGenTest, Rem) { - CompileAndVerifyIr(R"( -HloModule Rem -ENTRY %Rem(x: f32[2,2], y: f32[2,2]) -> f32[2,2] { - %x = f32[2,2]{1,0} parameter(0) - %y = f32[2,2]{1,0} parameter(1) - ROOT %rem = f32[2,2]{1,0} remainder(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y) -})", - R"( -;CHECK: func @remainder(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]]) { -;CHECK: "xla_lhlo.remainder"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> () -;CHECK: } - )"); + CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla", + "service", "mlir_gpu", "tests", + "rem.hlo")); } TEST_F(LhloGenTest, Rsqrt) { - CompileAndVerifyIr(R"( -HloModule Rsqrt - -ENTRY %Rsqrt (x: f32[2,2]) -> f32[2,2] { - %x = f32[2,2]{1,0} parameter(0) - ROOT %rsqrt = f32[2,2]{1,0} rsqrt(f32[2,2]{1,0} %x) -})", - R"( -;CHECK: func @rsqrt(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { -;CHECK: "xla_lhlo.rsqrt"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () -;CHECK: } - )"); + CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla", + "service", "mlir_gpu", "tests", + "rsqrt.hlo")); } TEST_F(LhloGenTest, Sign) { - CompileAndVerifyIr(R"( -HloModule Sign -ENTRY %Sign (val: f32[2,2]) -> f32[2,2] { - %val = f32[2,2]{1,0} parameter(0) - ROOT %sign = f32[2,2]{1,0} sign(f32[2,2]{1,0} %val) -})", - R"( -;CHECK: func @sign(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { -;CHECK: "xla_lhlo.sign"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () -;CHECK: } - )"); + CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla", + "service", "mlir_gpu", "tests", + "rsqrt.hlo")); } TEST_F(LhloGenTest, Tanh) { - CompileAndVerifyIr(R"( -HloModule Tanh -ENTRY %Tanh (val: f32[2,2]) -> f32[2,2] { - %val = f32[2,2]{1,0} parameter(0) - ROOT %tanh = f32[2,2]{1,0} tanh(f32[2,2]{1,0} %val) -})", - R"( -;CHECK: func @tanh(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { -;CHECK: "xla_lhlo.tanh"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () -;CHECK: } - )"); + CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla", + "service", "mlir_gpu", "tests", + "tanh.hlo")); } } // namespace mlir_gpu diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/neg.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/neg.hlo new file mode 100644 index 00000000000..caead37c995 --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/neg.hlo @@ -0,0 +1,9 @@ +HloModule Neg +ENTRY %Neg (val: f32[2,2]) -> f32[2,2] { + %val = f32[2,2]{1,0} parameter(0) + ROOT %neg = f32[2,2]{1,0} negate(f32[2,2]{1,0} %val) +} + +// CHECK: func @negate(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { +// CHECK: "xla_lhlo.neg"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () +// CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/rem.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/rem.hlo new file mode 100644 index 00000000000..441ace6ef94 --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/rem.hlo @@ -0,0 +1,10 @@ +HloModule Rem +ENTRY %Rem(x: f32[2,2], y: f32[2,2]) -> f32[2,2] { + %x = f32[2,2]{1,0} parameter(0) + %y = f32[2,2]{1,0} parameter(1) + ROOT %rem = f32[2,2]{1,0} remainder(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y) +} + +// CHECK: func @remainder(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]]) { +// CHECK: "xla_lhlo.remainder"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> () +// CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/rsqrt.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/rsqrt.hlo new file mode 100644 index 00000000000..a10f9ada92b --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/rsqrt.hlo @@ -0,0 +1,10 @@ +HloModule Rsqrt + +ENTRY %Rsqrt (x: f32[2,2]) -> f32[2,2] { + %x = f32[2,2]{1,0} parameter(0) + ROOT %rsqrt = f32[2,2]{1,0} rsqrt(f32[2,2]{1,0} %x) +} + +// CHECK: func @rsqrt(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { +// CHECK: "xla_lhlo.rsqrt"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () +// CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/select.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/select.hlo new file mode 100644 index 00000000000..0cbe8c73700 --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/select.hlo @@ -0,0 +1,13 @@ +HloModule Select + +ENTRY %Select (p: pred[2,2], x: f32[2,2], y: f32[2,2]) -> f32[2,2] { + %p = pred[2,2]{1,0} parameter(0) + %x = f32[2,2]{1,0} parameter(1) + %y = f32[2,2]{1,0} parameter(2) + ROOT %select = f32[2,2]{1,0} select(pred[2,2]{1,0} %p, f32[2,2]{1,0} %x, f32[2,2]{1,0} %y) +} + +// CHECK: func @select(%[[PRED:.*]]: [[PRED_TYPE:.*]], %[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]]) { +// CHECK: "xla_lhlo.select"(%[[PRED]], %[[ARG0]], %[[ARG1]], %[[ARG2]]) : ([[PRED_TYPE]], [[TYPE]], [[TYPE]], [[TYPE]]) -> () +// CHECK: } + diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/sign.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/sign.hlo new file mode 100644 index 00000000000..a0ff329938b --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/sign.hlo @@ -0,0 +1,9 @@ +HloModule Sign +ENTRY %Sign (val: f32[2,2]) -> f32[2,2] { + %val = f32[2,2]{1,0} parameter(0) + ROOT %sign = f32[2,2]{1,0} sign(f32[2,2]{1,0} %val) +} + +// CHECK: func @sign(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { +// CHECK: "xla_lhlo.sign"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () +// CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/tanh.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/tanh.hlo new file mode 100644 index 00000000000..d539b3002dc --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/tanh.hlo @@ -0,0 +1,9 @@ +HloModule Tanh +ENTRY %Tanh (val: f32[2,2]) -> f32[2,2] { + %val = f32[2,2]{1,0} parameter(0) + ROOT %tanh = f32[2,2]{1,0} tanh(f32[2,2]{1,0} %val) +} + +// CHECK: func @tanh(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { +// CHECK: "xla_lhlo.tanh"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () +// CHECK: } diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc index a8a4b7ef872..d97893b6d04 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc @@ -368,12 +368,12 @@ bool MultiOutputFusion::Perform() { int changed = false; // Pick the top candidate from queue and try to merge. while (!worklist_.empty()) { - ToBeFused candidate = worklist_.top(); - worklist_.pop(); + ToBeFused candidate = worklist_.pop(); HloInstruction* instr1 = candidate.instr1; HloInstruction* instr2 = candidate.instr2; + // Candidates are already fused. if (is_fused(instr1) || is_fused(instr2)) { continue; } diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h index 18069e2f76c..f0b56eeff90 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.h +++ b/tensorflow/compiler/xla/service/multi_output_fusion.h @@ -136,9 +136,34 @@ class MultiOutputFusion : public HloModulePass { HloInstruction* instr1; HloInstruction* instr2; int64 score; - ToBeFused(HloInstruction* instr1, HloInstruction* instr2, int64 score) - : instr1(instr1), instr2(instr2), score(score) {} - bool operator<(const ToBeFused& rhs) const { return score < rhs.score; } + int64 timestamp; + ToBeFused(HloInstruction* instr1, HloInstruction* instr2, int64 score, + int64 timestamp) + : instr1(instr1), instr2(instr2), score(score), timestamp(timestamp) {} + bool operator<(const ToBeFused& rhs) const { + return std::pair(score, timestamp) < + std::pair(rhs.score, rhs.timestamp); + } + }; + + // Stable priority queue where each insertion has a timestamp for + // deterministic popping. + class WorkList { + public: + bool empty() { return worklist_.empty(); } + ToBeFused pop() { + ToBeFused tmp = worklist_.top(); + worklist_.pop(); + return tmp; + } + template + void emplace(Args&&... args) { + worklist_.emplace(std::forward(args)..., timestamp_++); + } + + private: + std::priority_queue worklist_; + int64 timestamp_ = 0; }; // Update the internal data structures before instr1 and instr2 are fused into @@ -169,7 +194,7 @@ class MultiOutputFusion : public HloModulePass { } std::vector candidates_; - std::priority_queue worklist_; + WorkList worklist_; // A map that maps an instruction to the index_. absl::flat_hash_map candidates_index_; diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index bf2a1d64476..540a63405ef 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -262,6 +262,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "//tensorflow/core/platform:resource_loader", ], ) @@ -2429,15 +2430,16 @@ tf_cc_test( tags = tf_cuda_tests_tags(), deps = [ ":hlo_test_base", + ":literal_test_util", + ":xla_internal_test_main", # fixdeps: keep "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/service:cpu_plugin", # reference backend "//tensorflow/compiler/xla/service:gpu_plugin", # test backend "//tensorflow/compiler/xla/service:platform_util", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", "//tensorflow/core:test", + "//tensorflow/core/platform:resource_loader", ], ) diff --git a/tensorflow/compiler/xla/tests/collective_ops_test.cc b/tensorflow/compiler/xla/tests/collective_ops_test.cc index 56c5f688312..5cdf9633ca4 100644 --- a/tensorflow/compiler/xla/tests/collective_ops_test.cc +++ b/tensorflow/compiler/xla/tests/collective_ops_test.cc @@ -368,6 +368,55 @@ XLA_TEST_F(CollectiveOpsTest, AllReduce_ManyConcurrentAllReduces) { done.Wait(); } +// Runs the same executable many times concurrently. The all-reduces should not +// conflict with one another. +XLA_TEST_F(CollectiveOpsTest, AllReduce_CombinableAllReduces) { + std::string hlo_string = R"( + HloModule test + + apply_op { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT apply_op = f32[] add(x, y) + } + + ENTRY test_computation { + p0 = f32[5] parameter(0) + p1 = f32[5] parameter(1) + crs0 = f32[5] all-reduce(p0), replica_groups={}, to_apply=apply_op + crs1 = f32[5] all-reduce(p1), replica_groups={}, to_apply=apply_op + ROOT out = (f32[5], f32[5]) tuple(f32[5] crs0, f32[5] crs1) + } + )"; + static constexpr int kNumReplicas = 2; + auto config = GetModuleConfigForTest(); + config.set_replica_count(kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string, config)); + + std::vector input0_vec = {1., 2., 3., 4., 5.}; + auto input0_literal = LiteralUtil::CreateR1(input0_vec); + std::vector input1_vec = {7., 3., 4., 1., 2.}; + auto input1_literal = LiteralUtil::CreateR1(input1_vec); + + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), {&input0_literal, &input1_literal}, + /*num_replicas=*/kNumReplicas, + /*use_threads=*/true)); + std::vector expected0_vec = {2., 4., 6., 8., 10.}; + auto expected0_literal = LiteralUtil::CreateR1(expected0_vec); + std::vector expected1_vec = {14., 6., 8., 2., 4.}; + auto expected1_literal = LiteralUtil::CreateR1(expected1_vec); + for (int replica_idx = 0; replica_idx < kNumReplicas; replica_idx++) { + auto rs = results[replica_idx].DecomposeTuple(); + EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected0_literal, rs[0], + ErrorSpec{1e-5, 1e-5})); + EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected1_literal, rs[1], + ErrorSpec{1e-5, 1e-5})); + } +} + // Runs an all-reduce with three partitions: // {0}, {1,2}, {3} // meaning, the all-reduce is a nop for devices 0 and 3, and only devices 1 and diff --git a/tensorflow/compiler/xla/tests/filecheck.cc b/tensorflow/compiler/xla/tests/filecheck.cc index 91d1052fc64..068d6dc8fca 100644 --- a/tensorflow/compiler/xla/tests/filecheck.cc +++ b/tensorflow/compiler/xla/tests/filecheck.cc @@ -22,36 +22,35 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/path.h" +#include "tensorflow/core/platform/resource_loader.h" #include "tensorflow/core/platform/subprocess.h" namespace xla { StatusOr RunFileCheck(const std::string& input, absl::string_view pattern) { - using tensorflow::io::JoinPath; - // Generate an input file for the FileCheck pattern. - string pattern_path; + std::string pattern_path; auto env = tensorflow::Env::Default(); if (!env->LocalTempFilename(&pattern_path)) { return tensorflow::errors::Internal("couldn't get a pattern file name"); } TF_RETURN_IF_ERROR(tensorflow::WriteStringToFile(env, pattern_path, pattern)); + return RunFileCheckWithPatternFile(input, pattern_path); +} + +StatusOr RunFileCheckWithPatternFile(const std::string& input, + const std::string& pattern_file) { // Invoke FileCheck to check whether input matches `pattern`. - const char* file_check_path_suffix = - "org_tensorflow/external/llvm-project/llvm/FileCheck"; - string file_check_path; - if (const char* test_srcdir = getenv("TEST_SRCDIR")) { - file_check_path = JoinPath(test_srcdir, file_check_path_suffix); - } else { - file_check_path = file_check_path_suffix; - } + std::string file_check_path = tensorflow::GetDataDependencyFilepath( + tensorflow::io::JoinPath("external", "llvm-project", "llvm", "FileCheck")); tensorflow::SubProcess file_check_process; file_check_process.SetProgram( file_check_path, - {file_check_path, "-v", "-dump-input=fail", pattern_path}); + {file_check_path, "-v", "-dump-input=fail", pattern_file}); file_check_process.SetChannelAction(tensorflow::CHAN_STDIN, tensorflow::ACTION_PIPE); file_check_process.SetChannelAction(tensorflow::CHAN_STDERR, @@ -60,7 +59,7 @@ StatusOr RunFileCheck(const std::string& input, return tensorflow::errors::Internal("couldn't start FileCheck"); } - string standard_error; + std::string standard_error; int exit_status = file_check_process.Communicate( /*stdin_input=*/&input, /*stdout_output=*/nullptr, /*stderr_output=*/&standard_error); @@ -68,6 +67,7 @@ StatusOr RunFileCheck(const std::string& input, // FileCheck returns 0 when the inputs match. If matching failed, log // the error message generated by FileCheck and the inputs. bool succeeded = (exit_status == 0); + auto env = tensorflow::Env::Default(); if (!succeeded) { LOG(WARNING) << "Tried to execute FileCheck at " << file_check_path; if (!env->FileExists(file_check_path).ok()) { @@ -75,8 +75,6 @@ StatusOr RunFileCheck(const std::string& input, } LOG(WARNING) << "FileCheck error:\n" << standard_error; - LOG(WARNING) << "FileCheck pattern was:"; - XLA_LOG_LINES(tensorflow::WARNING, pattern); } else if (!standard_error.empty()) { LOG(INFO) << "FileCheck stderr:"; XLA_LOG_LINES(tensorflow::INFO, standard_error); diff --git a/tensorflow/compiler/xla/tests/filecheck.h b/tensorflow/compiler/xla/tests/filecheck.h index 23f71c11b78..2723ccc2e9d 100644 --- a/tensorflow/compiler/xla/tests/filecheck.h +++ b/tensorflow/compiler/xla/tests/filecheck.h @@ -26,7 +26,14 @@ namespace xla { // Runs FileCheck with the given pattern over given input string. Provided that // FileCheck can execute, returns true if and only if FileCheck succeeded in // matching the input. -StatusOr RunFileCheck(const string& input, absl::string_view pattern); +StatusOr RunFileCheck(const std::string& input, + absl::string_view pattern); + +// Runs FileCheck with the given pattern file over given input string. Provided +// that FileCheck can execute, returns true if and only if FileCheck succeeded +// in matching the input. +StatusOr RunFileCheckWithPatternFile(const std::string& input, + const std::string& pattern_file); } // namespace xla diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index 4dd59cdca5d..bb82193ae33 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/literal_comparison.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -30,10 +31,7 @@ void WriteLiteralToTempFile(const LiteralSlice& literal, const string& name) { // TEST_UNDECLARED_OUTPUTS_DIR. This plays well with tools that inspect test // results, especially when they're run on remote machines. string outdir; - const char* undeclared_outputs_dir = getenv("TEST_UNDECLARED_OUTPUTS_DIR"); - if (undeclared_outputs_dir != nullptr) { - outdir = undeclared_outputs_dir; - } else { + if (!tensorflow::io::GetTestUndeclaredOutputsDir(&outdir)) { outdir = tensorflow::testing::TmpDir(); } diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc index 66373af5686..e2ad5a7e08f 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -129,10 +130,7 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { tensorflow::Env* env = tensorflow::Env::Default(); string outdir; - const char* undeclared_outputs_dir = getenv("TEST_UNDECLARED_OUTPUTS_DIR"); - if (undeclared_outputs_dir != nullptr) { - outdir = undeclared_outputs_dir; - } else { + if (!tensorflow::io::GetTestUndeclaredOutputsDir(&outdir)) { outdir = tensorflow::testing::TmpDir(); } string pattern = tensorflow::io::JoinPath(outdir, "tempfile-*.pb"); diff --git a/tensorflow/compiler/xla/tests/sample_file_test.cc b/tensorflow/compiler/xla/tests/sample_file_test.cc index 31b104f4e37..d793dfc7960 100644 --- a/tensorflow/compiler/xla/tests/sample_file_test.cc +++ b/tensorflow/compiler/xla/tests/sample_file_test.cc @@ -25,6 +25,8 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/path.h" +#include "tensorflow/core/platform/resource_loader.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -41,10 +43,10 @@ class SampleFileTest : public HloTestBase { }; TEST_F(SampleFileTest, Convolution) { - const string& filename = "compiler/xla/tests/isolated_convolution.hlo"; - string test_srcdir = tensorflow::testing::TensorFlowSrcRoot(); - EXPECT_TRUE(RunAndCompareFromFile( - tensorflow::io::JoinPath(test_srcdir, filename), ErrorSpec{0.01})); + const string& filename = tensorflow::GetDataDependencyFilepath( + tensorflow::io::JoinPath("tensorflow", "compiler", "xla", "tests", + "isolated_convolution.hlo")); + EXPECT_TRUE(RunAndCompareFromFile(filename, ErrorSpec{0.01})); } } // namespace diff --git a/tensorflow/compiler/xla/text_literal_writer_test.cc b/tensorflow/compiler/xla/text_literal_writer_test.cc index 5cbaf2fcc19..667d6296117 100644 --- a/tensorflow/compiler/xla/text_literal_writer_test.cc +++ b/tensorflow/compiler/xla/text_literal_writer_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" @@ -35,12 +36,12 @@ TEST(TextLiteralWriterTest, WritesFloatLiteral) { {3.14, 2.17}, {1.23, 4.56}, }); - string path = - tensorflow::io::JoinPath(tensorflow::testing::TmpDir(), "/whatever"); + string path; + ASSERT_TRUE(tensorflow::Env::Default()->LocalTempFilename(&path)); ASSERT_IS_OK(TextLiteralWriter::WriteToPath(literal, path)); string contents; - TF_CHECK_OK(tensorflow::ReadFileToString(tensorflow::Env::Default(), path, - &contents)); + TF_ASSERT_OK(tensorflow::ReadFileToString(tensorflow::Env::Default(), path, + &contents)); const string expected = R"(f32[2,2] (0, 0): 3.14 (0, 1): 2.17 diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index a6c1b80ff54..5002f80c059 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -160,6 +160,7 @@ package_group( "//learning/freud/topic_models/tensorflow/...", "//perftools/accelerators/xprof/api/...", "//quality/webanswers/brain/tokenization/custom_tf_ops/kernels/...", + "//smartass/brain/server/...", ], ) @@ -471,6 +472,7 @@ tf_cuda_library( "//tensorflow/core/framework:memory_types.h", "//tensorflow/core/framework:node_def_builder.h", "//tensorflow/core/framework:node_def_util.h", + "//tensorflow/core/framework:node_properties.h", "//tensorflow/core/framework:numeric_op.h", "//tensorflow/core/framework:numeric_types.h", "//tensorflow/core/framework:op.h", @@ -1995,7 +1997,7 @@ cc_library( "//tensorflow/core/util:env_var", "//tensorflow/core/util:reporter", # TODO(gunan): REMOVE as soon as cc_shared_library is supported. "@snappy", - "@zlib_archive//:zlib", + "@zlib", "@double_conversion//:double-conversion", "@com_google_protobuf//:protobuf", ] + tf_protos_all_impl() + tf_protos_grappler_impl() + tf_protos_profiler_impl(), @@ -2322,6 +2324,7 @@ tf_cuda_library( "//tensorflow/core/framework:bfloat16", "//tensorflow/core/framework:common_shape_fns", "//tensorflow/core/framework:node_def_util", + "//tensorflow/core/framework:node_properties", "//tensorflow/core/framework:numeric_types", "//tensorflow/core/framework:op", "//tensorflow/core/framework:op_def_builder", @@ -3074,7 +3077,7 @@ tf_cc_tests( "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:optional", - "@zlib_archive//:zlib", + "@zlib", ], ) diff --git a/tensorflow/core/api_def/base_api/api_def_QuantizeAndDequantizeV2Grad.pbtxt b/tensorflow/core/api_def/base_api/api_def_QuantizeAndDequantizeV2Grad.pbtxt new file mode 100644 index 00000000000..6a7a2f38897 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_QuantizeAndDequantizeV2Grad.pbtxt @@ -0,0 +1,8 @@ +op { + graph_op_name: "QuantizeAndDequantizeV2Grad" + summary: "Returns the gradient of `QuantizeAndDequantizeV2`." + description: <