diff --git a/.bazelrc b/.bazelrc index 30c138e07a4..9ac5a1bbf40 100644 --- a/.bazelrc +++ b/.bazelrc @@ -238,6 +238,10 @@ build:linux --copt=-w build:macos --copt=-w build:windows --copt=/w +# Tensorflow uses M_* math constants that only get defined by MSVC headers if +# _USE_MATH_DEFINES is defined. +build:windows --copt=/D_USE_MATH_DEFINES + # Default paths for TF_SYSTEM_LIBS build:linux --define=PREFIX=/usr build:linux --define=LIBDIR=$(PREFIX)/lib @@ -258,9 +262,8 @@ build:windows --host_cxxopt=/std:c++14 # On windows, we still link everything into a single DLL. build:windows --config=monolithic -# On linux and macos, we dynamically link small amount of kernels +# On linux, we dynamically link small amount of kernels build:linux --config=dynamic_kernels -build:macos --config=dynamic_kernels # Make sure to include as little of windows.h as possible build:windows --copt=-DWIN32_LEAN_AND_MEAN @@ -378,9 +381,9 @@ build:rbe_linux_py3 --python_path="/usr/bin/python3" build:rbe_linux_py3 --repo_env=TF_PYTHON_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/py3" build:rbe_win --config=rbe -build:rbe_win --crosstool_top="@org_tensorflow//third_party/toolchains/preconfig/win_1803/bazel_026:toolchain" +build:rbe_win --crosstool_top="@org_tensorflow//third_party/toolchains/preconfig/win_1803/bazel_121:toolchain" build:rbe_win --extra_execution_platforms="@org_tensorflow//third_party/toolchains/preconfig/win_1803:rbe_windows_1803" -build:rbe_win --extra_toolchains="@org_tensorflow//third_party/toolchains/preconfig/win_1803/bazel_026:cc-toolchain-x64_windows" +build:rbe_win --extra_toolchains="@org_tensorflow//third_party/toolchains/preconfig/win_1803/bazel_121:cc-toolchain-x64_windows" build:rbe_win --host_javabase="@org_tensorflow//third_party/toolchains/preconfig/win_1803:windows_jdk8" build:rbe_win --host_platform="@org_tensorflow//third_party/toolchains/preconfig/win_1803:rbe_windows_1803" build:rbe_win --javabase="@org_tensorflow//third_party/toolchains/preconfig/win_1803:windows_jdk8" diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 756b7f06eb3..b4dc0e73975 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -72,7 +72,7 @@ TensorFlow coding style. [tensorflow/core](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core) and [tensorflow/python](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python). - TensorFlow has reached version 1 and hence cannot make + TensorFlow has passed version 1.0 and hence cannot make non-backward-compatible API changes without a major release. Reviewers of your pull request will comment on any API compatibility issues. * When you contribute a new feature to TensorFlow, the maintenance burden is diff --git a/README.md b/README.md index 58775b1d6d9..56baa0740c3 100644 --- a/README.md +++ b/README.md @@ -37,18 +37,18 @@ See the [TensorFlow install guide](https://www.tensorflow.org/install) for the [Docker container](https://www.tensorflow.org/install/docker), and [build from source](https://www.tensorflow.org/install/source). -To install the current release for CPU-only: +To install the current release, which includes support for +[CUDA-enabled GPU cards](https://www.tensorflow.org/install/gpu) *(Ubuntu and +Windows)*: ``` $ pip install tensorflow ``` -Use the GPU package for -[CUDA-enabled GPU cards](https://www.tensorflow.org/install/gpu) *(Ubuntu and -Windows)*: +A smaller CPU-only package is also available: ``` -$ pip install tensorflow-gpu +$ pip install tensorflow-cpu ``` To update TensorFlow to the latest version, add `--upgrade` flag to the above @@ -56,7 +56,7 @@ commands. *Nightly binaries are available for testing using the [tf-nightly](https://pypi.python.org/pypi/tf-nightly) and -[tf-nightly-gpu](https://pypi.python.org/pypi/tf-nightly-gpu) packages on PyPi.* +[tf-nightly-cpu](https://pypi.python.org/pypi/tf-nightly-cpu) packages on PyPi.* #### *Try your first TensorFlow program* @@ -150,17 +150,3 @@ Learn more about the ## License [Apache License 2.0](LICENSE) - -## Feature Prioritization Survey - -The TensorFlow team is working on building/improving features, and understands -that it is very important to prioritize these efforts based on what TF users -need. - -The goal of this short, < 5min -[survey](https://google.qualtrics.com/jfe/form/SV_d5nqhCEbkDkQ7ad), is to help -the TensorFlow team better understand what features to prioritize based on your -feedback. Participation is of course optional. - -Take the survey -[HERE](https://google.qualtrics.com/jfe/form/SV_d5nqhCEbkDkQ7ad). diff --git a/configure.py b/configure.py index 93c386240ce..b98cc9fdccc 100644 --- a/configure.py +++ b/configure.py @@ -147,14 +147,16 @@ def write_action_env_to_bazelrc(var_name, var): write_to_bazelrc('build --action_env %s="%s"' % (var_name, str(var))) -def run_shell(cmd, allow_non_zero=False): +def run_shell(cmd, allow_non_zero=False, stderr=None): + if stderr is None: + stderr = sys.stdout if allow_non_zero: try: - output = subprocess.check_output(cmd) + output = subprocess.check_output(cmd, stderr=stderr) except subprocess.CalledProcessError as e: output = e.output else: - output = subprocess.check_output(cmd) + output = subprocess.check_output(cmd, stderr=stderr) return output.decode('UTF-8').strip() @@ -169,10 +171,12 @@ def get_python_path(environ_cp, python_bin_path): if environ_cp.get('PYTHONPATH'): python_paths = environ_cp.get('PYTHONPATH').split(':') try: + stderr = open(os.devnull, 'wb') library_paths = run_shell([ python_bin_path, '-c', 'import site; print("\\n".join(site.getsitepackages()))' - ]).split('\n') + ], + stderr=stderr).split('\n') except subprocess.CalledProcessError: library_paths = [ run_shell([ @@ -1179,10 +1183,17 @@ def system_specific_test_config(env): write_to_bazelrc('test --test_env=LD_LIBRARY_PATH') else: test_and_build_filters.append('-gpu') - write_to_bazelrc('test --test_tag_filters=%s' % + + # Disable tests with "v1only" tag in "v2" Bazel config, but not in "v1" config + write_to_bazelrc('test:v1 --test_tag_filters=%s' % ','.join(test_and_build_filters + test_only_filters)) - write_to_bazelrc('test --build_tag_filters=%s' % + write_to_bazelrc('test:v1 --build_tag_filters=%s' % ','.join(test_and_build_filters)) + write_to_bazelrc( + 'test:v2 --test_tag_filters=%s' % + ','.join(test_and_build_filters + test_only_filters + ['-v1only'])) + write_to_bazelrc('test:v2 --build_tag_filters=%s' % + ','.join(test_and_build_filters + ['-v1only'])) def set_system_libs_flag(environ_cp): diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 081edb21ae1..d8a681c3999 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -860,7 +860,7 @@ gen_api_init_files( output_files = TENSORFLOW_API_INIT_FILES_V1, output_package = "tensorflow._api.v1", root_file_name = "v1.py", - root_init_template = "api_template_v1.__init__.py", + root_init_template = "$(location api_template_v1.__init__.py)", ) gen_api_init_files( @@ -883,7 +883,7 @@ gen_api_init_files( output_files = TENSORFLOW_API_INIT_FILES_V2, output_package = "tensorflow._api.v2", root_file_name = "v2.py", - root_init_template = "api_template.__init__.py", + root_init_template = "$(location api_template.__init__.py)", ) py_library( diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index c515cc76b9a..a8cd6d1782c 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -89,6 +89,7 @@ except ImportError: # Enable TF2 behaviors from tensorflow.python.compat import v2_compat as _compat # pylint: disable=g-import-not-at-top _compat.enable_v2_behavior() +_major_api_version = 2 # Load all plugin libraries from site-packages/tensorflow-plugins if we are @@ -119,8 +120,14 @@ def _running_from_pip_package(): _current_file_location.startswith(dir_) for dir_ in _site_packages_dirs) if _running_from_pip_package(): + # TODO(gunan): Add sanity checks to loaded modules here. for _s in _site_packages_dirs: - # TODO(gunan): Add sanity checks to loaded modules here. + # Load first party dynamic kernels. + _main_dir = _os.path.join(_s, 'tensorflow_core/core/kernels') + if _fi.file_exists(_main_dir): + _ll.load_library(_main_dir) + + # Load third party dynamic kernels. _plugin_dir = _os.path.join(_s, 'tensorflow-plugins') if _fi.file_exists(_plugin_dir): _ll.load_library(_plugin_dir) diff --git a/tensorflow/api_template_v1.__init__.py b/tensorflow/api_template_v1.__init__.py index 2b2899c3fe0..b6b5e36f0d5 100644 --- a/tensorflow/api_template_v1.__init__.py +++ b/tensorflow/api_template_v1.__init__.py @@ -104,6 +104,8 @@ from tensorflow.python.platform import flags # pylint: disable=g-import-not-at- _current_module.app.flags = flags # pylint: disable=undefined-variable setattr(_current_module, "flags", flags) +_major_api_version = 1 + # Load all plugin libraries from site-packages/tensorflow-plugins if we are # running under pip. # TODO(gunan): Enable setting an environment variable to define arbitrary plugin @@ -132,8 +134,14 @@ def _running_from_pip_package(): _current_file_location.startswith(dir_) for dir_ in _site_packages_dirs) if _running_from_pip_package(): + # TODO(gunan): Add sanity checks to loaded modules here. for _s in _site_packages_dirs: - # TODO(gunan): Add sanity checks to loaded modules here. + # Load first party dynamic kernels. + _main_dir = _os.path.join(_s, 'tensorflow_core/core/kernels') + if _fi.file_exists(_main_dir): + _ll.load_library(_main_dir) + + # Load third party dynamic kernels. _plugin_dir = _os.path.join(_s, 'tensorflow-plugins') if _fi.file_exists(_plugin_dir): _ll.load_library(_plugin_dir) diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index efe01f7e049..76a02090c3b 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -53,6 +53,20 @@ filegroup( visibility = ["//visibility:public"], ) +filegroup( + name = "pywrap_eager_hdrs", + srcs = [ + "c_api_internal.h", + "tf_status_helper.h", + "tf_status_internal.h", + "tf_tensor_internal.h", + ], + visibility = [ + "//tensorflow/core:__pkg__", + "//tensorflow/python:__pkg__", + ], +) + tf_cuda_library( name = "c_api_internal", hdrs = [ diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 130e9a0c3c7..92e994183a2 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -88,6 +88,18 @@ tf_cuda_library( alwayslink = 1, ) +filegroup( + name = "pywrap_eager_hdrs", + srcs = [ + "c_api_experimental.h", + "c_api_internal.h", + ], + visibility = [ + "//tensorflow/core:__pkg__", + "//tensorflow/python:__pkg__", + ], +) + tf_cuda_library( name = "c_api_internal", srcs = ["c_api_experimental.h"], diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 66a2a4aaa3c..c1aa187876f 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -464,7 +464,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( &new_remote_device_mgr)); remote_device_mgr = new_remote_device_mgr.get(); } else { - ctx->context->ClearCaches(); + ctx->context->ClearCachesAndDefaultExecutor(); // TODO(b/143914772): Potential memory leak if rendezvous has pending // tensors for removed / replaced workers. @@ -754,7 +754,9 @@ TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) { return list; } -void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context->ClearCaches(); } +void TFE_ContextClearCaches(TFE_Context* ctx) { + ctx->context->ClearCachesAndThreadExecutors(); +} // Set server_def on the context, possibly updating it. TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, diff --git a/tensorflow/c/eager/c_api_internal.cc b/tensorflow/c/eager/c_api_internal.cc index f6092715e17..4f3de479ba7 100644 --- a/tensorflow/c/eager/c_api_internal.cc +++ b/tensorflow/c/eager/c_api_internal.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/eager/c_api_internal.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/host_info.h" TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name, @@ -26,29 +27,22 @@ TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name, if (!status->status.ok()) { return nullptr; } - auto create_or_reset = - [&op_to_reset, &ctx, &name, &types, &raw_device_name, &status]( - bool is_function, TFE_OpInferenceContext* inference_ctx) -> TFE_Op* { - if (op_to_reset) { - status->status = op_to_reset->Reset(ctx, name, is_function, types, - raw_device_name, inference_ctx); - return op_to_reset; - } else { - TFE_Op* new_op = new TFE_Op(ctx, name, is_function, types, inference_ctx); - status->status = new_op->operation.SetDeviceName(raw_device_name); - return new_op; - } - }; + if (op_to_reset && op_to_reset->ctx != ctx) { + status->status = tensorflow::errors::Internal( + "Cannot reset a TFE_Op from another TFE_Context"); + return nullptr; + } + + std::unique_ptr inference_ctx; if (!is_function) { const tensorflow::OpDef* op_def; status->status = tensorflow::OpDefForOp(op_or_function_name, &op_def); if (!status->status.ok()) { return nullptr; } - return create_or_reset(false, new TFE_OpInferenceContext(op_def)); - } - if (!ctx->context->FindFunctionByName(name)) { + inference_ctx.reset(new TFE_OpInferenceContext(op_def)); + } else if (!ctx->context->FindFunctionByName(name)) { status->status = tensorflow::errors::NotFound( "'", name, "' is neither a type of a primitive operation nor a name " @@ -58,5 +52,15 @@ TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name, "registered in the binary running in this process."); return nullptr; } - return create_or_reset(true, nullptr); + + if (op_to_reset) { + status->status = op_to_reset->Reset( + name, is_function, types, raw_device_name, std::move(inference_ctx)); + return op_to_reset; + } + + TFE_Op* new_op = + new TFE_Op(ctx, name, is_function, types, std::move(inference_ctx)); + status->status = new_op->operation.SetDeviceName(raw_device_name); + return new_op; } diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 29106e2998d..df192913b72 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -125,24 +125,26 @@ struct TFE_OpInferenceContext { struct TFE_Op { TFE_Op(TFE_Context* ctx, const char* op, bool is_function, const tensorflow::AttrTypeMap* t, - TFE_OpInferenceContext* inference_ctx) - : operation(ctx->context, op, is_function, t), - inference_ctx(inference_ctx) {} + std::unique_ptr inference_ctx) + : ctx(ctx), + operation(ctx->context, op, is_function, t), + inference_ctx(std::move(inference_ctx)) {} void Clear() { operation.Clear(); inference_ctx.reset(); } - tensorflow::Status Reset(TFE_Context* ctx, const char* op, bool is_function, + tensorflow::Status Reset(const char* op, bool is_function, const tensorflow::AttrTypeMap* t, const char* raw_device_name, - TFE_OpInferenceContext* infer_ctx) { - inference_ctx.reset(infer_ctx); + std::unique_ptr infer_ctx) { + inference_ctx = std::move(infer_ctx); return operation.Reset(ctx->context, op, is_function, t, raw_device_name, nullptr); } + TFE_Context* ctx; tensorflow::EagerOperation operation; std::unique_ptr inference_ctx; }; diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index 70dadf79dbe..a9f429b8bd3 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -233,6 +233,7 @@ cc_library_with_android_deps( deps = [ "//tensorflow/core:core_cpu", "//tensorflow/core:lib", + "//tensorflow/core:lib_experimental", "//tensorflow/core:protos_all_cc", ], ) diff --git a/tensorflow/cc/client/client_session.cc b/tensorflow/cc/client/client_session.cc index 97038b2ba86..c4add1589e7 100644 --- a/tensorflow/cc/client/client_session.cc +++ b/tensorflow/cc/client/client_session.cc @@ -127,6 +127,33 @@ Status ClientSession::Run(const RunOptions& run_options, const FeedType& inputs, target_node_names, outputs, run_metadata); } +Status ClientSession::Run( + const RunOptions& run_options, const FeedType& inputs, + const std::vector& fetch_outputs, + const std::vector& run_outputs, std::vector* outputs, + RunMetadata* run_metadata, + const thread::ThreadPoolOptions& threadpool_options) const { + std::vector> feeds; + for (auto const& feed : inputs) { + TF_RETURN_IF_ERROR(feed.second.status); + feeds.emplace_back(feed.first.name(), feed.second.tensor); + } + std::vector output_tensor_names; + output_tensor_names.reserve(fetch_outputs.size()); + for (auto const& output : fetch_outputs) { + output_tensor_names.push_back(output.name()); + } + std::vector target_node_names; + target_node_names.reserve(run_outputs.size()); + for (auto const& output : run_outputs) { + target_node_names.push_back(output.node()->name()); + } + TF_RETURN_IF_ERROR(impl()->MaybeExtendGraph()); + return impl()->session_->Run(run_options, feeds, output_tensor_names, + target_node_names, outputs, run_metadata, + threadpool_options); +} + Status ClientSession::MakeCallable(const CallableOptions& callable_options, CallableHandle* out_handle) { TF_RETURN_IF_ERROR(impl()->MaybeExtendGraph()); diff --git a/tensorflow/cc/client/client_session.h b/tensorflow/cc/client/client_session.h index b0bb6c3fa6d..3765eaec9bf 100644 --- a/tensorflow/cc/client/client_session.h +++ b/tensorflow/cc/client/client_session.h @@ -93,6 +93,14 @@ class ClientSession { const std::vector& run_outputs, std::vector* outputs, RunMetadata* run_metadata) const; + /// Same as above. Additionally allows user to provide custom threadpool + /// implementation via ThreadPoolOptions. + Status Run(const RunOptions& run_options, const FeedType& inputs, + const std::vector& fetch_outputs, + const std::vector& run_outputs, + std::vector* outputs, RunMetadata* run_metadata, + const thread::ThreadPoolOptions& threadpool_options) const; + /// \brief A handle to a subgraph, created with /// `ClientSession::MakeCallable()`. typedef int64 CallableHandle; diff --git a/tensorflow/cc/client/client_session_test.cc b/tensorflow/cc/client/client_session_test.cc index 3d1c38483c4..27ec4c0871d 100644 --- a/tensorflow/cc/client/client_session_test.cc +++ b/tensorflow/cc/client/client_session_test.cc @@ -112,7 +112,7 @@ TEST(ClientSessionTest, Extend) { test::ExpectTensorEqual(outputs[0], test::AsTensor({31, 42}, {2})); } -TEST(ClientSessionTest, MultiThreaded) { +TEST(ClientSessionTest, MultiThreadedWithDefaultThreadpool) { Scope root = Scope::NewRootScope(); auto a = Add(root, {1, 2}, {3, 4}); auto b = Mul(root, {1, 2}, {3, 4}); @@ -138,6 +138,49 @@ TEST(ClientSessionTest, MultiThreaded) { test::ExpectTensorEqual(outputs[0], test::AsTensor({-1, 2}, {2})); } +TEST(ClientSessionTest, MultiThreadedWithCustomThreadpool) { + Scope root = Scope::NewRootScope(); + int num_threads = 3; + auto a = Add(root, {1, 2}, {3, 4}); + auto b = Mul(root, {1, 2}, {3, 4}); + ClientSession session(root); + + auto inter_op_threadpool = + absl::make_unique(num_threads); + ASSERT_EQ(inter_op_threadpool->GetNumScheduleCalled(), 0); + + auto intra_op_threadpool = + absl::make_unique(num_threads); + ASSERT_EQ(intra_op_threadpool->GetNumScheduleCalled(), 0); + + tensorflow::thread::ThreadPoolOptions threadPoolOptions; + threadPoolOptions.inter_op_threadpool = inter_op_threadpool.get(); + threadPoolOptions.intra_op_threadpool = intra_op_threadpool.get(); + + { + thread::ThreadPool thread_pool(Env::Default(), "pool", 2); + thread_pool.Schedule([&session, a]() { + std::vector outputs; + TF_EXPECT_OK(session.Run(RunOptions(), ClientSession::FeedType{}, {a}, {}, + &outputs, nullptr, thread::ThreadPoolOptions())); + test::ExpectTensorEqual(outputs[0], + test::AsTensor({4, 6}, {2})); + }); + thread_pool.Schedule([&session, b]() { + std::vector outputs; + TF_EXPECT_OK(session.Run(RunOptions(), ClientSession::FeedType{}, {b}, {}, + &outputs, nullptr, thread::ThreadPoolOptions())); + test::ExpectTensorEqual(outputs[0], + test::AsTensor({3, 8}, {2})); + }); + } + auto c = Sub(root, b, a); + std::vector outputs; + TF_EXPECT_OK(session.Run(RunOptions(), ClientSession::FeedType{}, {c}, {}, + &outputs, nullptr, thread::ThreadPoolOptions())); + test::ExpectTensorEqual(outputs[0], test::AsTensor({-1, 2}, {2})); +} + TEST(ClientSessionTest, CallableWithDefaultThreadPool) { Scope root = Scope::NewRootScope(); auto a = Placeholder(root, DT_INT32); diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc index b3c1e6a913a..f67c6f91d6c 100644 --- a/tensorflow/cc/gradients/math_grad.cc +++ b/tensorflow/cc/gradients/math_grad.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#define _USE_MATH_DEFINES #include #include "tensorflow/cc/ops/array_ops_internal.h" diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index 71709e40b36..b64f0f55417 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -125,11 +125,11 @@ cc_library( deps = [ ":constants", "@com_google_absl//absl/container:flat_hash_set", - "//tensorflow/core/platform:strcat", - "//tensorflow/core/util/tensor_bundle", ] + if_not_mobile([ "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:strcat", + "//tensorflow/core/util/tensor_bundle", ]), ) diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index b6e260f00a5..a17ad6d27a9 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -75,8 +75,8 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "@com_google_absl//absl/strings", - "@llvm//:support", # fixdeps: keep - "@llvm//:x86_code_gen", # fixdeps: keep + "@llvm-project//llvm:support", # fixdeps: keep + "@llvm-project//llvm:x86_code_gen", # fixdeps: keep ], ) @@ -104,11 +104,11 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/strings", - "@llvm//:aarch64_code_gen", # fixdeps: keep - "@llvm//:arm_code_gen", # fixdeps: keep - "@llvm//:powerpc_code_gen", # fixdeps: keep - "@llvm//:target", - "@llvm//:x86_code_gen", # fixdeps: keep + "@llvm-project//llvm:aarch64_code_gen", # fixdeps: keep + "@llvm-project//llvm:arm_code_gen", # fixdeps: keep + "@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep + "@llvm-project//llvm:target", + "@llvm-project//llvm:x86_code_gen", # fixdeps: keep ], ) @@ -205,9 +205,9 @@ cc_library( "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@llvm//:core", - "@llvm//:support", - "@llvm//:target", + "@llvm-project//llvm:core", + "@llvm-project//llvm:support", + "@llvm-project//llvm:target", ], ) diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index fb81266a048..c8bbb1a473c 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -407,6 +407,7 @@ def target_llvm_triple(): "//tensorflow:android_arm64": "aarch64-none-android", "//tensorflow:android_x86": "i686-none-android", "//tensorflow:ios": "arm64-none-ios", + "//tensorflow:ios_x86_64": "x86_64-apple-ios", "//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu", "//tensorflow:macos": "x86_64-none-darwin", "//conditions:default": "x86_64-pc-linux", diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 4526090d68a..15e53b7be67 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -500,6 +500,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:graph", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", ], ) diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.h b/tensorflow/compiler/jit/build_xla_ops_pass.h index 58f7c4b3a0d..8487802aa66 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass.h +++ b/tensorflow/compiler/jit/build_xla_ops_pass.h @@ -22,8 +22,9 @@ limitations under the License. namespace tensorflow { -// Adds _XlaCompile and _XlaRun operations to the TF graph that compiles and -// executes (using XLA) TF function calls marked with "_XlaCompiledKernel". +// Replaces TF function calls marked with `_XlaCompiledKernel` with _XlaCompile +// and _XlaRun nodes (which compile and launch, respectively, the corresponding +// HLO module). class BuildXlaOpsPass : public GraphOptimizationPass { public: // If enable_lazy_compilation is not nullopt then *enable_lazy_compilation diff --git a/tensorflow/compiler/jit/defs.cc b/tensorflow/compiler/jit/defs.cc index b23f6ec35f5..4bea71e8fc1 100644 --- a/tensorflow/compiler/jit/defs.cc +++ b/tensorflow/compiler/jit/defs.cc @@ -17,6 +17,8 @@ limitations under the License. namespace tensorflow { +const char* const kXlaMustCompileAttr = "_XlaMustCompile"; + const char* const kXlaCompileAttr = "_XlaCompile"; // User-provided through jit_scope APIs. Effective only when auto_jit is OFF. diff --git a/tensorflow/compiler/jit/defs.h b/tensorflow/compiler/jit/defs.h index bf8009344df..9eb4c2ca2e8 100644 --- a/tensorflow/compiler/jit/defs.h +++ b/tensorflow/compiler/jit/defs.h @@ -22,7 +22,16 @@ limitations under the License. namespace tensorflow { // Name of attribute used to tag operators for compilation with XLA + +// Implies must-compile semantics: either it will be compiled +// with XLA, or an error will be thrown. +extern const char* const kXlaMustCompileAttr; // "_XlaMustCompile" + +// Implies auto-clustering: tagged nodes will be clustered and compiled with XLA +// on a best-effort basis. extern const char* const kXlaCompileAttr; // "_XlaCompile" + +// Implies auto-clustering within the given scope. extern const char* const kXlaScopeAttr; // "_XlaScope" extern const char* const kXlaInternalScopeAttr; // "_XlaInternalScope" diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h index 8b627cd959a..bf8b2c41e0e 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h @@ -27,6 +27,15 @@ limitations under the License. namespace tensorflow { +// EncapsulateSubgraphs pass takes all the nodes with the same cluster ID +// (derived from kXlaClusterAttr=ID (kXlaClusterAttr) attribute), puts them into +// a TF function, and replaces the subgraph in the main graph with a call to +// that TF function annotated with kXlaCompiledKernelAttr (_XlaCompiledKernel). +class EncapsulateSubgraphsPass : public GraphOptimizationPass { + public: + Status Run(const GraphOptimizationPassOptions& options) override; +}; + // A rewriting function to apply to each subgraph during encapsulation. // 'arg_source_tensors' are the tensors corresponding to the arguments in the // original source graph (*not* 'graph'). @@ -100,11 +109,6 @@ extern const char* const kXlaHasReferenceVarsAttr; // TODO(hpucha): Move the utilities to a more appropriate place. void SortControlInputs(GraphDef* gdef); -class EncapsulateSubgraphsPass : public GraphOptimizationPass { - public: - Status Run(const GraphOptimizationPassOptions& options) override; -}; - } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_SUBGRAPHS_PASS_H_ diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h index 99e9dfd598f..3057e4c7469 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h @@ -28,7 +28,7 @@ #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/platform/env.h" - namespace tensorflow { +namespace tensorflow { // Encapsulates nodes marked with the _xla_compile_id attribute into // XlaLaunch operators. diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc index 90fa15bc29b..277c8dbc594 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc @@ -2130,6 +2130,53 @@ Status ExtractOutsideCompilationForNodesWithAssociatedFunctions( return Status::OK(); } +Status CopyOutsideCompilationConstNodes( + Graph* g, const string& outside_compilation_attr_name) { + for (Node* n : g->op_nodes()) { + if (!n->IsConstant() || + !HasNodeAttr(n->def(), outside_compilation_attr_name)) { + continue; + } + + std::vector out_edges(n->out_edges().begin(), + n->out_edges().end()); + bool has_non_oc_output = false; + for (const Edge* e : out_edges) { + if (!e->IsControlEdge() && + !HasNodeAttr(e->dst()->def(), outside_compilation_attr_name)) { + has_non_oc_output = true; + break; + } + } + if (!has_non_oc_output) { + continue; + } + + NodeDef copy_def = n->def(); + copy_def.set_name(g->NewName(n->name())); + copy_def.mutable_attr()->erase(outside_compilation_attr_name); + Status s; + Node* copy_node = g->AddNode(copy_def, &s); + TF_RETURN_IF_ERROR(s); + for (const Edge* e : n->in_edges()) { + if (e->IsControlEdge()) { + g->AddControlEdge(e->src(), copy_node); + } + } + for (const Edge* e : out_edges) { + if (!e->IsControlEdge() && + !HasNodeAttr(e->dst()->def(), outside_compilation_attr_name)) { + Node* dst = e->dst(); + int dst_input = e->dst_input(); + g->RemoveEdge(e); + g->AddEdge(copy_node, 0, dst, dst_input); + } + } + } + + return Status::OK(); +} + } // namespace Status RewriteOutsideCompilationSubgraphFn::operator()( @@ -2279,6 +2326,10 @@ Status ExtractOutsideCompilationForFunction( std::vector outside_compilation_host_graphs; std::vector shape_inference_graphs_to_rewrite; if (*has_outside_compilation) { + // Copy outside compilation Const nodes with non outside compilation users. + TF_RETURN_IF_ERROR(CopyOutsideCompilationConstNodes( + fbody->graph, outside_compilation_attr_name)); + // Find dependencies between outside compilation clusters. TF_ASSIGN_OR_RETURN(auto cluster_deps, OutsideCompilationClusterDependencies( diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index a5fadc08094..edcec281802 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -1187,7 +1187,7 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() { } if (!whitelist.empty() && !whitelist.contains(node->def().op())) { - VLOG(1) << "Rejecting " << node->name() + VLOG(1) << "Rejecting TF operation " << node->def().op() << " as it is not listed in --tf_xla_ops_to_cluster."; continue; } @@ -1770,9 +1770,10 @@ absl::flat_hash_map>* GetWhitelistTable() { {"ComplexAbs", "Angle", "Conj", "Abs", "Acos", "Acosh", "Asin", "Atan", "Atanh", "Ceil", "Cos", "Cosh", "Sin", "Exp", "Expm1", "Floor", "IsFinite", "IsInf", "IsNan", "Inv", "Reciprocal", "Log", - "Log1p", "Invert", "LogicalNot", "Neg", "Rint", "Round", "Rsqrt", - "Sigmoid", "Sign", "Sinh", "Softplus", "Softsign", "Sqrt", "Square", - "Tan", "Tanh", "Real", "Imag", "Erf", "Erfc", "Lgamma", "Digamma", + "Log1p", "Invert", "LogicalNot", "Ndtri", "Neg", "Rint", "Round", + "Rsqrt", "Sigmoid", "Sign", "Sinh", "Softplus", "Softsign", "Sqrt", + "Square", "Tan", "Tanh", "Real", "Imag", "Erf", "Erfc", "Erfinv", + "Lgamma", "Digamma", // Binary "Add", "AddV2", "Sub", "Mul", "Div", "Atan2", "Complex", "DivNoNan", "MulNoNan", "FloorDiv", "Xlogy", "Xdivy", "FloorMod", "BitwiseAnd", @@ -2035,6 +2036,7 @@ absl::flat_hash_set GetKnownXLAWhitelistOp() { "XlaDynamicSlice", "XlaDynamicUpdateSlice", "XlaEinsum", + "XlaGather", "XlaIf", "XlaKeyValueSort", "XlaPad", @@ -2042,6 +2044,7 @@ absl::flat_hash_set GetKnownXLAWhitelistOp() { "XlaReduce", "XlaReduceWindow", "XlaReplicaId", + "XlaScatter", "XlaSelectAndScatter", "XlaSelfAdjointEig", "XlaSend", diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.h b/tensorflow/compiler/jit/mark_for_compilation_pass.h index 0c9b40776f5..8b660710898 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.h +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.h @@ -34,8 +34,9 @@ extern const char* const kXlaClusterAttr; // compilation by the encapsulate subgraphs pass. extern const char* const kXlaOutsideCompilationAttr; -// Pass that marks a subset of operators in the graph with attribute -// _XlaCluster so they are compiled by the EncapsulateSubgraphsPass. +// Marks a subset of nodes in the graph which are to be clustered +// with an attribute _XlaCluster= so they are picked up by the +// EncapsulateSubgraphsPass. class MarkForCompilationPass : public GraphOptimizationPass { public: MarkForCompilationPass() = default; diff --git a/tensorflow/compiler/jit/shape_inference.cc b/tensorflow/compiler/jit/shape_inference.cc index 2ed085d021f..72804ff57e4 100644 --- a/tensorflow/compiler/jit/shape_inference.cc +++ b/tensorflow/compiler/jit/shape_inference.cc @@ -17,7 +17,10 @@ limitations under the License. #include "tensorflow/compiler/jit/shape_inference_helpers.h" #include "tensorflow/core/common_runtime/shape_refiner.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/util/dump_graph.h" @@ -39,7 +42,7 @@ Status ShapeHandleToTensorShape(shape_inference::InferenceContext* context, return PartialTensorShape::MakePartialShape(dims.data(), dims.size(), shape); } -Status PropagateShapes(const Graph& graph, +Status PropagateShapes(Graph* graph, const std::map& arg_shapes, const std::vector& back_edges, ShapeRefiner* shape_refiner) { @@ -54,7 +57,7 @@ Status PropagateShapes(const Graph& graph, // shapes. // TODO(phawkins): handle cyclic graphs. std::vector order; - GetReversePostOrder(graph, &order); + GetReversePostOrder(*graph, &order); for (Node* n : order) { // Ignore the status returned by the shape_refiner. We want the best effort @@ -99,6 +102,67 @@ Status PropagateShapes(const Graph& graph, } } + // Sometimes we have VariableShape nodes in while loop (after Enter nodes). + // They won't be constant-folded because TensorFlow constant folding does + // not handle Enter nodes (and thus does not handle any nodes after Enter + // nodes). We try to replace such VariableShape nodes with Const nodes here. + if (n->type_string() == "VariableShape") { + shape_inference::InferenceContext* context = shape_refiner->GetContext(n); + auto handle_shapes_and_types = context->input_handle_shapes_and_types(0); + if (handle_shapes_and_types && !handle_shapes_and_types->empty()) { + shape_inference::ShapeHandle handle = + handle_shapes_and_types->at(0).shape; + TensorShapeProto shape_proto; + context->ShapeHandleToProto(handle, &shape_proto); + if (!shape_proto.unknown_rank()) { + NodeDef const_def; + const_def.set_op("Const"); + Node* var_node; + TF_RETURN_IF_ERROR(n->input_node(0, &var_node)); + const_def.set_name( + graph->NewName(absl::StrCat("var_shape_", var_node->name()))); + DataType dtype = n->output_type(0); + AddNodeAttr("dtype", dtype, &const_def); + TensorProto value; + value.set_dtype(dtype); + value.mutable_tensor_shape()->add_dim()->set_size( + shape_proto.dim_size()); + for (const auto& dim : shape_proto.dim()) { + if (dtype == DT_INT32) { + value.add_int_val(dim.size()); + } else { + value.add_int64_val(dim.size()); + } + } + AddNodeAttr("value", value, &const_def); + for (auto const& attr : n->attrs()) { + if (*attr.first.begin() == '_') { + AddNodeAttr(attr.first, attr.second, &const_def); + } + } + + Status s; + Node* const_node = graph->AddNode(const_def, &s); + TF_RETURN_IF_ERROR(s); + + graph->AddControlEdge(var_node, const_node); + std::vector out_edges(n->out_edges().begin(), + n->out_edges().end()); + for (const Edge* e : out_edges) { + if (e->IsControlEdge()) { + graph->AddControlEdge(const_node, e->dst()); + graph->RemoveEdge(e); + } else { + Node* dst = e->dst(); + int dst_input = e->dst_input(); + graph->RemoveEdge(e); + graph->AddEdge(const_node, 0, dst, dst_input); + } + } + } + } + } + // Merge node causes a loop so we remove NextIteration->Merge edge before // performing shape inference. But removing those edges also prevents us // from inferring output shape for Merge node (we need shapes for all its @@ -196,7 +260,7 @@ Status InferShapes(Graph* graph, const std::map& arg_shapes, // the shape inference is complete. BackEdgeHelper back_edge; TF_RETURN_IF_ERROR(back_edge.Remove(graph)); - TF_RETURN_IF_ERROR(PropagateShapes(*graph, arg_shapes, + TF_RETURN_IF_ERROR(PropagateShapes(graph, arg_shapes, back_edge.RemovedEdges(), &shape_refiner)); TF_RETURN_IF_ERROR(back_edge.Replace()); diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 99e95314f64..34ff0c55615 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -191,7 +191,7 @@ class XlaAssignVariableOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("IteratorGetNextAsOptional").Device(DEVICE), \ data::IteratorGetNextAsOptionalOp); \ REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE), \ - data::IteratorGetNextSyncOp); \ + data::IteratorGetNextOp); \ REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle") \ .Device(DEVICE) \ .HostMemory("string_handle"), \ diff --git a/tensorflow/compiler/jit/xla_kernel_creator.cc b/tensorflow/compiler/jit/xla_kernel_creator.cc index e3706a09278..23bd7425dbd 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator.cc +++ b/tensorflow/compiler/jit/xla_kernel_creator.cc @@ -21,7 +21,7 @@ namespace tensorflow { bool XlaKernelCreator::CanCreateKernel(const FunctionLibraryRuntime& flr, const NodeDef& node_def) const { - return CanCreateXlaKernel(flr, node_def); + return CanCreateXlaKernel(node_def); } Status XlaKernelCreator::CreateKernel(FunctionLibraryRuntime* flr, diff --git a/tensorflow/compiler/jit/xla_kernel_creator_test.cc b/tensorflow/compiler/jit/xla_kernel_creator_test.cc index 28606abf2b2..7ec37332906 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator_test.cc +++ b/tensorflow/compiler/jit/xla_kernel_creator_test.cc @@ -95,15 +95,17 @@ AttrValue BoolAttr(bool b) { TEST_F(XlaKernelCreatorTest, OneFloatOneResourceArgument) { FunctionDef fdef = XTimesY(); - (*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(true); + (*fdef.mutable_attr())["_XlaMustCompile"] = BoolAttr(true); Init({fdef}); XlaKernelCreator xla_kernel_creator; - - Status status = xla_kernel_creator.CreateKernel( - flr_, ToNodeDef(R"pb( + NodeDef callsite = + ToNodeDef(R"pb( name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b' - )pb"), - &kernel_); + )pb"); + (*callsite.mutable_attr())["_XlaMustCompile"] = BoolAttr(true); + + // Note: need to set attribute on the created node. + Status status = xla_kernel_creator.CreateKernel(flr_, callsite, &kernel_); ASSERT_TRUE(status.ok()) << status.ToString(); EXPECT_EQ("XTimesY", kernel_->name()); @@ -137,7 +139,7 @@ TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrNotSet) { TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrIsSetToFalse) { FunctionDef fdef = XTimesY(); - (*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(false); + (*fdef.mutable_attr())["_XlaMustCompile"] = BoolAttr(false); Init({fdef}); XlaKernelCreator xla_kernel_creator; diff --git a/tensorflow/compiler/jit/xla_kernel_creator_util.cc b/tensorflow/compiler/jit/xla_kernel_creator_util.cc index 6441dd3ed28..94727fdf35a 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator_util.cc +++ b/tensorflow/compiler/jit/xla_kernel_creator_util.cc @@ -23,7 +23,9 @@ limitations under the License. #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/util/ptr_util.h" @@ -68,40 +70,10 @@ class SinglePassSearch { }; } // namespace -bool CanCreateXlaKernel(const FunctionLibraryRuntime& flr, - const NodeDef& node_def) { - const FunctionDef* function_def = - flr.GetFunctionLibraryDefinition()->Find(node_def.name()); - if (function_def == nullptr) { - // The node def is not calling a function. Individual ops can be - // run directly using on-demand mode, no need to create XlaLaunch - // kernel for them. - return false; - } - - // If kXlaCompileAttr is set on the node_def, use its value. - const auto& it = node_def.attr().find(kXlaCompileAttr); - if (it != node_def.attr().end()) { - return it->second.b(); - } - - // kXlaCompileAttr is not set on node_def, check if it is set on - // FunctionDef. - bool xla_compile = false; - Status status = flr.GetFunctionLibraryDefinition()->GetAttr( - node_def, kXlaCompileAttr, &xla_compile); - if (!status.ok() || !xla_compile) { - if (VLOG_IS_ON(3)) { - if (!status.ok()) { - VLOG(3) << "No " << kXlaCompileAttr << " attr defined for " - << node_def.op() << ". status=" << status.ToString(); - } else { - VLOG(3) << node_def.op() << " is explicitly marked not to be compiled"; - } - } - return false; - } - return true; +bool CanCreateXlaKernel(const NodeDef& node_def) { + // If kXlaMustCompileAttr is set on the node_def, use its value. + const auto& it = node_def.attr().find(kXlaMustCompileAttr); + return it != node_def.attr().end() && it->second.b(); } // Given a FunctionLibraryRuntime and a NodeDef calling a function in the @@ -118,8 +90,11 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle; // If node_def is not instantiable, e.g., the function does not exist, // simply bail out. + NameAttrList function; + TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function)); + TF_RETURN_IF_ERROR( - flr->Instantiate(node_def.op(), AttrSlice(&node_def.attr()), &handle)); + flr->Instantiate(function.name(), AttrSlice(&function.attr()), &handle)); *fbody = flr->GetFunctionBody(handle); CHECK(*fbody); // Can't be nullptr since we just instantiated it. const DataTypeVector& arg_types = (*fbody)->arg_types; @@ -149,7 +124,7 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr, Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def, std::unique_ptr* kernel) { - if (!CanCreateXlaKernel(*flr, node_def)) { + if (!CanCreateXlaKernel(node_def)) { return errors::Internal("Invalid node: ", node_def.ShortDebugString()); } @@ -241,9 +216,7 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def, // Create the kernel. NameAttrList function; - function.set_name(node_def.op()); - *(function.mutable_attr()) = node_def.attr(); - + TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function)); Device* dev = flr->device(); Status s; OpKernelConstruction construction( diff --git a/tensorflow/compiler/jit/xla_kernel_creator_util.h b/tensorflow/compiler/jit/xla_kernel_creator_util.h index 71398c334fc..5ec8df01f77 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator_util.h +++ b/tensorflow/compiler/jit/xla_kernel_creator_util.h @@ -24,11 +24,9 @@ namespace tensorflow { class FunctionLibraryRuntime; class OpKernel; - // Given a NodeDef 'node_def' and the function library runtime 'flr', returns - // true if 'node_def' is a call to a compilable function defined in 'flr', - // with the kXlaCompileAttr set. -bool CanCreateXlaKernel(const FunctionLibraryRuntime& flr, - const NodeDef& node_def); +// Given a NodeDef `node_def` returns true iff `node_def` has kXlaCompileAttr +// set. +bool CanCreateXlaKernel(const NodeDef& node_def); // Given a supported NodeDef, returns a XlaLaunchOp that computes the node. Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def, diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index b54d5867487..554288a0937 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -6,7 +6,7 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_binary") package( default_visibility = [ "//tensorflow/compiler/tf2xla:__subpackages__", - "@local_config_mlir//:friends", + "@llvm-project//mlir:friends", ], licenses = ["notice"], # Apache 2.0 ) @@ -30,8 +30,8 @@ cc_library( hdrs = ["op_or_arg_name_mapper.h"], deps = [ "@com_google_absl//absl/strings", - "@llvm//:support", - "@local_config_mlir//:IR", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", ], ) @@ -43,11 +43,11 @@ cc_library( ":passes", "//tensorflow/core:lib", "//tensorflow/core/platform:logging", - "@llvm//:support", - "@local_config_mlir//:MlirOptLib", - "@local_config_mlir//:Pass", - "@local_config_mlir//:Support", - "@local_config_mlir//test:TestTransforms", + "@llvm-project//llvm:support", + "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir/test:TestTransforms", ], ) @@ -80,9 +80,9 @@ cc_library( "//tensorflow/compiler/mlir/xla:xla_legalize_tf", "//tensorflow/compiler/mlir/xla:xla_legalize_to_standard", "//tensorflow/compiler/mlir/xla:xla_lower", - "@local_config_mlir//:AffineDialectRegistration", - "@local_config_mlir//:QuantOps", - "@local_config_mlir//:QuantOpsDialectRegistration", + "@llvm-project//mlir:AffineDialectRegistration", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:QuantOpsDialectRegistration", ], ) @@ -92,7 +92,7 @@ cc_library( hdrs = ["init_mlir.h"], deps = [ "//tensorflow/core:lib", - "@llvm//:support", + "@llvm-project//llvm:support", ], ) @@ -122,11 +122,11 @@ tf_cc_binary( "//tensorflow/core:tensorflow", "//tensorflow/stream_executor/lib", "@com_google_absl//absl/strings", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:Support", - "@local_config_mlir//:TranslateClParser", - "@local_config_mlir//:Translation", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TranslateClParser", + "@llvm-project//mlir:Translation", ], ) diff --git a/tensorflow/compiler/mlir/README.md b/tensorflow/compiler/mlir/README.md index f86b329b39f..cbb0b08503a 100644 --- a/tensorflow/compiler/mlir/README.md +++ b/tensorflow/compiler/mlir/README.md @@ -1,11 +1,11 @@ # MLIR dialects and utilities for TensorFlow, TensorFlow Lite and XLA. This module contains the MLIR -([Multi-Level Intermediate Representation](https://github.com/tensorflow/mlir)) +([Multi-Level Intermediate Representation](https://mlir.llvm.org)) dialects and utilities for 1. TensorFlow 2. XLA 3. TF Lite -See [MLIR repo](https://github.com/tensorflow/mlir) for complete documentation. \ No newline at end of file +See [MLIR's website](https://mlir.llvm.org) for complete documentation. diff --git a/tensorflow/compiler/mlir/glob_lit_test.bzl b/tensorflow/compiler/mlir/glob_lit_test.bzl index f82f719f2ce..fda2f819b98 100644 --- a/tensorflow/compiler/mlir/glob_lit_test.bzl +++ b/tensorflow/compiler/mlir/glob_lit_test.bzl @@ -10,7 +10,7 @@ load("@bazel_skylib//lib:paths.bzl", "paths") # Default values used by the test runner. _default_test_file_exts = ["mlir", ".pbtxt", ".td"] -_default_driver = "@local_config_mlir//:run_lit.sh" +_default_driver = "@llvm-project//mlir:run_lit.sh" _default_size = "small" _default_tags = ["no_rocm"] @@ -50,16 +50,16 @@ def _run_lit_test(name, data, size, tags, driver, features): native.py_test( name = name, - srcs = ["@llvm//:lit"], + srcs = ["@llvm-project//llvm:lit"], tags = tags, args = [ "tensorflow/compiler/mlir/" + paths.basename(data[-1]) + " --config-prefix=runlit -v", ] + features, data = data + [ "//tensorflow/compiler/mlir:litfiles", - "@llvm//:FileCheck", - "@llvm//:count", - "@llvm//:not", + "@llvm-project//llvm:FileCheck", + "@llvm-project//llvm:count", + "@llvm-project//llvm:not", ], size = size, main = "lit.py", diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 6f08ca3b5e8..700b2e6bb16 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -1,6 +1,6 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_native_cc_binary") load( - "@local_config_mlir//:tblgen.bzl", + "//third_party/mlir:tblgen.bzl", "gentbl", ) @@ -8,13 +8,14 @@ package( default_visibility = [ # TODO(jpienaar): Make the visibility more restrictive. ":friends", + "//tensorflow/lite/experimental/tf_runtime:__subpackages__", ], licenses = ["notice"], # Apache 2.0 ) package_group( name = "friends", - includes = ["@local_config_mlir//:subpackages"], + includes = ["//third_party/mlir:subpackages"], packages = [ "//learning/brain/experimental/mlir/...", "//learning/brain/google/xla/...", @@ -27,7 +28,7 @@ filegroup( srcs = [ "ir/tfl_ops.td", "//tensorflow/compiler/mlir/lite/quantization:quantization_td_files", - "@local_config_mlir//:OpBaseTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", ], ) @@ -47,7 +48,7 @@ gentbl( "g3doc/tfl_ops.md", ), ], - tblgen = "@local_config_mlir//:mlir-tblgen", + tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tfl_ops.td", td_srcs = [ ":tensorflow_lite_ops_td_files", @@ -62,11 +63,11 @@ gentbl( "transforms/generated_prepare_tf.inc", ), ], - tblgen = "@local_config_mlir//:mlir-tblgen", + tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/prepare_patterns.td", td_srcs = [ ":tensorflow_lite_ops_td_files", - "@local_config_mlir//:StdOpsTdFiles", + "@llvm-project//mlir:StdOpsTdFiles", "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", "//tensorflow/compiler/mlir/tensorflow:tensorflow_optimize_td_files", ], @@ -80,11 +81,11 @@ gentbl( "transforms/generated_lower_static_tensor_list.inc", ), ], - tblgen = "@local_config_mlir//:mlir-tblgen", + tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/tensorlist_patterns.td", td_srcs = [ ":tensorflow_lite_ops_td_files", - "@local_config_mlir//:StdOpsTdFiles", + "@llvm-project//mlir:StdOpsTdFiles", "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", ], ) @@ -97,11 +98,11 @@ gentbl( "transforms/generated_legalize_tf.inc", ), ], - tblgen = "@local_config_mlir//:mlir-tblgen", + tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/legalize_patterns.td", td_srcs = [ ":tensorflow_lite_ops_td_files", - "@local_config_mlir//:StdOpsTdFiles", + "@llvm-project//mlir:StdOpsTdFiles", "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", ], ) @@ -114,11 +115,12 @@ gentbl( "transforms/generated_optimize.inc", ), ], - tblgen = "@local_config_mlir//:mlir-tblgen", + tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/optimize_patterns.td", td_srcs = [ ":tensorflow_lite_ops_td_files", - "@local_config_mlir//:StdOpsTdFiles", + "@llvm-project//mlir:StdOpsTdFiles", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", ], ) @@ -130,11 +132,11 @@ gentbl( "transforms/generated_quantize.inc", ), ], - tblgen = "@local_config_mlir//:mlir-tblgen", + tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/quantize_patterns.td", td_srcs = [ ":tensorflow_lite_ops_td_files", - "@local_config_mlir//:StdOpsTdFiles", + "@llvm-project//mlir:StdOpsTdFiles", ], ) @@ -146,11 +148,11 @@ gentbl( "transforms/generated_post_quantize.inc", ), ], - tblgen = "@local_config_mlir//:mlir-tblgen", + tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/post_quantize_patterns.td", td_srcs = [ ":tensorflow_lite_ops_td_files", - "@local_config_mlir//:StdOpsTdFiles", + "@llvm-project//mlir:StdOpsTdFiles", ], ) @@ -163,9 +165,9 @@ cc_library( "utils/validators.h", ], deps = [ - "@local_config_mlir//:Dialect", - "@local_config_mlir//:IR", - "@local_config_mlir//:StandardOps", + "@llvm-project//mlir:Dialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:StandardOps", ], ) @@ -183,21 +185,21 @@ cc_library( "transforms/passes.h", "utils/attribute_utils.h", "//tensorflow/compiler/mlir/lite/quantization:quantization_traits.h", - "@local_config_mlir//:include/mlir/Transforms/InliningUtils.h", + "@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h", ], deps = [ ":tensorflow_lite_ops_inc_gen", ":validators", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/lite/schema:schema_fbs", - "@llvm//:support", - "@local_config_mlir//:Analysis", - "@local_config_mlir//:Dialect", - "@local_config_mlir//:IR", - "@local_config_mlir//:Pass", - "@local_config_mlir//:QuantOps", - "@local_config_mlir//:StandardOps", - "@local_config_mlir//:Support", + "@llvm-project//llvm:support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:Dialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", ], alwayslink = 1, ) @@ -214,10 +216,10 @@ cc_library( deps = [ ":tensorflow_lite", "//tensorflow/compiler/mlir/tensorflow", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:StandardOps", - "@local_config_mlir//:Support", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", ], ) @@ -231,9 +233,9 @@ cc_library( ], deps = [ ":tensorflow_lite", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:StandardOps", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:StandardOps", ], ) @@ -246,10 +248,10 @@ tf_cc_test( "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/core:test", "//tensorflow/core:test_main", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:StandardOps", - "@local_config_mlir//:Support", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", ], ) @@ -290,14 +292,14 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:logging", "@com_google_absl//absl/memory", - "@llvm//:support", - "@local_config_mlir//:Analysis", - "@local_config_mlir//:IR", - "@local_config_mlir//:Pass", - "@local_config_mlir//:QuantOps", - "@local_config_mlir//:StandardOps", - "@local_config_mlir//:Support", - "@local_config_mlir//:Transforms", + "@llvm-project//llvm:support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", ], alwayslink = 1, ) @@ -315,12 +317,12 @@ cc_library( ":tensorflow_lite", ":validators", "//tensorflow/compiler/mlir/tensorflow", - "@llvm//:support", - "@local_config_mlir//:Analysis", - "@local_config_mlir//:IR", - "@local_config_mlir//:Pass", - "@local_config_mlir//:StandardOps", - "@local_config_mlir//:Support", + "@llvm-project//llvm:support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", ], alwayslink = 1, ) @@ -346,13 +348,13 @@ cc_library( "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/memory", - "@llvm//:support", - "@local_config_mlir//:Analysis", - "@local_config_mlir//:IR", - "@local_config_mlir//:Pass", - "@local_config_mlir//:QuantOps", - "@local_config_mlir//:StandardOps", - "@local_config_mlir//:Support", + "@llvm-project//llvm:support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", ], alwayslink = 1, ) @@ -374,7 +376,7 @@ genrule( "utils/generated_op_quant_spec_getters.inc", ], cmd = ("$(location //tensorflow/compiler/mlir/lite/quantization:op_quant_spec_getters_gen) " + - "-I external/local_config_mlir/include " + + "-I external/llvm-project/mlir/include " + "-I external/org_tensorflow " + "$(location //tensorflow/compiler/mlir/lite:ir/tfl_ops.td) " + " -o $@"), tools = ["//tensorflow/compiler/mlir/lite/quantization:op_quant_spec_getters_gen"], @@ -388,7 +390,7 @@ cc_library( ], deps = [ ":tensorflow_lite", - "@local_config_mlir//:IR", + "@llvm-project//mlir:IR", ], alwayslink = 1, ) @@ -399,9 +401,9 @@ tf_native_cc_binary( "operator_converter_gen.cc", ], deps = [ - "@llvm//:support", - "@llvm//:tablegen", - "@local_config_mlir//:TableGen", + "@llvm-project//llvm:support", + "@llvm-project//llvm:tablegen", + "@llvm-project//mlir:TableGen", ], ) @@ -434,12 +436,17 @@ cc_library( deps = [ ":tensorflow_lite", "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/core/platform:errors", + "//tensorflow/core/platform:status", + "//tensorflow/lite/kernels/internal:kernel_utils", "//tensorflow/lite/schema:schema_fbs", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", "@flatbuffers", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:TransformUtils", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:TransformUtils", ], ) @@ -462,7 +469,7 @@ cc_library( ], deps = [ "//tensorflow/lite/core/api", - "@local_config_mlir//:IR", + "@llvm-project//mlir:IR", ], ) @@ -507,14 +514,14 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@flatbuffers", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:QuantOps", - "@local_config_mlir//:QuantOpsDialectRegistration", - "@local_config_mlir//:StandardDialectRegistration", - "@local_config_mlir//:StandardOps", - "@local_config_mlir//:Support", - "@local_config_mlir//:Translation", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:QuantOpsDialectRegistration", + "@llvm-project//mlir:StandardDialectRegistration", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Translation", ], alwayslink = 1, ) @@ -523,7 +530,7 @@ tf_cc_binary( name = "flatbuffer_translate", deps = [ ":flatbuffer_translate_lib", - "@local_config_mlir//:MlirTranslateMain", + "@llvm-project//mlir:MlirTranslateMain", ], ) @@ -536,7 +543,7 @@ cc_library( "tf_tfl_translate_cl.h", ], deps = [ - "@llvm//:support", + "@llvm-project//llvm:support", ], alwayslink = 1, ) @@ -548,7 +555,7 @@ cc_library( ], deps = [ "//tensorflow/compiler/mlir/lite/quantization:quantization_config", - "@llvm//:support", + "@llvm-project//llvm:support", ], ) @@ -576,9 +583,9 @@ tf_cc_binary( "//tensorflow/lite/schema:schema_fbs", "//tensorflow/stream_executor/lib", "@com_google_absl//absl/strings", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:Support", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -589,16 +596,15 @@ tf_cc_binary( ":flatbuffer_translate_lib", "//tensorflow/core:lib", "//tensorflow/core/platform:logging", - "//tensorflow/core/platform/default/build_config:base", "//tensorflow/lite:framework", "//tensorflow/lite/delegates/flex:delegate", "//tensorflow/lite/kernels:builtin_ops", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:Parser", - "@local_config_mlir//:Support", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Support", ], ) @@ -621,12 +627,12 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes", "//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass", "//tensorflow/compiler/mlir/tensorflow:translate_lib", - "@local_config_mlir//:Analysis", - "@local_config_mlir//:IR", - "@local_config_mlir//:Pass", - "@local_config_mlir//:QuantOps", - "@local_config_mlir//:QuantOpsDialectRegistration", - "@local_config_mlir//:Transforms", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:QuantOpsDialectRegistration", + "@llvm-project//mlir:Transforms", ], ) @@ -653,15 +659,15 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/lite/tools/optimize:quantize_weights", "//tensorflow/stream_executor/lib", - "@llvm//:support", - "@local_config_mlir//:Analysis", - "@local_config_mlir//:IR", - "@local_config_mlir//:Parser", - "@local_config_mlir//:Pass", - "@local_config_mlir//:QuantOps", - "@local_config_mlir//:QuantOpsDialectRegistration", - "@local_config_mlir//:Support", - "@local_config_mlir//:Transforms", + "@llvm-project//llvm:support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:QuantOpsDialectRegistration", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", ], ) diff --git a/tensorflow/compiler/mlir/lite/emit_error_reporter.h b/tensorflow/compiler/mlir/lite/emit_error_reporter.h index 40e89c5dec8..76cc1f612bb 100644 --- a/tensorflow/compiler/mlir/lite/emit_error_reporter.h +++ b/tensorflow/compiler/mlir/lite/emit_error_reporter.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "mlir/IR/Module.h" // TF:local_config_mlir +#include "mlir/IR/Module.h" // TF:llvm-project #include "tensorflow/lite/core/api/error_reporter.h" namespace tflite { diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index 11d97120d00..43974e02bba 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -43,24 +44,24 @@ limitations under the License. #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:local_config_mlir -#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Diagnostics.h" // TF:local_config_mlir -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/Location.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/Types.h" // TF:local_config_mlir -#include "mlir/IR/Value.h" // TF:local_config_mlir -#include "mlir/Support/Functional.h" // TF:local_config_mlir -#include "mlir/Support/LLVM.h" // TF:local_config_mlir -#include "mlir/Translation.h" // TF:local_config_mlir +#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project +#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Diagnostics.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/OperationSupport.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/IR/Value.h" // TF:llvm-project +#include "mlir/Support/Functional.h" // TF:llvm-project +#include "mlir/Support/LLVM.h" // TF:llvm-project +#include "mlir/Translation.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_translate_flags.h" @@ -103,12 +104,26 @@ using llvm::cl::opt; // Commandline flag to enable the control of flatbuffer import. bool use_external_constant; +// Commandline flag to enable graph pruning. +bool experimental_prune_unreachable_nodes_unconditionally; + // NOLINTNEXTLINE static opt use_external_constant_flag( "use-external-constant", llvm::cl::desc("Use external constant during flatbuffer import"), llvm::cl::location(use_external_constant), llvm::cl::init(false)); +// TODO(b/147111261): After the importer supports generic custom ops, we should +// change the flag to a more lightwise flag, e.g. +// "import_custom_ops_as_side_effect_free_ops", and let the MLIR DCE to prune +// the operations. +// NOLINTNEXTLINE +static opt experimental_prune_unreachable_nodes_unconditionally_flg( + "experimental-prune-unreachable-nodes-unconditionally", + llvm::cl::desc("Prune nodes that are not ancestors of the output nodes."), + llvm::cl::location(experimental_prune_unreachable_nodes_unconditionally), + llvm::cl::init(false)); + namespace { bool IsScalar(const TensorT& tensor) { // TODO(b/138222071) We can't distinguish scalars and unranked tensors @@ -212,12 +227,12 @@ StatusOr GetTensorType(const TensorT& tensor, Builder builder, // type, thus none stats op is required and nullptr is retruned. // If the min max information is invalid, nullptr is returned. mlir::Operation* ConvertMinMaxToStatsOp(const TensorT& tensor, OpBuilder b, - Value* res) { + Value res) { // If the `tensor` has scale/zero_point, it must have been quantized, then the // min/max stats is just for comments, so ignore it. if (!tensor.quantization || IsQuantized(tensor)) return nullptr; // If the result isn't float and unquantizable, the min/max is ignored. - if (!res->getType() + if (!res.getType() .cast() .getElementType() .isa()) { @@ -255,10 +270,23 @@ mlir::Operation* ConvertMinMaxToStatsOp(const TensorT& tensor, OpBuilder b, } StatusOr OpNameForOpCode(const tflite::OperatorCodeT opcode) { - // TODO(krzysd) Support custom ops + // TODO(b/143872630): Support custom ops if (opcode.builtin_code == tflite::BuiltinOperator_CUSTOM) { - return errors::Unimplemented("unsupported custom operation: ", - opcode.custom_code); + // Adding some custom op supported on GPU. + const absl::string_view custom_name = opcode.custom_code; + if (custom_name == "MaxPoolingWithArgmax2D") { + return std::string("tfl.max_pooling_with_argmax_2d"); + } + if (custom_name == "Convolution2DTransposeBias") { + return std::string("tfl.convolution_2d_transpose_bias"); + } + if (custom_name == "MaxUnpooling2D") { + return std::string("tfl.max_unpooling_2d"); + } + // Use an unsupported op name instead of throwing an error here in case the + // op is pruned during the import. + return std::string( + llvm::Twine("tfl.UNSUPPORTED_custom_", opcode.custom_code).str()); } if (opcode.builtin_code == tflite::BuiltinOperator_IF) { return std::string("tf.If"); @@ -495,14 +523,21 @@ bool IsBasicLSTMOp(tflite::BuiltinOptionsUnion op_union) { } } +// Returns true if this is a custom op. +bool IsCustomOp(const std::string& op_name) { + return op_name == "tfl.max_pooling_with_argmax_2d" || + op_name == "tfl.max_unpooling_2d" || + op_name == "tfl.convolution_2d_transpose_bias"; +} + // TODO(krzysd) Handle function calls StatusOr ConvertOp( - const tflite::OperatorT& op, const std::vector& vals_map, - Value* optional_arg_marker, const std::vector& op_names, + const tflite::OperatorT& op, const std::vector& vals_map, + Value optional_arg_marker, const std::vector& op_names, const std::vector& func_names, const std::vector>& tensors, Location loc, OpBuilder builder) { - llvm::SmallVector operands; + llvm::SmallVector operands; llvm::SmallVector outputTypes; if (op.outputs.empty()) { @@ -557,7 +592,15 @@ StatusOr ConvertOp( } llvm::SmallVector attrs; - mlir::BuiltinOptionsToAttributes(op.builtin_options, builder, attrs); + if (IsCustomOp(op_name)) { + auto status = mlir::CustomOptionsToAttributes(op_name, op.custom_options, + builder, loc, &attrs); + if (!status.ok()) { + return emitError(loc, status.ToString()), status; + } + } else { + mlir::BuiltinOptionsToAttributes(op.builtin_options, builder, attrs); + } op_state.addAttributes(attrs); // Handle the conversion from subgraph index to functions for If and While @@ -619,6 +662,49 @@ mlir::NamedAttribute BuildTFEntryFunctionAttribute( name, builder->getStringAttr(llvm::join(tensor_names, ","))); } +// Given a list of output indices, traverses the subgraph and returns the set of +// ops that are ancestors of the output tensors. +StatusOr> PruneSubgraph( + const tflite::SubGraphT& subgraph, ArrayRef output_indices) { + // Create a map from tensor index to defining op. + absl::flat_hash_map defining_op; + for (const auto& op : subgraph.operators) { + for (int32_t output : op->outputs) { + defining_op[output] = op.get(); + } + } + + std::vector queue; + for (int32_t output : output_indices) { + if (auto& op = defining_op[output]) { + queue.push_back(op); + } else { + return errors::InvalidArgument("Output tensor doesn't have defining op"); + } + } + + // Traverse the graph towards inputs. + absl::flat_hash_set visited; + while (!queue.empty()) { + const tflite::OperatorT* op = queue.back(); + queue.pop_back(); + if (!visited.insert(op).second) { + // The node has already been visited. + continue; + } + + for (int32_t input : op->inputs) { + // Input tensor may not have a defining op in case it is a subgraph input + // or a constant tensor. + if (auto& op = defining_op[input]) { + queue.push_back(op); + } + } + } + + return visited; +} + // Build a FuncOp from a tflite SubGraph // The op_names are a mapping from indexes into the TFLite operators array to // the operator name MLIR expects (tfl.foo_op). The buffers are directly taken @@ -635,7 +721,8 @@ StatusOr ConvertSubgraph( const std::vector>& buffers, Location base_loc, Builder builder, const std::vector& ordered_output_arrays, bool is_entry_point, - bool use_external_constant) { + bool use_external_constant, + bool experimental_prune_unreachable_nodes_unconditionally) { llvm::SmallVector ret_types; llvm::SmallVector input_types; @@ -692,19 +779,19 @@ StatusOr ConvertSubgraph( auto& body = func.getBody(); OpBuilder op_builder{body}; - std::vector vals_map(subgraph.tensors.size(), nullptr); - Value* maybe_optional_arg_marker = nullptr; + std::vector vals_map(subgraph.tensors.size(), nullptr); + Value maybe_optional_arg_marker = nullptr; // Get or construct MLIR values for each input for (int i = 0, e = subgraph.inputs.size(); i < e; i++) { auto input_tensor = subgraph.inputs[i]; const auto& tensor = *subgraph.tensors.at(input_tensor); auto loc = TensorLoc(tensor, builder, base_loc); - if (nullptr != vals_map[input_tensor]) { + if (vals_map[input_tensor]) { auto err = errors::FailedPrecondition("duplicate input arguments"); return emitError(loc, err.ToString()), err; } - Value* input_value = func.getArgument(i); + Value input_value = func.getArgument(i); // If the `tensor` has min/max and doesn't have scale/zero_point // information, a stats op is created to use the input_value, then the @@ -731,8 +818,19 @@ StatusOr ConvertSubgraph( func.setAttr("tf.entry_function", builder.getDictionaryAttr(attributes)); } + absl::flat_hash_set pruned_subgraph_ops; + if (experimental_prune_unreachable_nodes_unconditionally) { + TF_ASSIGN_OR_RETURN(pruned_subgraph_ops, + PruneSubgraph(subgraph, func_outputs)); + } + // Construct MLIR operators from TFLite operators for (auto& op : subgraph.operators) { + if (experimental_prune_unreachable_nodes_unconditionally && + !pruned_subgraph_ops.contains(op)) { + continue; + } + for (auto input_num : op->inputs) { // The operators in a graph are topologically sorted // and so if no previous operation has produced a tensor @@ -745,7 +843,7 @@ StatusOr ConvertSubgraph( builder.getUnitAttr()) .getResult(); } - } else if (nullptr == vals_map.at(input_num)) { + } else if (!vals_map.at(input_num)) { auto& const_tensor = *subgraph.tensors[input_num]; auto const_loc = TensorLoc(const_tensor, builder, base_loc); auto op_or_err = @@ -768,7 +866,7 @@ StatusOr ConvertSubgraph( ? base_loc : TensorLoc(*subgraph.tensors[op->outputs[0]], builder, base_loc); // If there's an optional argument, maybe_optional_arg_marker has been set - // to a valid Value* + // to a valid Value TF_ASSIGN_OR_RETURN( auto* mlir_op, ConvertOp(*op, vals_map, maybe_optional_arg_marker, op_names, @@ -791,9 +889,9 @@ StatusOr ConvertSubgraph( } // Construct return values - llvm::SmallVector return_operands; + llvm::SmallVector return_operands; for (auto index : func_outputs) { - if (nullptr == vals_map.at(index)) { + if (!vals_map.at(index)) { auto& const_tensor = *subgraph.tensors[index]; auto const_loc = TensorLoc(const_tensor, builder, base_loc); auto op_or_err = @@ -837,7 +935,8 @@ std::string SubgraphName(unsigned index, const tflite::SubGraphT& subgraph) { OwningModuleRef tflite::FlatBufferToMlir( absl::string_view buffer, MLIRContext* context, Location base_loc, const std::vector& ordered_output_arrays, - bool use_external_constant) { + bool use_external_constant, + bool experimental_prune_unreachable_nodes_unconditionally) { auto model_ptr = FlatBufferModel::VerifyAndBuildFromBuffer(buffer.data(), buffer.length()); if (nullptr == model_ptr) { @@ -892,7 +991,8 @@ OwningModuleRef tflite::FlatBufferToMlir( // TODO(b/131175224,b/132239787) Support multiple entry points builder, ordered_output_arrays, /*is_entry_point=*/e.index() == 0, - /*use_external_constant=*/use_external_constant); + /*use_external_constant=*/use_external_constant, + experimental_prune_unreachable_nodes_unconditionally); if (!func_or_error.ok()) { return emitError(base_loc, "could not translate function ") << subgraph->name, @@ -905,9 +1005,10 @@ OwningModuleRef tflite::FlatBufferToMlir( return OwningModuleRef(module); } -static OwningModuleRef FlatBufferFileToMlirTrans(llvm::SourceMgr* source_mgr, - MLIRContext* context, - bool use_external_constant) { +static OwningModuleRef FlatBufferFileToMlirTrans( + llvm::SourceMgr* source_mgr, MLIRContext* context, + bool use_external_constant, + bool experimental_prune_unreachable_nodes_unconditionally) { const llvm::MemoryBuffer* input = source_mgr->getMemoryBuffer(source_mgr->getMainFileID()); std::string error; @@ -924,12 +1025,14 @@ static OwningModuleRef FlatBufferFileToMlirTrans(llvm::SourceMgr* source_mgr, return tflite::FlatBufferToMlir( absl::string_view(input->getBufferStart(), input->getBufferSize()), - context, loc, outputs, use_external_constant); + context, loc, outputs, use_external_constant, + experimental_prune_unreachable_nodes_unconditionally); } static mlir::TranslateToMLIRRegistration FlatBufferFileToMlirTransReg( "tflite-flatbuffer-to-mlir", [](llvm::SourceMgr& source_mgr, MLIRContext* context) { - return FlatBufferFileToMlirTrans(&source_mgr, context, - use_external_constant); + return FlatBufferFileToMlirTrans( + &source_mgr, context, use_external_constant, + experimental_prune_unreachable_nodes_unconditionally); }); diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.h b/tensorflow/compiler/mlir/lite/flatbuffer_import.h index 66b31c54c80..e3210c6d03f 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.h +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.h @@ -17,9 +17,9 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_IMPORT_H_ #include "absl/strings/string_view.h" -#include "mlir/IR/Location.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project namespace tflite { // Converts a TFLite flatbuffer stored in `buffer` to a MLIR module @@ -31,11 +31,14 @@ namespace tflite { // on failure, and more specific errors will be emitted via the context. // If `use_external_constant` is true, it will create `tfl.external_const` // instead of `tfl.const`. +// If `experimental_prune_unreachable_nodes_unconditionally` is true, nodes that +// are not ancestors of the output nodes will be pruned. mlir::OwningModuleRef FlatBufferToMlir( absl::string_view buffer, mlir::MLIRContext* context, mlir::Location base_loc, const std::vector& ordered_output_arrays, - bool use_external_constant = false); + bool use_external_constant = false, + bool experimental_prune_unreachable_nodes_unconditionally = false); } // namespace tflite #endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_IMPORT_H_ diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc b/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc index 851292b10fa..d9680a51ae0 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc @@ -17,15 +17,45 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" +#include "flatbuffers/flexbuffers.h" // TF:flatbuffers #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringSwitch.h" -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/lite/kernels/internal/kernel_utils.h" #include "tensorflow/lite/schema/schema_generated.h" +namespace { + +using ::tensorflow::Status; +using ::tensorflow::errors::InvalidArgument; +using ::xla::StatusOr; + +StatusOr GetPaddingAttr(TfLitePadding pad_params, + mlir::Builder builder, + mlir::Location loc) { + auto padding = tflite::Padding::Padding_VALID; + if (pad_params == TfLitePadding::kTfLitePaddingSame) { + padding = tflite::Padding_SAME; + } else if (pad_params == TfLitePadding::kTfLitePaddingValid) { + padding = tflite::Padding_VALID; + } else { + return InvalidArgument( + absl::StrCat("Invalid padding type", std::to_string(pad_params))); + } + + const char* option_name = tflite::EnumNamePadding(padding); + return builder.getStringAttr(option_name); +} + +} // namespace + // TODO(jpienaar): This is a placeholder. This should be done in more efficient // way when part of the translation of module. static tflite::ActivationFunctionType ConvertTFL_AFAttrForOptionWriter( @@ -212,5 +242,44 @@ static mlir::Attribute BuildTFL_PaddingAttr(tflite::Padding value, return builder.getStringAttr(option_name); } +Status mlir::CustomOptionsToAttributes( + const std::string& op_name, const std::vector& custom_options, + mlir::Builder builder, mlir::Location loc, + llvm::SmallVectorImpl* attributes) { + if (op_name == "tfl.max_pooling_with_argmax_2d" || + op_name == "tfl.max_unpooling_2d") { + auto* pool_params = + reinterpret_cast(custom_options.data()); + TF_ASSIGN_OR_RETURN(auto padding_attribute, + GetPaddingAttr(pool_params->padding, builder, loc)); + attributes->emplace_back( + builder.getNamedAttr("padding", padding_attribute)); + attributes->emplace_back(builder.getNamedAttr( + "stride_h", builder.getI32IntegerAttr(pool_params->stride_height))); + attributes->emplace_back(builder.getNamedAttr( + "stride_w", builder.getI32IntegerAttr(pool_params->stride_width))); + attributes->emplace_back(builder.getNamedAttr( + "filter_w", builder.getI32IntegerAttr(pool_params->filter_height))); + attributes->emplace_back(builder.getNamedAttr( + "filter_h", builder.getI32IntegerAttr(pool_params->filter_width))); + return Status::OK(); + + } else if (op_name == "tfl.convolution_2d_transpose_bias") { + auto* conv_params = reinterpret_cast( + custom_options.data()); + TF_ASSIGN_OR_RETURN(auto padding_attribute, + GetPaddingAttr(conv_params->padding, builder, loc)); + attributes->emplace_back( + builder.getNamedAttr("padding", padding_attribute)); + attributes->emplace_back(builder.getNamedAttr( + "stride_h", builder.getI32IntegerAttr(conv_params->stride_height))); + attributes->emplace_back(builder.getNamedAttr( + "stride_w", builder.getI32IntegerAttr(conv_params->stride_width))); + return Status::OK(); + } + + return InvalidArgument(absl::StrCat("invalid custom op type: ", op_name)); +} + // Pull in FlatBuffer writers for TFLite generated using TableGen #include "tensorflow/compiler/mlir/lite/operator_converters.inc" diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_operator.h b/tensorflow/compiler/mlir/lite/flatbuffer_operator.h index 35293c1b812..fdc0fd81f8f 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_operator.h +++ b/tensorflow/compiler/mlir/lite/flatbuffer_operator.h @@ -26,9 +26,10 @@ limitations under the License. #include "flatbuffers/flatbuffers.h" // TF:flatbuffers #include "llvm/ADT/Optional.h" #include "llvm/ADT/SmallVector.h" -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/schema/schema_generated.h" namespace mlir { @@ -45,7 +46,7 @@ llvm::Optional> CreateFlatBufferOperator( const std::vector &operands, const std::vector &results, flatbuffers::FlatBufferBuilder *fbb); -// Populate the array of mlir::NamedAttributes corresponding to the given +// Populates the array of mlir::NamedAttributes corresponding to the given // tflite::FlatbufferOptionsUnion. // We use an out parameter per LLVM convention void BuiltinOptionsToAttributes( @@ -53,6 +54,15 @@ void BuiltinOptionsToAttributes( // NOLINTNEXTLINE llvm::SmallVectorImpl &attributes); +// Populates the array of mlir::NamedAttributes corresponding to the given +// custom_options. +// We use an out parameter per LLVM convention +tensorflow::Status CustomOptionsToAttributes( + const std::string &op_name, const std::vector &custom_options, + mlir::Builder builder, + // NOLINTNEXTLINE + Location loc, llvm::SmallVectorImpl *attributes); + } // namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_OPERATOR_H_ diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc index 3ed8eb87eb9..5abd37b22fa 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc @@ -41,21 +41,22 @@ limitations under the License. #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/ToolOutputFile.h" -#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/Location.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/Types.h" // TF:local_config_mlir -#include "mlir/IR/Value.h" // TF:local_config_mlir -#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir -#include "mlir/Translation.h" // TF:local_config_mlir +#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/IR/Value.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/Translation.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/utils/convert_type.h" #include "tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h" #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -230,19 +231,19 @@ static bool IsConst(Operation* op) { } template -static bool HasValidTFLiteType(Value* value, T& error_handler) { +static bool HasValidTFLiteType(Value value, T& error_handler) { // None type is allowed to represent unspecified operands. - if (value->getType().isa()) return true; + if (value.getType().isa()) return true; - auto type = value->getType().dyn_cast(); + auto type = value.getType().dyn_cast(); if (!type) { - if (auto op = value->getDefiningOp()) { + if (auto op = value.getDefiningOp()) { error_handler.emitError() << '\'' << op << "' should produce value of tensor type instead of " - << value->getType(); + << value.getType(); return false; } - error_handler.emitError("expected tensor type, got ") << value->getType(); + error_handler.emitError("expected tensor type, got ") << value.getType(); return false; } @@ -279,9 +280,9 @@ static bool IsValidTFLiteMlirModule(ModuleOp module) { } auto& bb = fn.getBlocks().front(); - for (auto* arg : bb.getArguments()) { + for (auto arg : bb.getArguments()) { if (!HasValidTFLiteType(arg, fn)) - return fn.emitError("invalid TFLite type: ") << arg->getType(), false; + return fn.emitError("invalid TFLite type: ") << arg.getType(), false; } // Verify that all operations except the terminator have exactly one @@ -289,9 +290,9 @@ static bool IsValidTFLiteMlirModule(ModuleOp module) { for (auto& inst : bb) { if (inst.isKnownTerminator()) break; - for (auto* result : inst.getResults()) { + for (auto result : inst.getResults()) { if (!HasValidTFLiteType(result, inst)) - return fn.emitError("invalid TFLite type: ") << result->getType(), + return fn.emitError("invalid TFLite type: ") << result.getType(), false; } } @@ -361,7 +362,7 @@ class Translator { // Builds TFLite tensor from the given value. `buffer_idx` is index of the // corresponding buffer. Emits error and returns llvm::None on failure. - Optional> BuildTensor(Value* value, + Optional> BuildTensor(Value value, const std::string& name, unsigned buffer_idx); @@ -419,7 +420,7 @@ class Translator { bool IsStatefulOperand(mlir::Operation* op, int operand_index); // Returns a unique name for `val`. - std::string UniqueName(mlir::Value* val); + std::string UniqueName(mlir::Value val); ModuleOp module_; @@ -449,7 +450,7 @@ class Translator { std::vector failed_custom_ops_; }; -std::string Translator::UniqueName(mlir::Value* val) { +std::string Translator::UniqueName(mlir::Value val) { return name_mapper_.GetUniqueName(val); } @@ -502,8 +503,8 @@ Optional> Translator::BuildBuffer( } Optional> Translator::BuildTensor( - Value* value, const std::string& name, unsigned buffer_idx) { - auto type = value->getType().cast(); + Value value, const std::string& name, unsigned buffer_idx) { + auto type = value.getType().cast(); // TFLite requires tensor shape only for the inputs and constants. // However, we output all known shapes for better round-tripping @@ -515,7 +516,7 @@ Optional> Translator::BuildTensor( if (std::any_of(shape_ref.begin(), shape_ref.end(), is_out_of_range)) return mlir::emitError( - value->getLoc(), + value.getLoc(), "result shape dimensions out of 32 bit int type range"); return mlir::success(); @@ -527,7 +528,7 @@ Optional> Translator::BuildTensor( if (mlir::failed(check_shape(shape_ref))) return llvm::None; shape = std::vector(shape_ref.begin(), shape_ref.end()); - } else if (auto* inst = value->getDefiningOp()) { + } else if (auto* inst = value.getDefiningOp()) { if (IsConst(inst)) { // Const op can have a result of dynamic shaped type (e.g. due to constant // folding), but we can still derive the shape of a constant tensor for @@ -570,7 +571,7 @@ Optional> Translator::BuildTensor( // marked as a stateful. If so, set the tensor's is_variable as true // This is v1 ref variable semantics in the TFLite runtime. bool is_variable = false; - for (auto& use : value->getUses()) { + for (auto& use : value.getUses()) { is_variable = IsStatefulOperand(use.getOwner(), use.getOperandNumber()); if (is_variable) { break; @@ -669,6 +670,16 @@ Translator::CreateFlexBuilderWithNodeAttrs( case ::tensorflow::AttrValue::kS: flex_builder->String(key, attr.s()); break; + case ::tensorflow::AttrValue::kType: { + auto status_or_tfl_type = tflite::TfTypeToTflType(attr.type()); + if (status_or_tfl_type.ok()) { + flex_builder->Int(key, status_or_tfl_type.ValueOrDie()); + } else { + emitWarning(loc, "ignoring unsupported tensorflow type: ") + << std::to_string(attr.type()); + } + break; + } case ::tensorflow::AttrValue::kI: flex_builder->Int(key, attr.i()); break; @@ -906,13 +917,13 @@ Optional> Translator::BuildSubGraph(FuncOp fn) { bool has_input_attr = false; InitializeNamesFromAttribute(fn, &has_input_attr); std::vector> tensors; - llvm::DenseMap tensor_index_map; + llvm::DenseMap tensor_index_map; // Builds tensor and buffer for argument or operation result. Returns false // on failure. - auto build_tensor_and_buffer = [&](Value* value, const std::string& name) { + auto build_tensor_and_buffer = [&](Value value, const std::string& name) { // NoneType represents optional and may be skipped here. - if (value->getType().isa()) { + if (value.getType().isa()) { return true; } @@ -925,7 +936,7 @@ Optional> Translator::BuildSubGraph(FuncOp fn) { // make the Buffer empty apart from setting the buffer_idx=0 in the Tensor. // This does not seem to affect runtime behavior for RNN/LSTM, but would be // good for reducing memory footprint. - if (auto* inst = value->getDefiningOp()) { + if (auto* inst = value.getDefiningOp()) { auto buffer_or = BuildBuffer(inst); if (!buffer_or) return false; buffers_.push_back(*buffer_or); @@ -942,7 +953,7 @@ Optional> Translator::BuildSubGraph(FuncOp fn) { // have associated tensor and buffer. Build FlatBuffer tensor and buffer for // other functions. for (unsigned i = 0, e = bb.getNumArguments(); i < e; ++i) { - mlir::BlockArgument* arg = bb.getArgument(i); + mlir::BlockArgument arg = bb.getArgument(i); std::string name; if (has_input_attr) name = name_mapper_.GetUniqueName(arg); if (name.empty()) name = absl::StrCat("arg", i); @@ -964,15 +975,15 @@ Optional> Translator::BuildSubGraph(FuncOp fn) { // Fetch operand and result tensor indices. std::vector operands; operands.reserve(inst.getNumOperands()); - for (auto* operand : inst.getOperands()) { - if (operand->getType().isa()) + for (auto operand : inst.getOperands()) { + if (operand.getType().isa()) operands.push_back(kTfLiteOptionalTensor); else operands.push_back(tensor_index_map.lookup(operand)); } std::vector results; results.reserve(inst.getNumOperands()); - for (auto* result : inst.getResults()) { + for (auto result : inst.getResults()) { results.push_back(tensor_index_map.lookup(result)); } @@ -986,10 +997,10 @@ Optional> Translator::BuildSubGraph(FuncOp fn) { // Get input and output tensor indices for the subgraph. std::vector inputs, outputs; - for (auto* arg : bb.getArguments()) { + for (auto arg : bb.getArguments()) { inputs.push_back(tensor_index_map[arg]); } - for (auto* result : bb.getTerminator()->getOperands()) { + for (auto result : bb.getTerminator()->getOperands()) { outputs.push_back(tensor_index_map[result]); } diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_translate.h b/tensorflow/compiler/mlir/lite/flatbuffer_translate.h index a69921c3b09..03f92ddbf03 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_translate.h +++ b/tensorflow/compiler/mlir/lite/flatbuffer_translate.h @@ -18,14 +18,15 @@ limitations under the License. #include -#include "mlir/IR/Module.h" // TF:local_config_mlir +#include "mlir/IR/Module.h" // TF:llvm-project #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" namespace tflite { // Translates the given MLIR `module` into a FlatBuffer and stores the -// serialized flatbuffer into the string. This uses OpLocNameMapper to convert -// location of the op to name in flatbuffer. +// serialized flatbuffer into the string. This uses OpOrArgLocNameMapper to +// convert location of the op to name in flatbuffer. Returns true if translation +// fails, otherwise returns false. bool MlirToFlatBufferTranslateFunction(mlir::ModuleOp module, std::string* serialized_flatbuffer, bool emit_builtin_tflite_ops, diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index 221c9aa2adc..b72b519a724 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -25,17 +25,17 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/FormatVariadic.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Matchers.h" // TF:local_config_mlir -#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir -#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir -#include "mlir/Support/LLVM.h" // TF:local_config_mlir -#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir -#include "mlir/Transforms/InliningUtils.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Matchers.h" // TF:llvm-project +#include "mlir/IR/OpImplementation.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/TypeUtilities.h" // TF:llvm-project +#include "mlir/Support/LLVM.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/Transforms/InliningUtils.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" namespace mlir { @@ -301,14 +301,14 @@ Attribute ConstFoldUnaryOp(Type result_type, Attribute operand, return {}; } -void buildComparisonBinOp(Builder *builder, OperationState &result, Value *lhs, - Value *rhs) { +void buildComparisonBinOp(Builder *builder, OperationState &result, Value lhs, + Value rhs) { auto result_type = - OpTrait::util::getBroadcastedType(lhs->getType(), rhs->getType()); + OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType()); if (!result_type) emitError(result.location) - << "non-broadcastable operands: " << lhs->getType() << " and " - << rhs->getType(); + << "non-broadcastable operands: " << lhs.getType() << " and " + << rhs.getType(); result.addOperands({lhs, rhs}); // Comparison binary ops always return i1 tensor. if (auto shaped_type = result_type.dyn_cast()) { @@ -321,15 +321,15 @@ void buildComparisonBinOp(Builder *builder, OperationState &result, Value *lhs, } void buildFusedBroadcastableBinOp(Builder *builder, OperationState &result, - Value *lhs, Value *rhs, + Value lhs, Value rhs, StringAttr fused_activation_function) { auto result_type = - OpTrait::util::getBroadcastedType(lhs->getType(), rhs->getType()); + OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType()); if (!result_type) emitError(result.location) - << "non-broadcastable operands: " << lhs->getType() << " and " - << rhs->getType(); + << "non-broadcastable operands: " << lhs.getType() << " and " + << rhs.getType(); result.addOperands({lhs, rhs}); result.addAttribute("fused_activation_function", fused_activation_function); @@ -358,7 +358,7 @@ OpFoldResult AddOp::fold(ArrayRef operands) { namespace { int64_t GetConcatenationOpAxis(ConcatenationOp op) { - auto output_type = op.output()->getType().cast(); + auto output_type = op.output().getType().cast(); int64_t axis = op.axis().getSExtValue(); if (axis < 0) axis += output_type.getRank(); return axis; @@ -452,7 +452,7 @@ LogicalResult VerifyConcatenationOpTypes(Operation *op, } LogicalResult Verify(ConcatenationOp op) { - auto output_type = op.output()->getType().dyn_cast(); + auto output_type = op.output().getType().dyn_cast(); // If the output type is unranked, there is nothing else to be verified. if (!output_type) return success(); @@ -462,8 +462,8 @@ LogicalResult Verify(ConcatenationOp op) { return op.emitOpError("concatenation dimension must be in [-rank, rank)"); SmallVector operand_types; - for (Value *operand : op.values()) - operand_types.push_back(operand->getType().cast()); + for (Value operand : op.values()) + operand_types.push_back(operand.getType().cast()); return VerifyConcatenationOpTypes(op.getOperation(), output_type, operand_types, axis); @@ -520,7 +520,7 @@ DenseElementsAttr ConstFoldConcatenateOpDense(ArrayRef operands, OpFoldResult ConcatenationOp::fold(ArrayRef operands) { if (fused_activation_function() == "NONE") { - if (auto output_type = output()->getType().dyn_cast()) { + if (auto output_type = output().getType().dyn_cast()) { const int64_t axis = GetConcatenationOpAxis(*this); if (IsConcatenationOpConstFoldable(*this, operands, output_type, axis)) return ConstFoldConcatenateOpDense(operands, output_type, axis); @@ -528,9 +528,9 @@ OpFoldResult ConcatenationOp::fold(ArrayRef operands) { } // Remove all empty values. - SmallVector non_empty_values; - for (Value *value : this->values()) { - const auto shaped_type = value->getType().cast(); + SmallVector non_empty_values; + for (Value value : this->values()) { + const auto shaped_type = value.getType().cast(); if (shaped_type.hasStaticShape() && shaped_type.getNumElements() == 0) { continue; } @@ -559,8 +559,8 @@ OpFoldResult ConcatenationOp::fold(ArrayRef operands) { //===----------------------------------------------------------------------===// LogicalResult Verify(FullyConnectedOp op) { - ShapedType input_type = op.input()->getType().cast(); - ShapedType filter_type = op.filter()->getType().cast(); + ShapedType input_type = op.input().getType().cast(); + ShapedType filter_type = op.filter().getType().cast(); if (filter_type.hasRank() && filter_type.getRank() != 2) { return op.emitOpError("expect 2d filter, got ") << filter_type; } @@ -582,7 +582,7 @@ LogicalResult Verify(FullyConnectedOp op) { // format. if (op.weights_format() == "DEFAULT") { ShapedType output_type = - (*op.output().begin())->getType().cast(); + (*op.output().begin()).getType().cast(); if (!output_type.hasStaticShape()) { return mlir::success(); } @@ -609,9 +609,9 @@ LogicalResult Verify(FullyConnectedOp op) { //===----------------------------------------------------------------------===// static void BuildGatherOp(Builder *builder, OperationState &result, - Value *params, Value *indices, IntegerAttr axis) { - auto params_type = params->getType().cast(); - auto indices_type = indices->getType().cast(); + Value params, Value indices, IntegerAttr axis) { + auto params_type = params.getType().cast(); + auto indices_type = indices.getType().cast(); // If params/indices is unranked, then output is unranked. if (!params_type.hasRank() || !indices_type.hasRank()) @@ -704,8 +704,8 @@ static LogicalResult Verify(PackOp op) { if (op.getOperation()->getNumOperands() != op.values_count()) return op.emitOpError("input count should match 'values_count' attribute"); - Value *operand0 = op.getOperand(0); - auto input_type = operand0->getType().cast(); + Value operand0 = op.getOperand(0); + auto input_type = operand0.getType().cast(); // Check axis bounds. if (input_type.hasRank()) { @@ -717,8 +717,8 @@ static LogicalResult Verify(PackOp op) { // Make sure all inputs have the same shape and element type. // TODO(rahulsp): Simplify once b/135032064 is fixed. - for (Value *operand : op.getOperands()) { - auto other_type = operand->getType().cast(); + for (Value operand : op.getOperands()) { + auto other_type = operand.getType().cast(); if (input_type != other_type) return op.emitOpError("operands should be of the same type. got ") << input_type << ", " << other_type; @@ -732,9 +732,9 @@ static LogicalResult Verify(PackOp op) { //===----------------------------------------------------------------------===// static LogicalResult Verify(PReluOp op) { - auto input_type = op.input()->getType().cast(); - auto alpha_type = op.alpha()->getType().cast(); - auto output_type = op.output()->getType().cast(); + auto input_type = op.input().getType().cast(); + auto alpha_type = op.alpha().getType().cast(); + auto output_type = op.output().getType().cast(); if (input_type.hasStaticShape() && alpha_type.hasStaticShape()) { if (input_type.getRank() != alpha_type.getRank() + 1) { @@ -783,13 +783,13 @@ struct RemoveAdjacentReshape : public RewritePattern { PatternMatchResult match(Operation *op) const override { auto thisOp = cast(op); - auto prevOp = thisOp.getOperand(0)->getDefiningOp(); + auto prevOp = thisOp.getOperand(0).getDefiningOp(); return isa_and_nonnull(prevOp) ? matchSuccess() : matchFailure(); } void rewrite(Operation *op, PatternRewriter &rewriter) const override { auto thisOp = cast(op); - auto prevOp = cast(thisOp.getOperand(0)->getDefiningOp()); + auto prevOp = cast(thisOp.getOperand(0).getDefiningOp()); // Replace // %1 = "tfl.reshape"(%0, %shape0) @@ -807,7 +807,7 @@ struct RemoveAdjacentReshape : public RewritePattern { OpFoldResult ReshapeOp::fold(ArrayRef operands) { // Remove identity reshape with both static result and input shape. auto result_type = getType().cast(); - auto input_type = getOperand(0)->getType().cast(); + auto input_type = getOperand(0).getType().cast(); if (result_type.hasStaticShape() && result_type == input_type) { return getOperand(0); } @@ -865,7 +865,7 @@ struct RemoveRedundantUnpackPack : public RewritePattern { PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { TFL::PackOp pack_op = cast(op); - Operation *first_input = pack_op.getOperand(0)->getDefiningOp(); + Operation *first_input = pack_op.getOperand(0).getDefiningOp(); if (!first_input) return matchFailure(); auto input_unpack_op = dyn_cast_or_null(first_input); if (!input_unpack_op) return matchFailure(); @@ -880,8 +880,8 @@ struct RemoveRedundantUnpackPack : public RewritePattern { return matchFailure(); for (auto input_output : llvm::zip(pack_op.getOperands(), input_unpack_op.getResults())) { - Value *pack_input = std::get<0>(input_output); - Value *unpack_output = std::get<1>(input_output); + Value pack_input = std::get<0>(input_output); + Value unpack_output = std::get<1>(input_output); // Make sure the ordering is the same for the pack op & unpack op. if (pack_input != unpack_output) return matchFailure(); } @@ -905,9 +905,9 @@ void PackOp::getCanonicalizationPatterns(OwningRewritePatternList &results, //===----------------------------------------------------------------------===// static LogicalResult Verify(SliceOp op) { - auto input_type = op.input()->getType().cast(); - auto begin_type = op.begin()->getType().cast(); - auto size_type = op.size()->getType().cast(); + auto input_type = op.input().getType().cast(); + auto begin_type = op.begin().getType().cast(); + auto size_type = op.size().getType().cast(); if (input_type.hasStaticShape() && begin_type.hasStaticShape() && size_type.hasStaticShape()) { if (input_type.getRank() != begin_type.getNumElements()) { @@ -984,8 +984,8 @@ OpFoldResult SubOp::fold(ArrayRef operands) { // TopKOp //===----------------------------------------------------------------------===// -static void BuildTopKOp(Builder *builder, OperationState &result, Value *input, - Value *k) { +static void BuildTopKOp(Builder *builder, OperationState &result, Value input, + Value k) { // Output size is only known if k is constant value. A negative dimension is // considered dynamic so use -1 here if k is not a constant value. int const_k = -1; @@ -995,7 +995,7 @@ static void BuildTopKOp(Builder *builder, OperationState &result, Value *input, // TODO(jpienaar): This should use a helper function. const_k = cst.getValue({}).getValue().getSExtValue(); - auto val_type = input->getType().cast(); + auto val_type = input.getType().cast(); // If value is unranked, then so is results. if (!val_type.hasRank()) return TFL::TopKV2Op::build( @@ -1035,7 +1035,7 @@ struct DropFakeQuant : public RewritePattern { // If all the users of this op have valid "minmax" attributes, it is matched // and can be removed. auto fakeQuantOp = cast(op); - for (auto *operand : fakeQuantOp.getResult()->getUsers()) + for (auto *operand : fakeQuantOp.getResult().getUsers()) if (!HasValidMinMaxAttribute(operand)) return matchFailure(); return matchSuccess(); @@ -1075,7 +1075,7 @@ static LogicalResult Verify(UnpackOp op) { // Extracts and returns the signed integer constant in a 0-rank integer tensor // or 1-element 1-rank integer tensor if 'value' is a constant. -static llvm::Optional ExtractConstantIntFromTensor(Value *value) { +static llvm::Optional ExtractConstantIntFromTensor(Value value) { ElementsAttr attr; if (!matchPattern(value, m_Constant(&attr))) return {}; if (attr.getNumElements() != 1) return {}; @@ -1101,8 +1101,8 @@ static LogicalResult VerifySplitOpOutputTypes( ExpectedOutputTypeGetter get_expected_output_type) { for (int64_t i = 0; i < num_splits; ++i) { auto expected_output_type = get_expected_output_type(i); - Value *output = op->getResult(i); - auto output_type = output->getType().dyn_cast(); + Value output = op->getResult(i); + auto output_type = output.getType().dyn_cast(); if (!output_type || output_type != expected_output_type) return op->emitOpError() << "output #" << i << " should be " << expected_output_type; @@ -1121,7 +1121,7 @@ static LogicalResult Verify(SplitOp op) { if (!split_dim_opt) return success(); // If 'input' is not a ranked tensor, there are no other checks. - auto input_type = op.value()->getType().dyn_cast(); + auto input_type = op.value().getType().dyn_cast(); if (!input_type) return success(); int64_t split_dim = split_dim_opt.getValue(); @@ -1157,7 +1157,7 @@ static LogicalResult Verify(SplitVOp op) { if (!split_dim_opt) return success(); // If 'input' is not a ranked tensor, there are no other checks. - auto input_type = op.value()->getType().dyn_cast(); + auto input_type = op.value().getType().dyn_cast(); if (!input_type) return success(); int64_t split_dim = split_dim_opt.getValue(); @@ -1177,8 +1177,7 @@ static LogicalResult Verify(SplitVOp op) { return success(); if (size_splits_attr.getNumElements() != num_splits) { - auto size_splits_type = - op.size_splits()->getType().cast(); + auto size_splits_type = op.size_splits().getType().cast(); RankedTensorType expected_size_splits_type = RankedTensorType::get({num_splits}, size_splits_type.getElementType()); return op.emitOpError("'size_splits' should be ") @@ -1414,7 +1413,7 @@ OpFoldResult RankOp::fold(ArrayRef operands) { } // Also fold if `input` has a known rank. - auto input_type = input()->getType().cast(); + auto input_type = input().getType().cast(); // Do not fold if rank is zero because the TFLite converter doesn't // distinguish between unranked input and scalar input due to b/138865275. // TODO(b/138865275): Remove `input_type.getRank() != 0` in the following @@ -1438,6 +1437,56 @@ OpFoldResult ConstOp::fold(ArrayRef operands) { return value(); } +//===----------------------------------------------------------------------===// +// SelectV2Op +//===----------------------------------------------------------------------===// + +static void BuildSelectV2Op(Builder *builder, OperationState &result, + Value cond, Value x, Value y) { + auto operand_type = + OpTrait::util::getBroadcastedType(x.getType(), y.getType()); + + if (!operand_type) + emitError(result.location) << "non-broadcastable operands: " << x.getType() + << " and " << y.getType(); + + bool has_static_cond_shape = false; + bool has_static_operand_shape = false; + ArrayRef cond_shape; + ArrayRef operand_shape; + + if (auto shaped_type = cond.getType().dyn_cast()) { + if (shaped_type.hasStaticShape()) { + has_static_cond_shape = true; + cond_shape = shaped_type.getShape(); + } + } + if (auto shaped_type = operand_type.dyn_cast()) { + if (shaped_type.hasStaticShape()) { + has_static_operand_shape = true; + operand_shape = shaped_type.getShape(); + } + } + + SmallVector broadcastedShape; + if (has_static_cond_shape && has_static_operand_shape && + !OpTrait::util::getBroadcastedShape(cond_shape, operand_shape, + broadcastedShape)) { + emitError(result.location) << "non-broadcastable operands: " << operand_type + << " and " << cond.getType(); + } + + result.addOperands({cond, x, y}); + + auto elementType = x.getType().dyn_cast().getElementType(); + if (has_static_cond_shape && has_static_operand_shape) { + result.types.push_back( + RankedTensorType::get(broadcastedShape, elementType)); + } else { + result.types.push_back(UnrankedTensorType::get(elementType)); + } +} + //===----------------------------------------------------------------------===// // RangeOp //===----------------------------------------------------------------------===// @@ -1521,9 +1570,8 @@ OpFoldResult RangeOp::fold(ArrayRef operands) { //===----------------------------------------------------------------------===// static LogicalResult Verify(TransposeConvOp op) { - ShapedType output_type = op.output()->getType().cast(); - ShapedType output_shape_type = - op.output_shape()->getType().cast(); + ShapedType output_type = op.output().getType().cast(); + ShapedType output_shape_type = op.output_shape().getType().cast(); if (output_type.hasRank() && output_shape_type.hasStaticShape()) { if (output_type.getRank() != output_shape_type.getDimSize(0)) { return op.emitOpError(llvm::formatv( @@ -1629,9 +1677,9 @@ OpFoldResult TransposeOp::fold(ArrayRef operands) { } static LogicalResult Verify(TransposeOp op) { - auto input_type = op.x()->getType().cast(); - auto perm_type = op.perm()->getType().cast(); - auto output_type = op.y()->getType().cast(); + auto input_type = op.x().getType().cast(); + auto perm_type = op.perm().getType().cast(); + auto output_type = op.y().getType().cast(); if (input_type.hasStaticShape() && perm_type.hasStaticShape()) { if (perm_type.getNumElements() != input_type.getRank()) { return op.emitOpError( diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h index 4fcfea7e9c7..c3c880d8cb6 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h @@ -18,15 +18,15 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_OPS_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_OPS_H_ -#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:local_config_mlir -#include "mlir/Dialect/Traits.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Dialect.h" // TF:local_config_mlir -#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/Support/Functional.h" // TF:local_config_mlir -#include "mlir/Support/LLVM.h" // TF:local_config_mlir +#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project +#include "mlir/Dialect/Traits.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Dialect.h" // TF:llvm-project +#include "mlir/IR/OpImplementation.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/Support/Functional.h" // TF:llvm-project +#include "mlir/Support/LLVM.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h" #include "tensorflow/lite/schema/schema_generated.h" diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index fdc256acf41..691264d32a4 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -135,7 +135,7 @@ def TFL_FpOrI32OrI64Tensor : TensorOf<[AnyFloat, TFL_Int32Or64]>; //===----------------------------------------------------------------------===// class TFL_OperandIsUnrankedPred : - CPred<"$_op.getOperand(" # n # ")->getType().isa()">; + CPred<"$_op.getOperand(" # n # ").getType().isa()">; // TODO: Some of these could be generalized and/or moved to more general // location. @@ -144,38 +144,38 @@ class TFL_OperandHasRank : PredOpTrait<"operand " # n # " is " # m # "-D", Or<[TFL_OperandIsUnrankedPred, CPred<"$_op.getOperand(" # n # - ")->getType().cast().getRank() == " # m>]>>; + ").getType().cast().getRank() == " # m>]>>; // Returns true if the n-th operand is ranked and has rank dim. class TFL_OperandHasKnownRank : And<[ - CPred<"$_op.getOperand(" # n # ")->getType().isa()">, - CPred<"$_op.getOperand(" # n # ")->getType().cast().getRank() == " + CPred<"$_op.getOperand(" # n # ").getType().isa()">, + CPred<"$_op.getOperand(" # n # ").getType().cast().getRank() == " # dim>]>; // True if operand n is ranked and has a rank > dim. class TFL_OperandIsRankedAndHasDimPred : And<[ - CPred<"$_op.getOperand(" # n # ")->getType().isa()">, - CPred<"$_op.getOperand(" # n # ")->getType().cast().getRank() > " + CPred<"$_op.getOperand(" # n # ").getType().isa()">, + CPred<"$_op.getOperand(" # n # ").getType().cast().getRank() > " # dim>]>; class TFL_OperandDimEquals : And<[ TFL_OperandIsRankedAndHasDimPred, - CPred<"$_op.getOperand(" # n # ")->getType().cast()" + CPred<"$_op.getOperand(" # n # ").getType().cast()" ".getShape()[" # dim # " ] == " # size>]>; // Returns true if the n-th operand has unknown rank or at least rank m. class TFL_OperandHasAtleastRank : PredOpTrait<"operand " # n # " is " # m # "-D", - Or<[CPred<"$_op.getOperand(" # n # ")->getType().isa()">, + Or<[CPred<"$_op.getOperand(" # n # ").getType().isa()">, CPred<"$_op.getOperand(" # n # - ")->getType().cast().getRank() >= " # m>]>>; + ").getType().cast().getRank() >= " # m>]>>; class TFL_OperandRankEquals1DimOfOperand : PredOpTrait<"operand " # x # "'s rank equals operand " # y # "'s size", CPred<"$_op.getOperand(" # x # - ")->getType().cast().getRank() == " + ").getType().cast().getRank() == " "$_op.getOperand(" # y # - ")->getType().cast().getShape()[0]">>; + ").getType().cast().getShape()[0]">>; class TFL_Operand0DOr1ElementTensor : PredOpTrait<"operand #" # x # " is an 0-d tensor or 1-d tensor w/ 1 element", @@ -195,7 +195,7 @@ class TFL_OperandHasRankLessThan : PredOpTrait<"operand " # n # " is maximum " # m # "-D", Or<[TFL_OperandIsUnrankedPred, CPred<"$_op.getOperand(" # n # - ")->getType().cast().getRank() <= " # m>]>>; + ").getType().cast().getRank() <= " # m>]>>; // This is a quantization-aware version of TCresVTEtIsSameAsOp class TFL_TCresVTEtIsSameAsOp : And<[ @@ -224,10 +224,10 @@ def BinaryOpSameElementTypeConstraint : //===----------------------------------------------------------------------===// def TFL_BroadcastableBinaryBuilder : OpBuilder< - "Builder *builder, OperationState &result, Value *lhs, Value *rhs", + "Builder *builder, OperationState &result, Value lhs, Value rhs", [{ auto resultType = - OpTrait::util::getBroadcastedType(lhs->getType(), rhs->getType()); + OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType()); if (!resultType) mlir::emitError(result.location, "non-broadcastable operands"); result.addOperands({lhs, rhs}); @@ -235,7 +235,7 @@ def TFL_BroadcastableBinaryBuilder : OpBuilder< }]>; def TFL_FusedBroadcastableBinaryBuilder : OpBuilder< - "Builder *builder, OperationState &result, Value *lhs, Value *rhs, " + "Builder *builder, OperationState &result, Value lhs, Value rhs, " "StringAttr fusedActivationFunction", [{ buildFusedBroadcastableBinOp( @@ -243,7 +243,7 @@ def TFL_FusedBroadcastableBinaryBuilder : OpBuilder< }]>; def TFL_ComparisonBinaryBuilder : OpBuilder< - "Builder *builder, OperationState &result, Value *lhs, Value *rhs", + "Builder *builder, OperationState &result, Value lhs, Value rhs", [{ buildComparisonBinOp(builder, result, lhs, rhs); }]>; @@ -427,6 +427,33 @@ def TFL_TransposeConvOp: let verifier = [{ return Verify(*this); }]; } +def TFL_Convolution2DTransposeBiasOp : + Op { + let summary = " Transpose convolution with bias operator"; + + let description = [{ +Performs transpose convolution operation on inputs, +with the option of adding a bias. +Note this is a custom op that is not supported in the standard runtime. + + Inputs: + `inputs[0]`: required: the input activation tensor + `inputs[1]`: required: the filter weight tensor + `inputs[2]`: optional: the bias tensor + }]; + + let arguments = ( + ins AnyTensor:$input, + AnyTensor:$filter, + TFL_TensorOfOrNone<[AnyType]>:$bias, + TFL_PaddingAttr:$padding, + I32Attr:$stride_h, + I32Attr:$stride_w + ); + + let results = (outs AnyTensor:$output); +} + def TFL_AveragePool2DOp: TFL_Op<"average_pool_2d", [NoSideEffect, SameOperandsAndResultsScale]> { let summary = "Average_pool_2d operator"; @@ -471,7 +498,7 @@ def TFL_ArgMaxOp : TFL_Op<"arg_max", [NoSideEffect]> { let hasOptions = 1; DerivedTFLiteTypeAttr output_type = DerivedTFLiteTypeAttr<[{ - return getResult()->getType().cast().getElementType(). + return getResult().getType().cast().getElementType(). cast().getWidth() > 32 ? tflite::TensorType_INT64 : tflite::TensorType_INT32; }]>; @@ -500,7 +527,7 @@ def TFL_ArgMinOp : TFL_Op<"arg_min", [NoSideEffect]> { let hasOptions = 1; DerivedTFLiteTypeAttr output_type = DerivedTFLiteTypeAttr<[{ - return getResult()->getType().cast().getElementType(). + return getResult().getType().cast().getElementType(). cast().getWidth() > 32 ? tflite::TensorType_INT64 : tflite::TensorType_INT32; }]>; @@ -669,7 +696,7 @@ def TFL_GatherOp : TFL_Op<"gather", [ let builders = [ OpBuilder<"Builder *builder, OperationState &result, " - "Value *params, Value *indices, IntegerAttr axis", + "Value params, Value indices, IntegerAttr axis", [{ BuildGatherOp(builder, result, params, indices, axis); }]> ]; @@ -932,7 +959,7 @@ def TFL_NotEqualOp : TFL_Op<"not_equal", [ let builders = [ OpBuilder< - "Builder *builder, OperationState &result, Value *lhs, Value *rhs", + "Builder *builder, OperationState &result, Value lhs, Value rhs", [{ buildComparisonBinOp(builder, result, lhs, rhs); }]> @@ -1427,6 +1454,63 @@ def TFL_MaxPool2DOp : TFL_Op<"max_pool_2d", [ let customOption = "Pool2DOptions"; } +def TFL_MaxPoolingWithArgMax2DOp : + Op { + let summary = "Max Pool 2D with argmax op"; + + let description = [{ + Performs max pooling on the input and outputs both max values and indices. + Each index is a flatten index in a sub-array of "filter_w" x "filter_h" size + Note this is a custom op that is not supported in the standard runtime. + + Inputs: + `inputs[0]`: required: the input activation tensor + }]; + + let arguments = ( + ins AnyTensor:$input, + TFL_PaddingAttr:$padding, + I32Attr:$stride_w, + I32Attr:$stride_h, + I32Attr:$filter_w, + I32Attr:$filter_h + ); + + let results = (outs + AnyTensor:$value, + AnyTensor:$indices + ); +} + +def TFL_MaxUnpooling2DOp : + Op { + let summary = "Max Unpool 2D"; + + let description = [{ + Performs max unpool operation. + To some extent this is the reverse operation of max pooling: + the elements in the input activation tensor is stored into the position + specified by the input indices. + Note this is a custom op that is not supported in the standard runtime. + + Inputs: + `inputs[0]`: required: the input activation tensor + `inputs[1]`: required: the input indices + }]; + + let arguments = ( + ins AnyTensor:$input, + AnyTensor:$indices, + TFL_PaddingAttr:$padding, + I32Attr:$stride_w, + I32Attr:$stride_h, + I32Attr:$filter_w, + I32Attr:$filter_h + ); + + let results = (outs AnyTensor:$outputs); +} + def TFL_MaximumOp : TFL_Op<"maximum", [ Broadcastable, NoSideEffect, Commutative, SameOperandsAndResultsScale, TFL_OperandHasRankLessThan<0, 4>, TFL_OperandHasRankLessThan<1, 4>]> { @@ -1996,7 +2080,7 @@ def TFL_ShapeOp: TFL_Op<"shape", [NoSideEffect]> { let results = (outs AnyTensor:$output); DerivedTypeAttr out_type = DerivedTypeAttr<[{ - return getResult()->getType().cast().getElementType(); + return getResult().getType().cast().getElementType(); }]>; let hasOptions = 1; @@ -2081,9 +2165,9 @@ def TFL_SelectOp : TFL_Op<"select", [NoSideEffect, // TODO(jpienaar): autogenerate this. let builders = [OpBuilder<"Builder *builder, OperationState &result, " - "Value *condition, Value *x, Value *y", + "Value condition, Value x, Value y", [{ - auto resultType = x->getType(); + auto resultType = x.getType(); result.addOperands({condition, x, y}); result.types.push_back(resultType); }]>]; @@ -2091,6 +2175,32 @@ def TFL_SelectOp : TFL_Op<"select", [NoSideEffect, let hasOptions = 1; } +def TFL_SelectV2Op : TFL_Op<"select_v2", [NoSideEffect]> { + let summary = "SelectV2 operator"; + + let description = [{ + Select values of 'x' if the corresponding value of 'condition' is true or + the value of 'y' if false. There are valid condition input sizes: + + 1. Either the same shape (in which case the select is elementwise), or + 2. Broadcastable shapes between 'condition', 'x' and 'y'. + }]; + + let arguments = (ins + TFL_BoolTensor:$condition, + TensorOf<[F32, I1, I8, I16, I32, I64, TFL_Uint8]>:$x, + TensorOf<[F32, I1, I8, I16, I32, I64, TFL_Uint8]>:$y); + let results = (outs AnyTensor:$output); + + let builders = [OpBuilder<"Builder *builder, OperationState &result, " + "Value cond, Value x, Value y", + [{ + BuildSelectV2Op(builder, result, cond, x, y); + }]>]; + + let hasOptions = 1; +} + def TFL_SinOp: TFL_Op<"sin", [ NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> { let summary = "Sine operator"; @@ -2277,7 +2387,7 @@ def TFL_TopKV2Op: TFL_Op<"topk_v2", [NoSideEffect, TFL_OperandHasRank<1,0>, I32Tensor:$indices); let builders = [OpBuilder<"Builder *builder, OperationState &result, " - "Value *input, Value *k", + "Value input, Value k", [{ BuildTopKOp(builder, result, input, k); }]>]; let hasOptions = 1; @@ -2333,14 +2443,14 @@ def TFL_UnpackOp : TFL_Op<"unpack", [NoSideEffect]> { }]; let arguments = (ins - TensorOf<[F32, I8, I32, QI8, QUI8]>:$input, + TensorOf<[F32, I1, I8, I32, QI8, QUI8]>:$input, I32Attr:$num, I32Attr:$axis ); let results = (outs - Variadic>:$outputs + Variadic>:$outputs ); let verifier = [{ return Verify(*this); }]; @@ -2707,7 +2817,7 @@ in the unique output `y`. In other words: ); DerivedTFLiteTypeAttr idx_out_type = DerivedTFLiteTypeAttr<[{ - return getResult(1)->getType().cast().getElementType(). + return getResult(1).getType().cast().getElementType(). cast().getWidth() > 32 ? tflite::TensorType_INT64 : tflite::TensorType_INT32; }]>; diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_traits.h b/tensorflow/compiler/mlir/lite/ir/tfl_traits.h index 0ec63531658..c489dc825d0 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_traits.h +++ b/tensorflow/compiler/mlir/lite/ir/tfl_traits.h @@ -19,7 +19,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_ #include "mlir/IR/OpDefinition.h" -#include "mlir/Support/LLVM.h" // TF:local_config_mlir +#include "mlir/Support/LLVM.h" // TF:llvm-project namespace mlir { namespace OpTrait { diff --git a/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc b/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc index eb840eeeeb4..3099cbeb1fa 100644 --- a/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc +++ b/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc @@ -30,10 +30,10 @@ limitations under the License. #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/SMLoc.h" #include "llvm/Support/SourceMgr.h" -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/Parser.h" // TF:local_config_mlir +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/Parser.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_translate_flags.h" #include "tensorflow/core/platform/init_main.h" diff --git a/tensorflow/compiler/mlir/lite/operator_converter_gen.cc b/tensorflow/compiler/mlir/lite/operator_converter_gen.cc index b2c125d7001..0f23cbefebd 100644 --- a/tensorflow/compiler/mlir/lite/operator_converter_gen.cc +++ b/tensorflow/compiler/mlir/lite/operator_converter_gen.cc @@ -27,7 +27,7 @@ limitations under the License. #include "llvm/TableGen/Main.h" #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" -#include "mlir/TableGen/Attribute.h" // TF:local_config_mlir +#include "mlir/TableGen/Attribute.h" // TF:llvm-project using llvm::DefInit; using llvm::dyn_cast; diff --git a/tensorflow/compiler/mlir/lite/python/BUILD b/tensorflow/compiler/mlir/lite/python/BUILD index 8e2198c2a6a..98f840d3fe7 100644 --- a/tensorflow/compiler/mlir/lite/python/BUILD +++ b/tensorflow/compiler/mlir/lite/python/BUILD @@ -28,10 +28,10 @@ cc_library( "//tensorflow/lite/toco:toco_flags_proto_cc", "//tensorflow/lite/toco:types_proto_cc", "//tensorflow/stream_executor/lib", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:Pass", - "@local_config_mlir//:Support", - "@local_config_mlir//:ViewOpGraph", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:ViewOpGraph", ], ) diff --git a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc index e2ba0cb822b..4ea26ee2f06 100644 --- a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc @@ -19,11 +19,11 @@ limitations under the License. #include #include "llvm/Support/ToolOutputFile.h" -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Support/FileUtilities.h" // TF:local_config_mlir -#include "mlir/Transforms/ViewOpGraph.h" // TF:local_config_mlir +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Support/FileUtilities.h" // TF:llvm-project +#include "mlir/Transforms/ViewOpGraph.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h" #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h" @@ -151,10 +151,9 @@ Status RegisterCustomBuiltinOps(const std::vector extra_tf_opdefs) { return errors::InvalidArgument("fail to parse extra OpDef"); } // Make sure the op is not already registered. If registered continue. - const OpRegistrationData* op_reg = nullptr; - auto status = - tensorflow::OpRegistry::Global()->LookUp(opdef.name(), &op_reg); - if (status.ok()) continue; + const OpRegistrationData* op_reg = + tensorflow::OpRegistry::Global()->LookUp(opdef.name()); + if (op_reg) continue; tensorflow::OpRegistry::Global()->Register( [opdef](tensorflow::OpRegistrationData* op_reg_data) -> Status { @@ -278,7 +277,6 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, auto status = ConvertTFExecutorToTFLOrFlatbuffer( module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops, emit_select_tf_ops, emit_custom_ops, quant_specs, result, &pm); - if (toco_flags.has_dump_graphviz_dir()) { TF_RETURN_IF_ERROR(DumpOpGraphToFile( // rename once we enable the new converter feature flag. diff --git a/tensorflow/compiler/mlir/lite/quantization/BUILD b/tensorflow/compiler/mlir/lite/quantization/BUILD index 4ef6ac0b0cb..7cc03adf543 100644 --- a/tensorflow/compiler/mlir/lite/quantization/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/BUILD @@ -13,7 +13,7 @@ package( package_group( name = "friends", - includes = ["@local_config_mlir//:subpackages"], + includes = ["//third_party/mlir:subpackages"], packages = ["//tensorflow/compiler/mlir/..."], ) @@ -26,8 +26,8 @@ filegroup( name = "quantization_td_files", srcs = [ "quantization.td", - "@local_config_mlir//:OpBaseTdFiles", - "@local_config_mlir//:QuantizationOpsTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:QuantizationOpsTdFiles", ], ) @@ -53,13 +53,13 @@ cc_library( "//tensorflow/core:lib_proto_parsing", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", - "@llvm//:support", - "@local_config_mlir//:Analysis", - "@local_config_mlir//:IR", - "@local_config_mlir//:Pass", - "@local_config_mlir//:QuantOps", - "@local_config_mlir//:StandardOps", - "@local_config_mlir//:Support", + "@llvm-project//llvm:support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", ], alwayslink = 1, ) @@ -75,11 +75,11 @@ cc_library( ], deps = [ "@com_google_absl//absl/memory", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:QuantOps", - "@local_config_mlir//:StandardOps", - "@local_config_mlir//:Support", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", # TODO(fengliuai): remove this dependence. "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/core:lib_proto_parsing", @@ -97,7 +97,7 @@ cc_library( deps = [ "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/strings", - "@llvm//:support", + "@llvm-project//llvm:support", ], ) @@ -107,8 +107,8 @@ tf_native_cc_binary( "tools/op_quant_spec_getters_gen.cc", ], deps = [ - "@llvm//:support", - "@llvm//:tablegen", - "@local_config_mlir//:TableGen", + "@llvm-project//llvm:support", + "@llvm-project//llvm:tablegen", + "@llvm-project//mlir:TableGen", ], ) diff --git a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc index 0326d122c07..4c4d8f1d9a2 100644 --- a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc +++ b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc @@ -23,18 +23,18 @@ limitations under the License. #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Regex.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir -#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:local_config_mlir -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/AffineExpr.h" // TF:local_config_mlir -#include "mlir/IR/AffineMap.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Location.h" // TF:local_config_mlir -#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Support/Functional.h" // TF:local_config_mlir -#include "mlir/Support/LLVM.h" // TF:local_config_mlir +#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project +#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/AffineExpr.h" // TF:llvm-project +#include "mlir/IR/AffineMap.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Support/Functional.h" // TF:llvm-project +#include "mlir/Support/LLVM.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/quantization/quantization_info.pb.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h" #include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h" @@ -70,16 +70,16 @@ class ImportQuantStatsPass : public FunctionPass { void ImportAsStatsOps(OpBuilder b, Operation *op, int index, const QuantParamsEntry &info); - void InsertStatsOpAtResult(OpBuilder b, Value *res, ElementsAttr layer_stats, + void InsertStatsOpAtResult(OpBuilder b, Value res, ElementsAttr layer_stats, ElementsAttr axis_stats, IntegerAttr axis); // If the index is out of range, this method returns false. Otherwise it // returns true if the value is a float tensor. bool IsQuantizableResult(Operation *op, int index) { if (index < 0 || index >= op->getNumResults()) return false; - Value *res = op->getResult(index); - return res->getType().isa() && - res->getType().cast().getElementType().isa(); + Value res = op->getResult(index); + return res.getType().isa() && + res.getType().cast().getElementType().isa(); } // A method to retrieve the name for the given op. @@ -117,13 +117,13 @@ bool ImportQuantStatsPass::ParseQuantStats(const std::string &stats_str) { return false; } -void ImportQuantStatsPass::InsertStatsOpAtResult(OpBuilder b, Value *res, +void ImportQuantStatsPass::InsertStatsOpAtResult(OpBuilder b, Value res, ElementsAttr layer_stats, ElementsAttr axis_stats, IntegerAttr axis) { auto stats_op = b.create(b.getUnknownLoc(), res, layer_stats, axis_stats, axis); - res->replaceAllUsesWith(stats_op); + res.replaceAllUsesWith(stats_op); stats_op.getOperation()->replaceUsesOfWith(stats_op, res); } diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD index 880f5ae5210..d076911761f 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD @@ -9,7 +9,7 @@ package( package_group( name = "friends", - includes = ["@local_config_mlir//:subpackages"], + includes = ["//third_party/mlir:subpackages"], packages = [ "//learning/brain/experimental/mlir/...", "//tensorflow/lite/...", @@ -36,9 +36,9 @@ cc_library( "//tensorflow/lite/core/api", "//tensorflow/lite/schema:schema_fbs", "@com_google_absl//absl/strings", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:Pass", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", ], ) @@ -53,6 +53,6 @@ tf_cc_binary( "//tensorflow/lite:framework", "//tensorflow/lite/schema:schema_fbs", "@com_google_absl//absl/strings", - "@llvm//:support", + "@llvm-project//llvm:support", ], ) diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc index 97aa128653f..d00357be155 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc @@ -17,11 +17,11 @@ limitations under the License. #include "absl/strings/string_view.h" #include "llvm/ADT/SmallVector.h" -#include "mlir/IR/Location.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Pass/PassManager.h" // TF:local_config_mlir +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/PassManager.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_import.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h" diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc index a145e75465e..3fd1ff2ac94 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc @@ -23,17 +23,17 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" -#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Matchers.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/Value.h" // TF:local_config_mlir -#include "mlir/Support/LLVM.h" // TF:local_config_mlir +#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Matchers.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/Value.h" // TF:llvm-project +#include "mlir/Support/LLVM.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h" @@ -146,14 +146,14 @@ class QuantizationDriver { // Adds all the users of index-th result of op to the work list. void AddUserToList(Operation *op, int index) { - for (auto *user : op->getResult(index)->getUsers()) { + for (auto *user : op->getResult(index).getUsers()) { work_list_.push_back(user); } } // Adds the defining op of index-th operand of op to the work list. void AddOperandToList(Operation *op, int index) { - if (auto *inst = op->getOperand(index)->getDefiningOp()) { + if (auto *inst = op->getOperand(index).getDefiningOp()) { work_list_.push_back(inst); } } @@ -183,20 +183,20 @@ class QuantizationDriver { // of the op. void QuantizeOpResult(Operation *op, int index, QuantParams params); - void QuantizeArg(BlockArgument *arg, QuantParams params); + void QuantizeArg(BlockArgument arg, QuantParams params); // Inserts the Quantize and Dequantize ops to quantize the value and returns // the Quantize op. - void QuantizeValue(Value *value, QuantParams params, Location loc); + void QuantizeValue(Value value, QuantParams params, Location loc); // Inserts the Quantize ops for requantizing the index-th result of the op. void RequantizeOpResult(Operation *op, int index, RequantizeState *state); - void RequantizeArg(BlockArgument *arg, RequantizeState *state); + void RequantizeArg(BlockArgument arg, RequantizeState *state); // Inserts the Quantize and Dequantize ops to quantize the value and returns // the Quantize op. - void RequantizeValue(Value *value, RequantizeState *state, Location loc); + void RequantizeValue(Value value, RequantizeState *state, Location loc); // A heuristic to get the quantization parameter satisfies the same scale // constraints for the op. Returns an empty option if this quantization @@ -213,7 +213,7 @@ class QuantizationDriver { return states_[result_states_[{op, index}]]; } - QuantState &GetArgQuantState(BlockArgument *arg) { + QuantState &GetArgQuantState(BlockArgument arg) { return states_[arg_states_[arg]]; } @@ -227,7 +227,7 @@ class QuantizationDriver { return rescale_states_[result_states_[{op, index}]]; } - RequantizeState &GetArgRequantizeState(BlockArgument *arg) { + RequantizeState &GetArgRequantizeState(BlockArgument arg) { return rescale_states_[arg_states_[arg]]; } @@ -235,32 +235,45 @@ class QuantizationDriver { // `as_result` is true or index-th operand if `as_result` is false. The state // is immutable if the type is a quantized type. Returns the index of this // new state in the state vector. - int InitializeState(Operation *op, int index, Value *val, bool as_result); + int InitializeState(Operation *op, int index, Value val, bool as_result); + + // Sets the state of an argument. If this value is cached, uses the cached + // result without creating new entry in the state vector. Otherwise, allocate + // a new entry in the state vector. + void InitializeArgState(BlockArgument arg, Value in, + llvm::DenseMap *cache) { + auto cached = cache->insert({in, 0}); + if (!cached.second) { + arg_states_[arg] = cached.first->second; + return; + } + QuantParams params = + quant::QuantizedType::getQuantizedElementType(in.getType()); + bool immutable = !EmptyParams(params); + int next_state_index = states_.size(); + states_.push_back({params, immutable}); + arg_states_[arg] = next_state_index; + cached.first->second = next_state_index; + } // Sets the state of the index-th operand of the op. If this operand is // cached, uses the cached result without creating new entry in the state // vector. Otherwise, allocate a new entry in the state vector. - void InitializeOperandState(Operation *op, int index, Value *in, - llvm::DenseMap *cache, - bool is_argument) { + void InitializeOperandState(Operation *op, int index, Value in, + llvm::DenseMap *cache) { auto cached = cache->insert({in, 0}); if (!cached.second) { operand_states_.insert({{op, index}, cached.first->second}); return; } cached.first->second = InitializeState(op, index, in, /*as_result=*/false); - if (is_argument) { - auto *arg = llvm::cast(in); - arg_states_[arg] = cached.first->second; - args_.push_back(arg); - } } // Sets the state of the index-th result of the op. If this result is cached, // uses the cached result without creating new entry in the state vector. // Otherwise, allocate a new entry in the state vector. - void InitializeResultState(Operation *op, int index, Value *res, - llvm::DenseMap *cache) { + void InitializeResultState(Operation *op, int index, Value res, + llvm::DenseMap *cache) { auto cached = cache->insert({res, 0}); if (!cached.second) { result_states_.insert({{op, index}, cached.first->second}); @@ -279,7 +292,8 @@ class QuantizationDriver { // rest are weights. llvm::DenseSet weights_; - // The weights require narrow_range quantization. If the value of this map is + // The weights require narrow_range quantization. This map collects all the + // weight operands defined by the op quant spec. If the value of the entry is // positive, per-channel quantization is required. llvm::DenseMap optimized_weights_; @@ -300,11 +314,11 @@ class QuantizationDriver { // results and arguments. llvm::DenseMap operand_states_; llvm::DenseMap result_states_; - llvm::DenseMap arg_states_; + llvm::DenseMap arg_states_; // This vector is to preserve the arguments order, so the newly inserted // quantized ops for the arguments are deterministically ordered. - llvm::SmallVector args_; + llvm::SmallVector args_; OpQuantSpecGetter op_quant_spec_getter_; }; @@ -321,10 +335,10 @@ bool QuantizationDriver::IsQuantized(Operation *op) { return true; } -int QuantizationDriver::InitializeState(Operation *op, int index, Value *val, +int QuantizationDriver::InitializeState(Operation *op, int index, Value val, bool as_result) { QuantParams params = - quant::QuantizedType::getQuantizedElementType(val->getType()); + quant::QuantizedType::getQuantizedElementType(val.getType()); bool immutable = !EmptyParams(params); int next_state_index = states_.size(); states_.push_back({params, immutable}); @@ -338,7 +352,7 @@ int QuantizationDriver::InitializeState(Operation *op, int index, Value *val, bool QuantizationDriver::SetConstantResultParams(Operation *op) { ElementsAttr attr; - Value *res = op->getResult(0); + Value res = op->getResult(0); if (!matchPattern(res, m_Constant(&attr))) { return false; } @@ -362,7 +376,7 @@ bool QuantizationDriver::SetConstantResultParams(Operation *op) { } else { // per-tensor quantization weight final_type = GetUniformQuantizedTypeForWeight( - attr, /*symmetric=*/is_weight_with_per_channel_support, + attr, /*symmetric=*/is_weight && is_signed_, /*num_bits=*/8, is_signed_, /*narrow_range_=*/is_weight); } @@ -428,18 +442,18 @@ bool QuantizationDriver::SetOperandParams(Operation *op, int index, void QuantizationDriver::QuantizeOpResult(Operation *op, int index, QuantParams params) { builder_.setInsertionPoint(op->getBlock(), ++Block::iterator(op)); - Value *original_result = op->getResult(index); + Value original_result = op->getResult(index); QuantizeValue(original_result, params, op->getLoc()); } -void QuantizationDriver::QuantizeArg(BlockArgument *arg, QuantParams params) { - builder_.setInsertionPointToStart(arg->getOwner()); +void QuantizationDriver::QuantizeArg(BlockArgument arg, QuantParams params) { + builder_.setInsertionPointToStart(arg.getOwner()); QuantizeValue(arg, params, builder_.getUnknownLoc()); } -void QuantizationDriver::QuantizeValue(Value *value, QuantParams params, +void QuantizationDriver::QuantizeValue(Value value, QuantParams params, Location loc) { - Type expressed_type = value->getType(); + Type expressed_type = value.getType(); Type new_type = params.castFromExpressedType(expressed_type); // This value isn't an expressed type (float), skip. if (!new_type) return; @@ -451,7 +465,7 @@ void QuantizationDriver::QuantizeValue(Value *value, QuantParams params, quantize.output()); // `original_result` has a use to `quantize`, so this will replace that use // by the result of `dequantize`. Remember to reset that use afterwards - value->replaceAllUsesWith(dequantize); + value.replaceAllUsesWith(dequantize); quantize.getOperation()->replaceUsesOfWith(dequantize, value); } @@ -459,9 +473,9 @@ void QuantizationDriver::RequantizeOpResult(Operation *op, int index, RequantizeState *state) { if (state->pos == RequantizeState::NO_REQUANTIZE) return; builder_.setInsertionPointAfter(op); - Value *value = op->getResult(index); + Value value = op->getResult(index); if (state->pos == RequantizeState::ON_OUTPUT) { - Operation *user = value->getUses().begin().getUser(); + Operation *user = value.getUses().begin().getUser(); if (llvm::isa(user)) { // The requantize op is inserted between `quantize` and `dequantize` ops. value = user->getResult(0); @@ -471,31 +485,31 @@ void QuantizationDriver::RequantizeOpResult(Operation *op, int index, RequantizeValue(value, state, op->getLoc()); } -void QuantizationDriver::RequantizeArg(BlockArgument *arg, +void QuantizationDriver::RequantizeArg(BlockArgument arg, RequantizeState *state) { - Value *value = arg; - builder_.setInsertionPointToStart(arg->getOwner()); - if (value->hasOneUse()) { - auto user = value->use_begin().getUser(); + Value value = arg; + builder_.setInsertionPointToStart(arg.getOwner()); + if (value.hasOneUse()) { + auto user = value.use_begin().getUser(); if (auto q = llvm::dyn_cast(user)) { value = q.output(); - builder_.setInsertionPoint(arg->getOwner(), ++Block::iterator(user)); + builder_.setInsertionPoint(arg.getOwner(), ++Block::iterator(user)); } } RequantizeValue(value, state, builder_.getUnknownLoc()); } -void QuantizationDriver::RequantizeValue(Value *value, RequantizeState *state, +void QuantizationDriver::RequantizeValue(Value value, RequantizeState *state, Location loc) { Type new_type; if (state->pos == RequantizeState::ON_INPUT) { - Type expressed_type = value->getType(); + Type expressed_type = value.getType(); // The value needs to be requantized. A Quantize op will be created to use // it as the operand and replace its uses. new_type = state->params.castFromExpressedType(expressed_type); } else { Type expressed_type = - quant::QuantizedType::castToExpressedType(value->getType()); + quant::QuantizedType::castToExpressedType(value.getType()); if (!expressed_type) return; // The value needs to be requantized. A Quantize op will be created to use @@ -508,7 +522,7 @@ void QuantizationDriver::RequantizeValue(Value *value, RequantizeState *state, TypeAttr type_attr = TypeAttr::get(new_type); auto requantize_op = builder_.create(loc, new_type, value, type_attr); - value->replaceAllUsesWith(requantize_op); + value.replaceAllUsesWith(requantize_op); requantize_op.getOperation()->replaceUsesOfWith(requantize_op, value); } @@ -586,10 +600,10 @@ void QuantizationDriver::PreprocessConstantOps() { auto type = cst.getType().dyn_cast(); if (!type || !type.getElementType().isa()) return; - Value *value = cst.getResult(); + Value value = cst.getResult(); SmallVector, 4> bias_users; bool used_as_weight = false; - for (auto &use : value->getUses()) { + for (auto &use : value.getUses()) { auto spec = GetQuantSpec(use.getOwner()); auto biases = spec->biases_params; Operation *user = use.getOwner(); @@ -629,7 +643,20 @@ void QuantizationDriver::PreprocessConstantOps() { } void QuantizationDriver::SetupAllStates() { - llvm::DenseMap value_to_state; + llvm::DenseMap value_to_state; + + for (auto arg : fn_.getArguments()) { + args_.push_back(arg); + Value value = arg; + // If the argument is quantized, it should only has one user. + if (arg.hasOneUse()) { + auto user = value.use_begin().getUser(); + if (auto q = llvm::dyn_cast(user)) { + value = q.output(); + } + } + InitializeArgState(arg, value, &value_to_state); + } fn_.walk([&](Operation *op) { if (op->isKnownTerminator() || @@ -638,26 +665,24 @@ void QuantizationDriver::SetupAllStates() { work_list_.push_back(op); for (int i = 0, e = op->getNumOperands(); i != e; ++i) { - auto *operand = op->getOperand(i); - bool is_argument = true; - if (auto *inst = operand->getDefiningOp()) { + auto operand = op->getOperand(i); + if (auto *inst = operand.getDefiningOp()) { // If the operand comes from a tfl.dequantize op, we use the quantized // input of this tfl.dequantize op to set the state. if (auto dq = llvm::dyn_cast(inst)) { operand = dq.input(); } - is_argument = false; } - InitializeOperandState(op, i, operand, &value_to_state, is_argument); + InitializeOperandState(op, i, operand, &value_to_state); } for (int res = 0, e = op->getNumResults(); res != e; ++res) { - auto *result = op->getResult(res); + Value result = op->getResult(res); // If the result has been quantized, it should only be used by a // tfl.quantize op. For this case, we uses the quantized result to // create the state and mark it immutable. - if (result->hasOneUse()) { - auto user = result->use_begin().getUser(); + if (result.hasOneUse()) { + auto user = result.use_begin().getUser(); if (auto q = llvm::dyn_cast(user)) { result = q.output(); } @@ -746,7 +771,7 @@ bool QuantizationDriver::PropagateParams() { } void QuantizationDriver::Finalize() { - for (auto *arg : args_) { + for (auto arg : args_) { auto &state = GetArgQuantState(arg); auto &requantize = GetArgRequantizeState(arg); if (state.IsEmpty() || diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_passes.h b/tensorflow/compiler/mlir/lite/quantization/quantization_passes.h index 56beb387370..58e9538045b 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_passes.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_passes.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_PASSES_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_PASSES_H_ -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Pass/PassManager.h" // TF:local_config_mlir +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/PassManager.h" // TF:llvm-project namespace mlir { namespace quant { diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_traits.h b/tensorflow/compiler/mlir/lite/quantization/quantization_traits.h index 3830d11afe4..ea278344dec 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_traits.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_traits.h @@ -18,8 +18,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_TRAITS_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_TRAITS_H_ -#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir -#include "mlir/Support/LLVM.h" // TF:local_config_mlir +#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project +#include "mlir/Support/LLVM.h" // TF:llvm-project namespace mlir { namespace OpTrait { @@ -70,7 +70,7 @@ class FixedResultUniformScale { QuantizedType GetResultQuantizedType(int index) { auto op = this->getOperation(); auto result_type = - op->getResult(index)->getType().template cast(); + op->getResult(index).getType().template cast(); Builder builder(op->getContext()); IntegerType storage_type = builder.getIntegerType(BitWidth); const double scale = static_cast(ScaleMantissa) * diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc index ca10809be69..5ff4ffa4b92 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc @@ -21,15 +21,15 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir -#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:local_config_mlir -#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir -#include "mlir/Dialect/QuantOps/QuantizeUtils.h" // TF:local_config_mlir -#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/Support/LLVM.h" // TF:local_config_mlir +#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project +#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project +#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project +#include "mlir/Dialect/QuantOps/QuantizeUtils.h" // TF:llvm-project +#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/Support/LLVM.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h" namespace mlir { @@ -367,7 +367,7 @@ ElementsAttr Quantize(Attribute real_value, Type tensor_type) { static bool PreferResultScale(Operation* op) { int float_operands = 0; for (auto operand : op->getOperands()) { - if (auto operand_type = operand->getType().dyn_cast()) { + if (auto operand_type = operand.getType().dyn_cast()) { if (operand_type.getElementType().isa()) { if (float_operands++ > 1) return true; } @@ -400,22 +400,22 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func, quant::StatisticsOp stats_op = all_stats_ops.back(); all_stats_ops.pop_back(); - if (auto def = stats_op.arg()->getDefiningOp()) { + if (auto def = stats_op.arg().getDefiningOp()) { if (IsStatsRedundant(def, op_quant_spec_getter)) { redundant_stats_ops.insert(stats_op); } } - for (auto user : stats_op.getResult()->getUsers()) { + for (auto user : stats_op.getResult().getUsers()) { // We don't propagate this parameter down if it has multiple operands. // We want to use the result parameter scales instead. if (user->hasTrait() && !PreferResultScale(user)) { - for (Value* res : user->getResults()) { - if (res->hasOneUse()) { + for (Value res : user->getResults()) { + if (res.hasOneUse()) { if (auto next_stats = llvm::dyn_cast( - *res->getUsers().begin())) { + *res.getUsers().begin())) { // quantization parameters can be propagated to next_stats redundant_stats_ops.insert(next_stats); // add next_stats to the work list so propagation can @@ -440,12 +440,12 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func, quant::StatisticsOp stats_op = all_stats_ops.back(); all_stats_ops.pop_back(); - if (auto def = stats_op.arg()->getDefiningOp()) { + if (auto def = stats_op.arg().getDefiningOp()) { if (def->hasTrait() && PreferResultScale(def)) { for (auto input : def->getOperands()) { if (auto next_stats = llvm::dyn_cast_or_null( - input->getDefiningOp())) { + input.getDefiningOp())) { redundant_stats_ops.insert(next_stats); all_stats_ops.push_back(next_stats); } @@ -458,7 +458,7 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func, for (auto it : redundant_stats_ops) { if (!llvm::isa(it)) return true; auto stats_op = llvm::cast(it); - stats_op.getResult()->replaceAllUsesWith(stats_op.arg()); + stats_op.getResult().replaceAllUsesWith(stats_op.arg()); stats_op.erase(); } diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h index 0f7ec91ebc6..60fc2add989 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h @@ -23,18 +23,18 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir -#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:local_config_mlir -#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/BlockAndValueMapping.h" // TF:local_config_mlir -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Matchers.h" // TF:local_config_mlir -#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/Support/LLVM.h" // TF:local_config_mlir +#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project +#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project +#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Matchers.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/Support/LLVM.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h" namespace mlir { @@ -116,7 +116,7 @@ struct ConvertStatsToQDQs : public OpRewritePattern { auto q = rewriter.create(op.getLoc(), result_type, op.arg(), TypeAttr::get(result_type)); auto dq = rewriter.create(op.getLoc(), op.getType(), q); - op.getResult()->replaceAllUsesWith(dq); + op.getResult().replaceAllUsesWith(dq); q.getOperation()->replaceUsesOfWith(dq, op.arg()); op.erase(); @@ -161,8 +161,8 @@ struct QuantizationPattern : public RewritePattern { if (op->getNumResults() != 1) { return matchFailure(); } - Value* quantized_value = op->getResult(0); - for (Operation* quantized_op : quantized_value->getUsers()) { + Value quantized_value = op->getResult(0); + for (Operation* quantized_op : quantized_value.getUsers()) { // If it is requantize op, we shouldn't rewrite this op. if (llvm::isa(quantized_op) || llvm::isa(quantized_op)) { return matchFailure(); @@ -176,17 +176,17 @@ struct QuantizationPattern : public RewritePattern { // Collect all the quantized inputs and "clone" the matched op by these // inputs. - SmallVector inputs; + SmallVector inputs; inputs.reserve(quantized_op->getNumOperands()); for (auto operand : quantized_op->getOperands()) { - Type operand_type = operand->getType(); + Type operand_type = operand.getType(); if (operand_type.isa()) { inputs.push_back(operand); continue; } - auto ele_type = operand->getType().cast().getElementType(); - if (auto op_inst = dyn_cast_or_null(operand->getDefiningOp())) { + auto ele_type = operand.getType().cast().getElementType(); + if (auto op_inst = dyn_cast_or_null(operand.getDefiningOp())) { inputs.push_back(op_inst.input()); } else if (ele_type.isa()) { // If the operand is an integer tensor, then it doesn't require the @@ -201,13 +201,13 @@ struct QuantizationPattern : public RewritePattern { // Collect all the quantized outputs and replace them by the results of // the new quantized op. - llvm::SmallDenseMap outputs_replaced; + llvm::SmallDenseMap outputs_replaced; SmallVector output_types; output_types.reserve(quantized_op->getNumResults()); for (auto enumerated_result : llvm::enumerate(quantized_op->getResults())) { - Value* result = enumerated_result.value(); - Type result_type = result->getType(); + Value result = enumerated_result.value(); + Type result_type = result.getType(); // Add this to the test coverage once we create test ops with none type // results. if (result_type.isa()) { @@ -216,20 +216,20 @@ struct QuantizationPattern : public RewritePattern { continue; } Type result_ele_type = - result->getType().cast().getElementType(); + result.getType().cast().getElementType(); // If the user is the Quantize op, it must be the only user. - if (result->hasOneUse() && llvm::isa(*result->user_begin())) { - auto user = llvm::cast(*result->user_begin()); + if (result.hasOneUse() && llvm::isa(*result.user_begin())) { + auto user = llvm::cast(*result.user_begin()); outputs_replaced.insert({user.output(), enumerated_result.index()}); output_types.push_back(user.getType()); } else if (result_ele_type.template isa()) { // If the result is an integer tensor, then it doesn't require the // D op in the pattern. outputs_replaced.insert({result, enumerated_result.index()}); - output_types.push_back(result->getType()); + output_types.push_back(result.getType()); } else if (static_cast(this)->AllowHybridResult()) { outputs_replaced.insert({result, enumerated_result.index()}); - output_types.push_back(result->getType()); + output_types.push_back(result.getType()); } else { return matchFailure(); } @@ -241,7 +241,7 @@ struct QuantizationPattern : public RewritePattern { output_types, quantized_op->getAttrs()); Operation* new_op = rewriter.createOperation(new_state); for (auto output : outputs_replaced) { - output.getFirst()->replaceAllUsesWith( + output.getFirst().replaceAllUsesWith( new_op->getResult(output.getSecond())); } @@ -252,7 +252,7 @@ struct QuantizationPattern : public RewritePattern { // For constant operands, the floating-point constant is duplicated in // case it is quantized. for (int i = 0, e = new_op->getNumOperands(); i != e; ++i) { - auto def = new_op->getOperand(i)->getDefiningOp(); + auto def = new_op->getOperand(i).getDefiningOp(); if (auto q = llvm::dyn_cast_or_null(def)) { DenseFPElementsAttr attr; if (!matchPattern(q.input(), m_Constant(&attr))) { @@ -265,7 +265,7 @@ struct QuantizationPattern : public RewritePattern { for (int i = 0, e = new_op->getNumResults(); i != e; ++i) { if (!quantized_op->getResult(i) - ->getType() + .getType() .cast() .getElementType() .isa()) { @@ -283,13 +283,13 @@ struct QuantizationPattern : public RewritePattern { // Find the Dequantize/Dequantize users of the new op results, and // replace the usage. Then all the floating-point ops are connected. // N.B. the return op will use this floating-point result. - for (auto user : new_op->getResult(i)->getUsers()) { + for (auto user : new_op->getResult(i).getUsers()) { // Skip the Requantize op, and we know it has a single user. if (llvm::isa(user)) { - user = *user->getResult(0)->getUsers().begin(); + user = *user->getResult(0).getUsers().begin(); } if (auto dequantize = llvm::dyn_cast(user)) { - dequantize.getResult()->replaceAllUsesWith( + dequantize.getResult().replaceAllUsesWith( quantized_op->getResult(i)); } } @@ -316,7 +316,7 @@ struct ConvertUnsignedToSigned : public OpRewritePattern { PatternMatchResult matchAndRewrite(Q op, PatternRewriter& rewriter) const override { - Type output_type = op.output()->getType(); + Type output_type = op.output().getType(); auto qtype = QType::getQuantizedElementType(output_type); if (!qtype || qtype.isSigned()) return this->matchFailure(); diff --git a/tensorflow/compiler/mlir/lite/quantization/tests/BUILD b/tensorflow/compiler/mlir/lite/quantization/tests/BUILD index 9f47185e90a..4faa8d2efe8 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tests/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/tests/BUILD @@ -4,7 +4,7 @@ package(licenses = ["notice"]) glob_lit_tests( data = [":test_utilities"], - driver = "@local_config_mlir//:run_lit.sh", + driver = "@llvm-project//mlir:run_lit.sh", test_file_exts = ["mlir"], ) @@ -14,6 +14,6 @@ filegroup( testonly = True, data = [ "//tensorflow/compiler/mlir:tf-opt", - "@llvm//:FileCheck", + "@llvm-project//llvm:FileCheck", ], ) diff --git a/tensorflow/compiler/mlir/lite/quantization/tools/op_quant_spec_getters_gen.cc b/tensorflow/compiler/mlir/lite/quantization/tools/op_quant_spec_getters_gen.cc index bc49b0df23f..abc38505abd 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tools/op_quant_spec_getters_gen.cc +++ b/tensorflow/compiler/mlir/lite/quantization/tools/op_quant_spec_getters_gen.cc @@ -20,7 +20,7 @@ limitations under the License. #include "llvm/TableGen/Main.h" #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" -#include "mlir/TableGen/Operator.h" // TF:local_config_mlir +#include "mlir/TableGen/Operator.h" // TF:llvm-project using llvm::LessRecord; using llvm::raw_ostream; @@ -36,7 +36,7 @@ using mlir::tblgen::Operator; // NOLINTNEXTLINE static bool OpQuantSpecWriter(raw_ostream &os, RecordKeeper &records) { llvm::Regex acc_uniform_trait_regex{"AccumulatorUniformScale<([0-9]*),"}; - llvm::Regex coeff_index_trait_regex{"AffineOpCoefficient<([0-9]*),"}; + llvm::Regex coeff_index_trait_regex{"AffineOpCoefficient<(-?[0-9]*),"}; llvm::Regex fixed_uniform_trait_regex{ "FixedResultUniformScale<([0-9]+).*(true|false)>"}; emitSourceFileHeader("Generated Ops Quant Spec Getters", os); diff --git a/tensorflow/compiler/mlir/lite/tests/BUILD b/tensorflow/compiler/mlir/lite/tests/BUILD index 9f47185e90a..4faa8d2efe8 100644 --- a/tensorflow/compiler/mlir/lite/tests/BUILD +++ b/tensorflow/compiler/mlir/lite/tests/BUILD @@ -4,7 +4,7 @@ package(licenses = ["notice"]) glob_lit_tests( data = [":test_utilities"], - driver = "@local_config_mlir//:run_lit.sh", + driver = "@llvm-project//mlir:run_lit.sh", test_file_exts = ["mlir"], ) @@ -14,6 +14,6 @@ filegroup( testonly = True, data = [ "//tensorflow/compiler/mlir:tf-opt", - "@llvm//:FileCheck", + "@llvm-project//llvm:FileCheck", ], ) diff --git a/tensorflow/compiler/mlir/lite/tests/debuginfo/BUILD b/tensorflow/compiler/mlir/lite/tests/debuginfo/BUILD index 2498a106f8f..5ef392b0ea0 100644 --- a/tensorflow/compiler/mlir/lite/tests/debuginfo/BUILD +++ b/tensorflow/compiler/mlir/lite/tests/debuginfo/BUILD @@ -7,10 +7,12 @@ glob_lit_tests( ":debug_info_files", ":test_utilities", ], - driver = "@local_config_mlir//:run_lit.sh", + driver = "@llvm-project//mlir:run_lit.sh", test_file_exts = [ "pbtxt", - "py", + # TODO(fengliuai): reenable these tests after the fused loc is + # supported in the diagnostic handler. + # "py", ], ) @@ -31,8 +33,8 @@ filegroup( ":saved_model_error", "//tensorflow/compiler/mlir/lite:flatbuffer_to_string", "//tensorflow/compiler/mlir/lite:tf_tfl_translate", - "@llvm//:FileCheck", - "@llvm//:not", + "@llvm-project//llvm:FileCheck", + "@llvm-project//llvm:not", ], ) diff --git a/tensorflow/compiler/mlir/lite/tests/debuginfo/concrete_function_error.py b/tensorflow/compiler/mlir/lite/tests/debuginfo/concrete_function_error.py index 0bb386f4829..7fe587095b6 100644 --- a/tensorflow/compiler/mlir/lite/tests/debuginfo/concrete_function_error.py +++ b/tensorflow/compiler/mlir/lite/tests/debuginfo/concrete_function_error.py @@ -21,6 +21,7 @@ from __future__ import division from __future__ import print_function import sys + from absl import app import tensorflow.compat.v2 as tf diff --git a/tensorflow/compiler/mlir/lite/tests/debuginfo/saved_model_error.py b/tensorflow/compiler/mlir/lite/tests/debuginfo/saved_model_error.py index a4011226f14..fa35d229bc4 100644 --- a/tensorflow/compiler/mlir/lite/tests/debuginfo/saved_model_error.py +++ b/tensorflow/compiler/mlir/lite/tests/debuginfo/saved_model_error.py @@ -21,6 +21,7 @@ from __future__ import division from __future__ import print_function import sys + from absl import app import tensorflow.compat.v2 as tf diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/BUILD b/tensorflow/compiler/mlir/lite/tests/end2end/BUILD index a15b434571c..732fd784bbc 100644 --- a/tensorflow/compiler/mlir/lite/tests/end2end/BUILD +++ b/tensorflow/compiler/mlir/lite/tests/end2end/BUILD @@ -7,7 +7,7 @@ glob_lit_tests( ":quant_stats_files", ":test_utilities", ], - driver = "@local_config_mlir//:run_lit.sh", + driver = "@llvm-project//mlir:run_lit.sh", test_file_exts = [ "pbtxt", ], @@ -20,7 +20,7 @@ filegroup( data = [ "//tensorflow/compiler/mlir/lite:flatbuffer_to_string", "//tensorflow/compiler/mlir/lite:tf_tfl_translate", - "@llvm//:FileCheck", + "@llvm-project//llvm:FileCheck", ], ) diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/custom_opdef.pbtxt b/tensorflow/compiler/mlir/lite/tests/end2end/custom_opdef.pbtxt index 0fcee7d7e8f..80452715b78 100644 --- a/tensorflow/compiler/mlir/lite/tests/end2end/custom_opdef.pbtxt +++ b/tensorflow/compiler/mlir/lite/tests/end2end/custom_opdef.pbtxt @@ -38,6 +38,6 @@ versions { # CHECK: func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<*xi32> # CHECK: attributes {tf.entry_function = {inputs = "input0,input1", outputs = "output"}} { -# CHECK-NEXT: %0 = "tf.BannaPotatoSaladWithColeslaw"(%arg0, %arg1) {T = i32, device = "", name = "output"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32> +# CHECK-NEXT: %0 = "tf.BannaPotatoSaladWithColeslaw"(%arg0, %arg1) {T = i32, device = ""} : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32> # CHECK-NEXT: return %0 : tensor<*xi32> # CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD index 87caef0237e..b52b766a10d 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD @@ -8,7 +8,7 @@ glob_lit_tests( ":extra_files", ":test_utilities", ], - driver = "@local_config_mlir//:run_lit.sh", + driver = "@llvm-project//mlir:run_lit.sh", test_file_exts = [ "mlir", "cc", @@ -24,7 +24,7 @@ filegroup( ":importer_test_min_max", "//tensorflow/compiler/mlir/lite:flatbuffer_to_string", "//tensorflow/compiler/mlir/lite:flatbuffer_translate", - "@llvm//:FileCheck", + "@llvm-project//llvm:FileCheck", ], ) @@ -51,7 +51,7 @@ tf_native_cc_binary( "//tensorflow/lite:framework", "//tensorflow/lite/schema:schema_fbs", "@com_google_absl//absl/strings", - "@llvm//:support", + "@llvm-project//llvm:support", ], ) @@ -67,6 +67,6 @@ tf_native_cc_binary( "//tensorflow/lite:framework", "//tensorflow/lite/schema:schema_fbs", "@com_google_absl//absl/strings", - "@llvm//:support", + "@llvm-project//llvm:support", ], ) diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/output_arrays.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/output_arrays.mlir index d228cc06a88..20df2f75732 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/output_arrays.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/output_arrays.mlir @@ -11,6 +11,8 @@ func @main(tensor<4xf32>) -> tensor<4xf32> { %3 = "tfl.div"(%2, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("div") // CHECK: %[[EXP:.*]] = "tfl.exp" %4 = "tfl.exp"(%3) : (tensor<4xf32>) -> tensor<4xf32> loc("exp") + // tfl.neg should not be pruned + // CHECK: %[[NEG:.*]] = "tfl.neg" %5 = "tfl.neg"(%4) : (tensor<4xf32>) -> tensor<4xf32> loc("neg") // CHECK: return %[[MUL]], %[[EXP]], %[[DIV]] return %5 : tensor<4xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/pruning.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/pruning.mlir new file mode 100644 index 00000000000..0d7f911f282 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/pruning.mlir @@ -0,0 +1,19 @@ +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate -output-arrays=mul,exp,div --experimental-prune-unreachable-nodes-unconditionally --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s +// Confirm graph pruning. + +func @main(tensor<4xf32>) -> tensor<4xf32> { +^bb0(%arg0: tensor<4xf32>): + %0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") + %1 = "tfl.squared_difference"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("squared_difference") + // CHECK: %[[MUL:.*]] = tfl.mul + %2 = "tfl.mul"(%0, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("mul") + // CHECK: %[[DIV:.*]] = tfl.div + %3 = "tfl.div"(%2, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("div") + // CHECK: %[[EXP:.*]] = "tfl.exp" + %4 = "tfl.exp"(%3) : (tensor<4xf32>) -> tensor<4xf32> loc("exp") + // tfl.neg should be pruned + // CHECK-NOT: "tfl.neg" + %5 = "tfl.neg"(%4) : (tensor<4xf32>) -> tensor<4xf32> loc("neg") + // CHECK: return %[[MUL]], %[[EXP]], %[[DIV]] + return %5 : tensor<4xf32> +} diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index e22198da6ea..e7efc7de99b 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -521,21 +521,30 @@ func @select_multidim(%arg0: tensor<8xi1>, %arg1: tensor<8x3xf32>, %arg2: tensor // CHECK: return } -func @select_v2(%arg0: tensor<8xi1>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>) -> tensor<8xf32> { +func @select_v2_same_shape(%arg0: tensor<8xi1>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>) -> tensor<8xf32> { %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<8xi1>, tensor<8xf32>, tensor<8xf32>) -> tensor<8xf32> return %0: tensor<8xf32> -// CHECK-LABEL: select_v2 +// CHECK-LABEL: select_v2_same_shape // CHECK: "tfl.select"(%arg0, %arg1, %arg2) // CHECK: return } -func @select_v2_multidim(%arg0: tensor<8xi1>, %arg1: tensor<8x3xf32>, %arg2: tensor<8x3xf32>) -> tensor<8x3xf32> { - %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<8xi1>, tensor<8x3xf32>, tensor<8x3xf32>) -> tensor<8x3xf32> +func @select_v2_multidim(%arg0: tensor<3xi1>, %arg1: tensor<8x3xf32>, %arg2: tensor<8x3xf32>) -> tensor<8x3xf32> { + %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<8x3xf32>, tensor<8x3xf32>) -> tensor<8x3xf32> return %0: tensor<8x3xf32> // CHECK-LABEL: select_v2_multidim -// CHECK: "tfl.select"(%arg0, %arg1, %arg2) +// CHECK: "tfl.select_v2"(%arg0, %arg1, %arg2) +// CHECK: return +} + +func @select_v2_broadcast(%arg0: tensor<4xi1>, %arg1: tensor<3x4xf32>, %arg2: tensor<8x3x4xf32>) -> tensor<8x3x4xf32> { + %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<4xi1>, tensor<3x4xf32>, tensor<8x3x4xf32>) -> tensor<8x3x4xf32> + return %0: tensor<8x3x4xf32> + +// CHECK-LABEL: select_v2_broadcast +// CHECK: "tfl.select_v2"(%arg0, %arg1, %arg2) // CHECK: return } diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/BUILD b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/BUILD index c13df3faafc..c0ae9570225 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/BUILD +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/BUILD @@ -4,7 +4,7 @@ licenses(["notice"]) glob_lit_tests( data = [":test_utilities"], - driver = "@local_config_mlir//:run_lit.sh", + driver = "@llvm-project//mlir:run_lit.sh", test_file_exts = ["mlir"], ) @@ -15,7 +15,7 @@ filegroup( data = [ "//tensorflow/compiler/mlir/lite:flatbuffer_to_string", "//tensorflow/compiler/mlir/lite:flatbuffer_translate", - "@llvm//:FileCheck", - "@llvm//:not", + "@llvm-project//llvm:FileCheck", + "@llvm-project//llvm:not", ], ) diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/type_attr.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/type_attr.mlir new file mode 100644 index 00000000000..01410d370d4 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/type_attr.mlir @@ -0,0 +1,40 @@ +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -emit-builtin-tflite-ops=false -o - | flatbuffer_to_string - | FileCheck %s + +// CHECK: { +// CHECK: version: 3, +// CHECK: operator_codes: [ { +// CHECK: builtin_code: CUSTOM, +// CHECK: custom_code: "SomeOperation" +// CHECK: } ], +// CHECK: subgraphs: [ { +// CHECK: tensors: [ { +// CHECK: shape: [ ], +// CHECK: type: INT32, +// CHECK: buffer: 1, +// CHECK: name: "tf.SomeOperation", +// CHECK: quantization: { +// CHECK-EMPTY +// CHECK: } +// CHECK: } ], +// CHECK: inputs: [ ], +// CHECK: outputs: [ 0 ], +// CHECK: operators: [ { +// CHECK: inputs: [ ], +// CHECK: outputs: [ 0 ], +// CHECK: custom_options: [ 100, 116, 121, 112, 101, 0, 1, 7, 1, 1, 1, 2, 4, 2, 36, 1 ] +// CHECK: } ], +// CHECK: name: "main" +// CHECK: } ], +// CHECK: description: "MLIR Converted.", +// CHECK: buffers: [ { +// CHECK-EMPTY +// CHECK: }, { +// CHECK-EMPTY +// CHECK: } ] +// CHECK: } + +func @main() -> tensor<*xi32> { + // Tests that the below type attribute is convertible into the corresponding custom option in flatbuffer. + %0 = "tf.SomeOperation"() {dtype = i32 } : () -> tensor<*xi32> + return %0 : tensor<*xi32> +} diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir index ad3b5540dd7..a60796d1580 100644 --- a/tensorflow/compiler/mlir/lite/tests/ops.mlir +++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir @@ -518,6 +518,20 @@ func @testMaxPool2DWrongOperandStorageType(tensor<1x7x7x16x!quant.uniform) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) { + %0, %1 = "tfl.max_pooling_with_argmax_2d"(%arg0) {filter_h = 2 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) + return %0, %1 : tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32> +} + +// ----- + +func @testMaxUnpooling2D(%arg0: tensor<1x8x8x128xf32>, %arg1: tensor<1x8x8x128xf32>) -> tensor<1x8x8x128xf32> { + %0 = "tfl.max_unpooling_2d"(%arg0, %arg1) {filter_h = 2 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x8x8x128xf32>, tensor<1x8x8x128xf32>) -> (tensor<1x8x8x128xf32>) + return %0 : tensor<1x8x8x128xf32> +} + +// ----- + // CHECK-LABEL: testLogistic func @testLogistic(tensor<1x2x3x4x5xbf16>) -> tensor<1x2x3x4x5xbf16> { ^bb0(%arg0: tensor<1x2x3x4x5xbf16>): @@ -1942,6 +1956,13 @@ func @testTransposeConv(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x128xf32>, %ar // ----- +func @testConvolution2DTransposeBias(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<1x32x42x128xf32>, %arg2: tensor<4xi32>) -> tensor<1x64x84x32xf32> { + %0 = "tfl.convolution_2d_transpose_bias"(%arg0, %arg1, %arg2) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<4xi32>) -> tensor<1x64x84x32xf32> + return %0 : tensor<1x64x84x32xf32> +} + +// ----- + func @testTransposeConvBadOutputRank(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x128xf32>, %arg2: tensor<1x32x42x128xf32>) -> tensor<64x84x32xf32> { // expected-error @+1 {{expect output type has rank = 4, got output type tensor<64x84x32xf32>}} %0 = "tfl.transpose_conv"(%arg0, %arg1, %arg2) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>) -> tensor<64x84x32xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index 1d51adb16f2..5a07946fd9e 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -140,22 +140,6 @@ func @fuseAddWithRelu6IntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: // CHECK-SAME: fused_activation_function = "RELU6" } -// CHECK-LABEL: intermOpUsedTwice -func @intermOpUsedTwice(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> (tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>) { - %cst = constant dense<1.5> : tensor<16xf32> - %cst_0 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32> - %0 = "tfl.conv_2d"(%arg0, %arg1, %cst_0) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> - %1 = "tfl.add"(%0, %cst) {fused_activation_function = "RELU6"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> - return %0, %1 : tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32> - - // CHECK: %cst = constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, - // CHECK: %cst_0 = constant dense<[2.500000e+00, 3.500000e+00, 4.500000e+00, 5.500000e+00, - // CHECK: %0 = "tfl.conv_2d"(%arg0, %arg1, %cst) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} - // CHECK: %1 = "tfl.conv_2d"(%arg0, %arg1, %cst_0) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} - // CHECK: return %0, %1 - -} - // CHECK-LABEL: @fuseMulIntoFullyConnected func @fuseMulIntoFullyConnected(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> { %cst0 = constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32> @@ -167,8 +151,8 @@ func @fuseMulIntoFullyConnected(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> { return %1 : tensor<4x2xf32> -// CHECK: %[[CONSTANT:.*]] = "tf.Const"{{.*}} dense<{{\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00]]> : tensor<2x2xf32> -// CHECK: %[[CONSTANT0:.*]] = "tf.Const"{{.*}} dense<[2.000000e+00, 4.000000e+00]> : tensor<2xf32> +// CHECK: %[[CONSTANT:.*]] = constant dense<{{\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00]]> : tensor<2x2xf32> +// CHECK: %[[CONSTANT0:.*]] = constant dense<[2.000000e+00, 4.000000e+00]> : tensor<2xf32> // CHECK: %[[RES:.*]] = "tfl.fully_connected"(%arg0, %[[CONSTANT]], %[[CONSTANT0]]) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"} // CHECK: return %[[RES]] : tensor<4x2xf32> } @@ -233,8 +217,8 @@ func @fuseMulIntoFullyConnectedBroadcast(%arg0: tensor<1x3xf32>) -> tensor<1x2xf %1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "RELU6"} : (tensor<1x2xf32>, tensor<2xf32>) -> tensor<1x2xf32> return %1 : tensor<1x2xf32> -// CHECK: %[[CONSTANT:.*]] = "tf.Const"{{.*}} dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [2.000000e+00, 4.000000e+00, 6.000000e+00]]> : tensor<2x3xf32> -// CHECK: %[[CONSTANT0:.*]] = "tf.Const"{{.*}} dense<[2.000000e+00, 4.000000e+00]> : tensor<2xf32> +// CHECK: %[[CONSTANT:.*]] = constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [2.000000e+00, 4.000000e+00, 6.000000e+00]]> : tensor<2x3xf32> +// CHECK: %[[CONSTANT0:.*]] = constant dense<[2.000000e+00, 4.000000e+00]> : tensor<2xf32> // CHECK: %[[RES:.*]] = "tfl.fully_connected"(%arg0, %[[CONSTANT]], %[[CONSTANT0]]) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"} // CHECK: return %[[RES]] : tensor<1x2xf32> } @@ -249,7 +233,7 @@ func @fuseMulIntoFullyConnectedNoBias(%arg0: tensor<4x2xf32>, %arg1: none) -> te return %1 : tensor<4x2xf32> -// CHECK: %[[CONSTANT:.*]] = "tf.Const"{{.*}} dense<{{\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00]]> : tensor<2x2xf32> +// CHECK: %[[CONSTANT:.*]] = constant dense<{{\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00]]> : tensor<2x2xf32> // CHECK: %[[RES:.*]] = "tfl.fully_connected"(%arg0, %[[CONSTANT]], %arg1) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, none) -> tensor<4x2xf32> // CHECK: return %[[RES]] : tensor<4x2xf32> } @@ -631,3 +615,18 @@ func @fuse_relu_to_add(%arg0: tensor<2x3xf32>, %arg1: tensor<2x3xf32>) -> tensor // CHECK: %[[RES:.*]] = tfl.add %arg0, %arg1 {fused_activation_function = "RELU_N1_TO_1"} // CHECK: return %[[RES]] } + +// CHECK-LABEL: NotfuseAddIntoConv2d_MultipleUsers +func @NotfuseAddIntoConv2d_MultipleUsers(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> (tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>) { + %cst = constant dense<1.5> : tensor<16xf32> + %cst_1 = constant dense<3.5> : tensor<16xf32> + %cst_0 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32> + %0 = "tfl.conv_2d"(%arg0, %arg1, %cst_0) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> + %1 = "tfl.add"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> + %2 = "tfl.add"(%0, %cst_1) {fused_activation_function = "NONE"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> + return %1, %2 : tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32> + + // CHECK: %[[tfl_conv2d:[0-9].*]] = "tfl.conv_2d" + // CHECK: tfl.add + // CHECK-NEXT: tfl.add +} diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-signed.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-signed.mlir index 29585296217..f6054f3d65d 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-signed.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-signed.mlir @@ -125,3 +125,21 @@ func @prepareDepthwiseConv2D(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112 // PerTensor: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) // PerTensor: %[[conv:.*]] = "tfl.depthwise_conv_2d"(%arg0, %[[dq]] } + +// CHECK-LABEL: QuantizeFullyConnected +func @QuantizeFullyConnected(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> { + %w = constant dense<127.0> : tensor<32x12xf32> + %b = constant dense<0.0> : tensor<32xf32> + %fc = "tfl.fully_connected"(%arg0, %w, %b) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x224x224x3xf32>, tensor<32x12xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32> + return %fc : tensor<1x112x112x32xf32> + +// CHECK: %[[cst:.*]] = constant dense<1.270000e+02> : tensor<32x12xf32> +// CHECK: %[[q:.*]] = "tfl.quantize"(%cst) {qtype = tensor<32x12x!quant.uniform:f32, 1.000000e+00>>} : (tensor<32x12xf32>) -> tensor<32x12x!quant.uniform:f32, 1.000000e+00>> +// CHECK: %[[dq:.*]] = "tfl.dequantize"(%0) : (tensor<32x12x!quant.uniform:f32, 1.000000e+00>>) -> tensor<32x12xf32> +// CHECK: "tfl.fully_connected"(%arg0, %[[dq]] + +// PerTensor: %[[cst:.*]] = constant dense<1.270000e+02> : tensor<32x12xf32> +// PerTensor: %[[q:.*]] = "tfl.quantize"(%cst) {qtype = tensor<32x12x!quant.uniform:f32, 1.000000e+00>>} : (tensor<32x12xf32>) -> tensor<32x12x!quant.uniform:f32, 1.000000e+00>> +// PerTensor: %[[dq:.*]] = "tfl.dequantize"(%0) : (tensor<32x12x!quant.uniform:f32, 1.000000e+00>>) -> tensor<32x12xf32> +// PerTensor: "tfl.fully_connected"(%arg0, %[[dq]] +} diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir index cd111176163..fc9c55089a3 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir @@ -379,26 +379,26 @@ func @QuantizeConcatResToAllNoRequantize(tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> // CHECK: %3 = "tfl.concatenation"(%2, %1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32> // CHECK: %4 = "tfl.quantize"(%3) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> -// CHeCK: return %4 : tensor<2x2x!quant.uniform> +// CHECK: return %4 : tensor<2x2x!quant.uniform> } // CHECK-LABEL: QuantizeConcatResToAllRequantize func @QuantizeConcatResToAllRequantize(tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2x!quant.uniform> { ^bb0(%arg0: tensor<1x2xf32>, %arg1: tensor<1x2xf32>): - %0 = "tfl.quantize"(%arg0) {qtype = tensor<2x!quant.uniform>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> + %0 = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> %1 = "tfl.dequantize"(%0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> %2 = "tfl.concatenation"(%1, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32> %3 = "tfl.quantize"(%2) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> return %3 : tensor<2x2x!quant.uniform> -// CHECK %0 = "tfl.quantize"(%arg0) {qtype = tensor<2x!quant.uniform>} : (tensor<2xf32>) -> tensor<2x!quant.uniform> -// CHECK %1 = "tfl.quantize"(%0) {qtype = tensor<2x!quant.uniform>} : (tensor<2x!quant.uniform>) -> tensor<2x!quant.uniform> -// CHECK %2 = "tfl.dequantize"(%1) : (tensor<2x!quant.uniform>) -> tensor<2xf32> -// CHECK %3 = "tfl.quantize"(%arg1) {qtype = tensor<2x!quant.uniform>} : (tensor<2xf32>) -> tensor<2x!quant.uniform> -// CHECK %4 = "tfl.dequantize"(%3) : (tensor<2x!quant.uniform>) -> tensor<2xf32> -// CHECK %5 = "tfl.concatenation"(%2, %4) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2x2xf32> -// CHECK %6 = "tfl.quantize"(%5) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> -// CHECK return %6 : tensor<2x2x!quant.uniform> +// CHECK: %[[Q1:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> +// CHECK: %[[DQ1:.*]] = "tfl.dequantize"(%[[Q1]]) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> +// CHECK: %[[Q0:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> +// CHECK: %[[RQ0:.*]] = "tfl.quantize"(%[[Q0]]) {qtype = tensor<1x2x!quant.uniform>} : (tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> +// CHECK: %[[DQ0:.*]] = "tfl.dequantize"(%[[RQ0]]) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> +// CHECK: %[[CONC:.*]] = "tfl.concatenation"(%[[DQ0]], %[[DQ1]]) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32> +// CHECK: %[[Q:.*]] = "tfl.quantize"(%[[CONC]]) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> +// CHECK: return %[[Q]] : tensor<2x2x!quant.uniform> } // CHECK-LABEL: QuantizeConcatResToAllRequantizeArg @@ -409,13 +409,13 @@ func @QuantizeConcatResToAllRequantizeArg(tensor<1x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> return %3 : tensor<2x2x!quant.uniform> -// CHECK %1 = "tfl.quantize"(%arg0) {qtype = tensor<2x!quant.uniform>} : (tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> -// CHECK %2 = "tfl.dequantize"(%1) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> -// CHECK %3 = "tfl.quantize"(%arg1) {qtype = tensor<2x!quant.uniform>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> -// CHECK %4 = "tfl.dequantize"(%3) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> -// CHECK %5 = "tfl.concatenation"(%2, %4) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32> -// CHECK %6 = "tfl.quantize"(%5) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> -// CHECK return %6 : tensor<2x2x!quant.uniform> +// CHECK: %[[Q1:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> +// CHECK: %[[DQ1:.*]] = "tfl.dequantize"(%[[Q1]]) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> +// CHECK: %[[RQ0:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform>} : (tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> +// CHECK: %[[DQ0:.*]] = "tfl.dequantize"(%[[RQ0]]) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> +// CHECK: %[[CONC:.*]] = "tfl.concatenation"(%[[DQ0]], %[[DQ1]]) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32> +// CHECK: %[[Q:.*]] = "tfl.quantize"(%[[CONC]]) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> +// CHECK: return %[[Q]] : tensor<2x2x!quant.uniform> } // CHECK-LABEL: RequantizeAlreadyQuantizedModel diff --git a/tensorflow/compiler/mlir/lite/tests/quantize.mlir b/tensorflow/compiler/mlir/lite/tests/quantize.mlir index 225123eb3d3..89d1e7cb7f4 100644 --- a/tensorflow/compiler/mlir/lite/tests/quantize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/quantize.mlir @@ -204,8 +204,9 @@ func @QuantizeConcatRequantize(tensor<1x2x!quant.uniform>, tens %3 = "tfl.quantize"(%2) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> return %3 : tensor<2x2x!quant.uniform> -// CHECK: %[[q:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform>} -// CHECK: %[[cc:.*]] = "tfl.concatenation"(%arg0, %[[q]]) {axis = 0 : i32, fused_activation_function = "NONE"} +// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform>} +// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform>} : (tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> +// CHECK: %[[cc:.*]] = "tfl.concatenation"(%[[q0]], %[[q1]]) {axis = 0 : i32, fused_activation_function = "NONE"} // CHECK: return %[[cc]] : tensor<2x2x!quant.uniform> } diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index 58ff9ce9d2e..e2cf3f9012a 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -15,11 +15,11 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h" -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Pass/PassManager.h" // TF:local_config_mlir -#include "mlir/Transforms/Passes.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/PassManager.h" // TF:llvm-project +#include "mlir/Transforms/Passes.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.h b/tensorflow/compiler/mlir/lite/tf_tfl_passes.h index 7d5b28356dd..651248b1059 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.h +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_TF_TFL_PASSES_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_TF_TFL_PASSES_H_ -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/Pass/PassManager.h" // TF:local_config_mlir +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/Pass/PassManager.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" namespace tensorflow { diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc index aa7e4f21c74..648f469e9b0 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc @@ -20,11 +20,11 @@ limitations under the License. #include "llvm/Support/InitLLVM.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/ToolOutputFile.h" -#include "mlir/IR/Diagnostics.h" // TF:local_config_mlir -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/Support/FileUtilities.h" // TF:local_config_mlir +#include "mlir/IR/Diagnostics.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/Support/FileUtilities.h" // TF:llvm-project #include "tensorflow/compiler/mlir/init_mlir.h" #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h" @@ -103,7 +103,7 @@ static int PrintFunctionResultMapping(const std::string &result, i = 0; for (auto output : *subgraph->outputs()) { print_buffer(*subgraph, i, output, [&](int i) { - return terminator ? terminator->getOperand(i)->getLoc() : unknown_loc; + return terminator ? terminator->getOperand(i).getLoc() : unknown_loc; }); } } diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index bab2ffff7cb..71deb4a8cb3 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -15,12 +15,12 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h" -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/Parser.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Support/FileUtilities.h" // TF:local_config_mlir -#include "mlir/Transforms/Passes.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/Parser.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Support/FileUtilities.h" // TF:llvm-project +#include "mlir/Transforms/Passes.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h index 0f6b2f384f0..6f002af463b 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h @@ -17,9 +17,9 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_ #include "llvm/Support/SourceMgr.h" -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/Pass/PassManager.h" // TF:local_config_mlir +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/Pass/PassManager.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/stream_executor/lib/statusor.h" diff --git a/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc b/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc index 52eb6216e90..7aab9f08732 100644 --- a/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc +++ b/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc @@ -21,26 +21,26 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Casting.h" -#include "mlir/Analysis/LoopAnalysis.h" // TF:local_config_mlir -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Block.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir -#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/SymbolTable.h" // TF:local_config_mlir -#include "mlir/IR/Types.h" // TF:local_config_mlir -#include "mlir/IR/Value.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir -#include "mlir/Support/Functional.h" // TF:local_config_mlir -#include "mlir/Support/LLVM.h" // TF:local_config_mlir -#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Block.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/OperationSupport.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/SymbolTable.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/IR/Value.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/PassRegistry.h" // TF:llvm-project +#include "mlir/Support/Functional.h" // TF:llvm-project +#include "mlir/Support/LLVM.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h" @@ -188,10 +188,10 @@ struct OphintCompositeOp { // This function will process the aggregated inputs based on different // strategies like "first", "last", "stack". - std::map GetAggregatedInputs(OpBuilder* builder) { - std::map aggregated_inputs; + std::map GetAggregatedInputs(OpBuilder* builder) { + std::map aggregated_inputs; for (const auto& kv : inputs) { - Value* op_input = nullptr; + Value op_input = nullptr; const AggregatedOperand& operand = kv.second; // Dealing with "stack" strategy: // This breaks into two parts: @@ -203,9 +203,9 @@ struct OphintCompositeOp { if (operand.ops.size() == 1) { // If ops size is 1, it will be simply expanding dimensions at dim 0. Operation* current_identity_op = operand.ops.begin()->second; - Value* input = current_identity_op->getOperand(0); + Value input = current_identity_op->getOperand(0); RankedTensorType input_type = - input->getType().cast(); + input.getType().cast(); // The Reshape will be {1, (original_shape)} SmallVector reshape_op_shape; reshape_op_shape.push_back(1); @@ -234,21 +234,21 @@ struct OphintCompositeOp { } else { // Insert a pack op to pack all the inputs together. - std::vector pack_input_operands; - std::vector packed_input_consumers; + std::vector pack_input_operands; + std::vector packed_input_consumers; for (int i = 0, e = operand.ops.size(); i < e; ++i) { pack_input_operands.push_back(operand.ops.at(i)->getOperand(0)); packed_input_consumers.push_back(operand.ops.at(i)->getResult(0)); } // Find the first op that consumes the last value of the aggregated // inputs. - Operation* first_use = *(packed_input_consumers.back()->user_begin()); + Operation* first_use = *(packed_input_consumers.back().user_begin()); // The pack reshape will be {N, (original_shape)} SmallVector pack_shape; pack_shape.push_back(pack_input_operands.size()); RankedTensorType type = operand.ops.at(0) ->getResult(0) - ->getType() + .getType() .cast(); for (const auto& dim : type.getShape()) { pack_shape.push_back(dim); @@ -288,9 +288,9 @@ struct OphintCompositeOp { const AggregatedOperand& operand = kv.second; if (operand.aggregation == kStrategyStack) { const int output_numer = operand.ops.size(); - Value* first_output = operand.ops.at(0)->getOperand(0); + Value first_output = operand.ops.at(0)->getOperand(0); RankedTensorType first_output_type = - first_output->getType().cast(); + first_output.getType().cast(); // The aggregated output shape will be {N, original_shape}. SmallVector shape; shape.push_back(output_numer); @@ -300,12 +300,12 @@ struct OphintCompositeOp { aggregated_output_types[kv.first] = RankedTensorType::get(shape, first_output_type.getElementType()); } else if (operand.aggregation == kStrategyLast) { - Value* last_output = + Value last_output = operand.ops.at(operand.ops.size() - 1)->getOperand(0); - aggregated_output_types[kv.first] = last_output->getType(); + aggregated_output_types[kv.first] = last_output.getType(); } else { - Value* first_output = operand.ops.at(0)->getOperand(0); - aggregated_output_types[kv.first] = first_output->getType(); + Value first_output = operand.ops.at(0)->getOperand(0); + aggregated_output_types[kv.first] = first_output.getType(); } } return aggregated_output_types; @@ -329,7 +329,7 @@ struct OphintCompositeOp { Operation* first_output = operand.ops.at(0); Location insert_loc = first_output->getLoc(); SmallVector unpack_output_types( - output_number, first_output->getOperand(0)->getType()); + output_number, first_output->getOperand(0).getType()); builder->setInsertionPoint(first_output); Operation* unpack_op = builder->create( @@ -404,7 +404,7 @@ void PreprocessTopoSortGraph( // should only count as one. llvm::DenseSet input_ops; for (int i = 0; i < op.getNumOperands(); ++i) { - Operation* input_op = op.getOperand(i)->getDefiningOp(); + Operation* input_op = op.getOperand(i).getDefiningOp(); if (input_op) input_ops.insert(input_op); } if (input_ops.empty()) { @@ -507,15 +507,15 @@ LogicalResult TopoSortOperations(OpBuilder* builder) { Operation* BuildFusedFuncOp(StringRef func_name, StringRef fused_func_type, Operation* insert_before_op, - const std::map& inputs, + const std::map& inputs, const std::map& output_types, OpBuilder* builder, ModuleOp* module_op) { SmallVector input_types; - SmallVector input_values; + SmallVector input_values; SmallVector input_indexes; for (const auto& kv : inputs) { - Value* input = kv.second; - input_types.push_back(input->getType()); + Value input = kv.second; + input_types.push_back(input.getType()); input_values.push_back(input); input_indexes.push_back(kv.first); } @@ -588,8 +588,8 @@ llvm::DenseSet BfsForReachableOps(ArrayRef input_ops) { llvm::DenseSet reachable_ops; std::queue ops_queue; for (auto& input_op : input_ops) { - for (Value* value : input_op->getOperands()) { - Operation* op = value->getDefiningOp(); + for (Value value : input_op->getOperands()) { + Operation* op = value.getDefiningOp(); if (op != nullptr) ops_queue.push(op); } } @@ -598,8 +598,8 @@ llvm::DenseSet BfsForReachableOps(ArrayRef input_ops) { Operation* current_op = ops_queue.front(); ops_queue.pop(); reachable_ops.insert(current_op); - for (Value* value : current_op->getOperands()) { - Operation* upstream_op = value->getDefiningOp(); + for (Value value : current_op->getOperands()) { + Operation* upstream_op = value.getDefiningOp(); // Not visited, put it into the queue. if (upstream_op != nullptr && !llvm::is_contained(reachable_ops, upstream_op)) { @@ -625,7 +625,7 @@ LogicalResult ConvertOphintToStub(StringRef stub_name, BfsForReachableOps(ophint_composite_op.GetAllOutputOps()); // Step 3, deal with inputs aggregation strategies. - const std::map& aggregated_inputs = + const std::map& aggregated_inputs = ophint_composite_op.GetAggregatedInputs(builder); // Step 4, get aggregated output types. @@ -642,7 +642,7 @@ LogicalResult ConvertOphintToStub(StringRef stub_name, aggregated_inputs, aggregated_output_types, builder, module_op); for (const auto& kv : aggregated_inputs) { - Operation* op = kv.second->getDefiningOp(); + Operation* op = kv.second.getDefiningOp(); if (op == nullptr) return failure(); op->moveBefore(fused_op); } diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_ophint_func_op.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_ophint_func_op.cc index ed3a9ea5000..e31b143ab43 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_ophint_func_op.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_ophint_func_op.cc @@ -15,23 +15,23 @@ limitations under the License. #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringMap.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Block.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/SymbolTable.h" // TF:local_config_mlir -#include "mlir/IR/Types.h" // TF:local_config_mlir -#include "mlir/IR/Value.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir -#include "mlir/Support/LLVM.h" // TF:local_config_mlir -#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Block.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/OperationSupport.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/SymbolTable.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/IR/Value.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/PassRegistry.h" // TF:llvm-project +#include "mlir/Support/LLVM.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" namespace mlir { @@ -92,18 +92,18 @@ LogicalResult BuildUnidirectionalSequenceRnnOp(FuncOp composite_func_op, if (call_op.getNumResults() != 1) return failure(); // Inputs is indexed at 0. - Value* input = call_op.getOperand(0); + Value input = call_op.getOperand(0); // Input_weight is indexed at 1. - Value* weight = call_op.getOperand(1); + Value weight = call_op.getOperand(1); // Recurrent_weight is indexed at 2. - Value* recurrent_weight = call_op.getOperand(2); + Value recurrent_weight = call_op.getOperand(2); // Bias is indexed at 3. - Value* bias = call_op.getOperand(3); + Value bias = call_op.getOperand(3); // Hidden_state is indexed at 4. - Value* hidden_state = call_op.getOperand(4); + Value hidden_state = call_op.getOperand(4); // Build Output. - auto output_type = call_op.getResult(0)->getType(); + auto output_type = call_op.getResult(0).getType(); // Currently, ophinted RNN only supports time_major = True. const bool time_major = true; @@ -127,7 +127,7 @@ LogicalResult BuildUnidirectionalSequenceLSTMOp(FuncOp composite_func_op, auto input_index_attr = composite_func_op.getAttr(kTfLiteFunctionInputIndex) .cast() .getValue(); - llvm::DenseMap fused_ops_index_to_call_op_args; + llvm::DenseMap fused_ops_index_to_call_op_args; for (int i = 0; i < call_op.getNumOperands(); ++i) { int input_index = input_index_attr[i].cast().getInt(); @@ -139,7 +139,7 @@ LogicalResult BuildUnidirectionalSequenceLSTMOp(FuncOp composite_func_op, // We encounter some optional arguments not filled, so we need to create an // empty Value. - Value* none_value; + Value none_value; if (call_op.getNumOperands() < kUnidirectionalSequenceLSTMOpTotalIArgumentNum) { builder->setInsertionPoint(call_op.getOperation()); @@ -148,7 +148,7 @@ LogicalResult BuildUnidirectionalSequenceLSTMOp(FuncOp composite_func_op, } // Prepare all operands for the UnidirectionalSequenceLSTMOp. - SmallVector operands; + SmallVector operands; for (int i = 0; i < kUnidirectionalSequenceLSTMOpTotalIArgumentNum; ++i) { auto operand_it = fused_ops_index_to_call_op_args.find(i); if (operand_it == fused_ops_index_to_call_op_args.end()) { @@ -169,12 +169,12 @@ LogicalResult BuildUnidirectionalSequenceLSTMOp(FuncOp composite_func_op, if (call_op.getNumResults() > 1) { for (int i = 0; i < call_op.getNumResults() - 1; ++i) { // This one should not be used. - Value* unused_output = call_op.getResult(i); - if (!unused_output->use_empty()) return failure(); + Value unused_output = call_op.getResult(i); + if (!unused_output.use_empty()) return failure(); } } output_types.push_back( - call_op.getResult(call_op.getNumResults() - 1)->getType()); + call_op.getResult(call_op.getNumResults() - 1).getType()); // Prepare attributes. SmallVector attributes; @@ -206,11 +206,11 @@ LogicalResult ConvertTfLiteFusedOpIfAvailable(StringRef func_name, LogicalResult build_fused_op_result = BuildUnidirectionalSequenceLSTMOp( composite_func_op, call_op, builder, &fused_op); if (failed(build_fused_op_result)) return build_fused_op_result; - Value* call_output = call_op.getResult(call_op.getNumResults() - 1); - if (call_output->getType() != fused_op->getResult(0)->getType()) { + Value call_output = call_op.getResult(call_op.getNumResults() - 1); + if (call_output.getType() != fused_op->getResult(0).getType()) { return failure(); } - call_output->replaceAllUsesWith(fused_op->getResult(0)); + call_output.replaceAllUsesWith(fused_op->getResult(0)); } else { // If we support more fused op, we should add the conversion here. return failure(); } diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td index 9e9dfa5874f..596809d3bcb 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td @@ -39,7 +39,7 @@ def Merge2AttrsToArray : NativeCodeCall<"$_builder.getArrayAttr({$0, $1})">; // Use the tensor type information from $0 and convert min $1, max $2 and // numBits $3 and narrowRange $4 to a QuantizedType. def ConvertToQuantTypeFromAttrs : NativeCodeCall< - "GetQuantizedTypeAttr($_builder, $0->getType(), $1, $2, -1, $3, $4, /*is_signed=*/false)">; + "GetQuantizedTypeAttr($_builder, $0.getType(), $1, $2, -1, $3, $4, /*is_signed=*/false)">; // Converts an integer attribute $0 to 32-bit with builder. def convertIntAttrTo32Bit : NativeCodeCall< @@ -49,6 +49,11 @@ def convertIntAttrTo32Bit : NativeCodeCall< def ExtractSingleElementAsInteger : NativeCodeCall< "ExtractSingleElementAsInteger($_self.cast())">; +// Checks whether the given operation has static shapes and same shapes of all inputs. +def HasSameStaticShapesPred : CPred<"HasSameStaticShapes($0.getDefiningOp())">; +def HasSameStaticShapes : Constraint; +def HasNotSameStaticShapes : Constraint, "op must have not static same input shapes">; + //===----------------------------------------------------------------------===// // Nullary ops patterns. //===----------------------------------------------------------------------===// @@ -145,10 +150,9 @@ def : Pat<(TF_RoundOp $arg), (TFL_RoundOp $arg)>; def : Pat<(TF_RsqrtOp $arg), (TFL_RsqrtOp $arg)>; def : Pat<(TF_SqrtOp $arg), (TFL_SqrtOp $arg)>; def : Pat<(TF_SquareOp $arg), (TFL_SquareOp $arg)>; -// TODO(jpienaar): this is not true for all selects, TF's select supports rank 0 -// condition def : Pat<(TF_SelectOp $cond, $x, $y), (TFL_SelectOp $cond, $x, $y)>; -def : Pat<(TF_SelectV2Op $cond, $x, $y), (TFL_SelectOp $cond, $x, $y)>; +def : Pat<(TF_SelectV2Op:$src_op $cond, $x, $y), (TFL_SelectOp $cond, $x, $y), [(HasSameStaticShapes $src_op)]>; +def : Pat<(TF_SelectV2Op:$src_op $cond, $x, $y), (TFL_SelectV2Op $cond, $x, $y), [(HasNotSameStaticShapes $src_op)]>; def : Pat<(TF_ShapeOp $arg), (TFL_ShapeOp $arg)>; def : Pat<(TF_SigmoidOp $arg), (TFL_LogisticOp $arg)>; def : Pat<(TF_SinOp F32Tensor:$arg), (TFL_SinOp $arg)>; diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index 698ba4d4483..5513f2ad546 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -28,15 +28,15 @@ limitations under the License. #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringSwitch.h" -#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir -#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Support/Functional.h" // TF:local_config_mlir -#include "mlir/Support/LLVM.h" // TF:local_config_mlir +#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project +#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Support/Functional.h" // TF:llvm-project +#include "mlir/Support/LLVM.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" @@ -66,6 +66,28 @@ struct LegalizeTF : public FunctionPass { void runOnFunction() override; }; +// Returns true if all tensor value in `values` has static shape and same shape. +bool HasSameStaticShapes(Operation* op) { + auto values = op->getOperands(); + int index = 0; + ArrayRef shape; + for (Value value : values) { + auto shaped_type = value.getType().dyn_cast(); + if (!shaped_type && !shaped_type.hasStaticShape()) { + return false; + } + if (index == 0) { + shape = shaped_type.getShape(); + } else { + if (shape != shaped_type.getShape()) { + return false; + } + } + ++index; + } + return true; +} + #include "tensorflow/compiler/mlir/lite/transforms/generated_legalize_tf.inc" #define DECL_CONVERT_OP(tf_op) \ @@ -100,7 +122,7 @@ PatternMatchResult ConvertTFConcatOp::matchAndRewrite( auto tf_concat_op = cast(op); auto values = tf_concat_op.values(); - auto output_type = tf_concat_op.output()->getType(); + auto output_type = tf_concat_op.output().getType(); // Extract axis attribute from constant concat_dims tensor ElementsAttr axis; if (!matchPattern(tf_concat_op.concat_dim(), m_Constant(&axis))) @@ -119,7 +141,7 @@ PatternMatchResult ConvertTFConcatV2Op::matchAndRewrite( auto tf_concat_op = cast(op); auto values = tf_concat_op.values(); - auto output_type = tf_concat_op.output()->getType(); + auto output_type = tf_concat_op.output().getType(); // Extract axis attribute from constant axis tensor ElementsAttr axis; if (!matchPattern(tf_concat_op.axis(), m_Constant(&axis))) @@ -145,7 +167,7 @@ PatternMatchResult ConvertTFMatMulOp::matchAndRewrite( if (tf_matmul_op.transpose_a()) return matchFailure(); if (!tf_matmul_op.transpose_b()) return matchFailure(); - Type output_type = tf_matmul_op.getResult()->getType(); + Type output_type = tf_matmul_op.getResult().getType(); // TODO(jpienaar): Follow up post shuffle discussion. auto no_input = rewriter.create( op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr()); @@ -161,8 +183,8 @@ PatternMatchResult ConvertTFPackOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_pack_op = cast(op); - SmallVector values(tf_pack_op.values()); - auto output_type = tf_pack_op.output()->getType(); + SmallVector values(tf_pack_op.values()); + auto output_type = tf_pack_op.output().getType(); auto values_count = rewriter.getI32IntegerAttr(tf_pack_op.N()); // Axis can be negative. auto axis = rewriter.getI32IntegerAttr(tf_pack_op.axis().getSExtValue()); @@ -176,10 +198,10 @@ PatternMatchResult ConvertTFReshapeOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_reshape_op = cast(op); - auto* input = tf_reshape_op.tensor(); - auto* shape = tf_reshape_op.shape(); + auto input = tf_reshape_op.tensor(); + auto shape = tf_reshape_op.shape(); - ShapedType shape_type = shape->getType().cast(); + ShapedType shape_type = shape.getType().cast(); // The tfl reshape's #2 operand needs to i32 tensor type, so we have to cast. if (!shape_type.getElementType().isInteger(32)) { auto new_shape = shape_type.getShape(); @@ -191,7 +213,7 @@ PatternMatchResult ConvertTFReshapeOp::matchAndRewrite( rewriter.getBoolAttr(false)) .y(); } - rewriter.replaceOpWithNewOp(op, tf_reshape_op.output()->getType(), + rewriter.replaceOpWithNewOp(op, tf_reshape_op.output().getType(), input, shape); return matchSuccess(); } @@ -200,7 +222,7 @@ PatternMatchResult ConvertTFSplitOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_split_op = cast(op); - auto output_types = functional::map([](Value* v) { return v->getType(); }, + auto output_types = functional::map([](Value v) { return v.getType(); }, tf_split_op.output()); // Number of splits cannot be negative. auto num_split = rewriter.getI32IntegerAttr(tf_split_op.num_split()); @@ -215,7 +237,7 @@ PatternMatchResult ConvertTFSplitVOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_splitv_op = cast(op); - auto output_types = functional::map([](Value* v) { return v->getType(); }, + auto output_types = functional::map([](Value v) { return v.getType(); }, tf_splitv_op.output()); // Number of splits cannot be negative. auto num_split = rewriter.getI32IntegerAttr(tf_splitv_op.num_split()); @@ -226,13 +248,13 @@ PatternMatchResult ConvertTFSplitVOp::matchAndRewrite( return matchSuccess(); } -Value* PadStridedSliceAttributeArray(Operation* op, PatternRewriter& rewriter, - Value* attribute, - ArrayRef padding_val, int* mask) { +Value PadStridedSliceAttributeArray(Operation* op, PatternRewriter& rewriter, + Value attribute, + ArrayRef padding_val, int* mask) { DenseIntElementsAttr dense_elem_attr; SmallVector padded_val; - auto ranked_attr_type = attribute->getType().dyn_cast(); + auto ranked_attr_type = attribute.getType().dyn_cast(); if (!ranked_attr_type || !matchPattern(attribute, m_Constant(&dense_elem_attr))) { // If the input attribute is neither ranked type nor constant, we @@ -258,14 +280,14 @@ PatternMatchResult ConvertTFStridedSliceOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_strided_slice_op = cast(op); auto ranked_input_type = - tf_strided_slice_op.input()->getType().dyn_cast(); + tf_strided_slice_op.input().getType().dyn_cast(); if (!ranked_input_type) { // If input is not a ranked tensor, we can't deduce the padding dimensions // from it, so we just do a plain conversion here. rewriter.replaceOpWithNewOp( - op, tf_strided_slice_op.output()->getType(), - tf_strided_slice_op.input(), tf_strided_slice_op.begin(), - tf_strided_slice_op.end(), tf_strided_slice_op.strides(), + op, tf_strided_slice_op.output().getType(), tf_strided_slice_op.input(), + tf_strided_slice_op.begin(), tf_strided_slice_op.end(), + tf_strided_slice_op.strides(), rewriter.getI32IntegerAttr( tf_strided_slice_op.begin_mask().getSExtValue()), rewriter.getI32IntegerAttr( @@ -283,20 +305,20 @@ PatternMatchResult ConvertTFStridedSliceOp::matchAndRewrite( // Pad `begin` array with zero values and update the `begin_mask`. SmallVector begin_pad_val(num_input_dims, 0); int begin_mask = tf_strided_slice_op.begin_mask().getSExtValue(); - Value* padded_begin = PadStridedSliceAttributeArray( + Value padded_begin = PadStridedSliceAttributeArray( op, rewriter, tf_strided_slice_op.begin(), begin_pad_val, &begin_mask); // Pad `end` array with `input_shape` and update the `end_mask`. int end_mask = tf_strided_slice_op.end_mask().getSExtValue(); auto input_shape = ranked_input_type.getShape(); SmallVector end_pad_val(input_shape.begin(), input_shape.end()); - Value* padded_end = PadStridedSliceAttributeArray( + Value padded_end = PadStridedSliceAttributeArray( op, rewriter, tf_strided_slice_op.end(), end_pad_val, &end_mask); // Pad `strides` array with ones. SmallVector strides_pad_val(num_input_dims, 1); - Value* padded_strides = PadStridedSliceAttributeArray( + Value padded_strides = PadStridedSliceAttributeArray( op, rewriter, tf_strided_slice_op.strides(), strides_pad_val, nullptr); rewriter.replaceOpWithNewOp( - op, tf_strided_slice_op.output()->getType(), tf_strided_slice_op.input(), + op, tf_strided_slice_op.output().getType(), tf_strided_slice_op.input(), padded_begin, padded_end, padded_strides, rewriter.getI32IntegerAttr(begin_mask), rewriter.getI32IntegerAttr(end_mask), @@ -313,8 +335,8 @@ PatternMatchResult ConvertTFUnpackOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_unpack_op = cast(op); - auto* input = tf_unpack_op.value(); - auto output_types = functional::map([](Value* v) { return v->getType(); }, + auto input = tf_unpack_op.value(); + auto output_types = functional::map([](Value v) { return v.getType(); }, tf_unpack_op.output()); auto num = rewriter.getI32IntegerAttr(tf_unpack_op.num()); // Axis can be negative. @@ -338,7 +360,7 @@ bool ConvertTFMatrixDiagV2orV3(Operation* op, PatternRewriter* rewriter) { if (tf_matrix_diag_v2_or_v3_op.getNumOperands() != 5) return false; auto input = tf_matrix_diag_v2_or_v3_op.diagonal(); - auto output_type = tf_matrix_diag_v2_or_v3_op.output()->getType(); + auto output_type = tf_matrix_diag_v2_or_v3_op.output().getType(); // Extract k constant tensor and check value = 0. ElementsAttr k; @@ -478,7 +500,7 @@ PatternMatchResult ConvertTFReciprocalOp::matchAndRewrite( auto status_or_const_op = CreateConstOpWithSingleValue( &rewriter, op->getLoc(), - tf_reciprocal_op.x()->getType().cast(), 1); + tf_reciprocal_op.x().getType().cast(), 1); if (!status_or_const_op.ok()) { return matchFailure(); } diff --git a/tensorflow/compiler/mlir/lite/transforms/load_quantization_recipe.cc b/tensorflow/compiler/mlir/lite/transforms/load_quantization_recipe.cc index f1668b0ffb9..3349261af02 100644 --- a/tensorflow/compiler/mlir/lite/transforms/load_quantization_recipe.cc +++ b/tensorflow/compiler/mlir/lite/transforms/load_quantization_recipe.cc @@ -19,11 +19,11 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" -#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" @@ -50,13 +50,13 @@ struct LoadQuantizationRecipe : public FunctionPass { // Create LSTM gates with different weights for input, recurrent and // cell state, and also the layer normalization parameters. - Operation* CreateGate(Location loc, Value* in, Value* in_w, Value* rec, - Value* rec_w, - llvm::Optional> cell, - Value* ln_w, Value* ln_bias, OpBuilder* builder); + Operation* CreateGate(Location loc, Value in, Value in_w, Value rec, + Value rec_w, + llvm::Optional> cell, + Value ln_w, Value ln_bias, OpBuilder* builder); - Operation* CreateLayerNorm(Location loc, Value* in, Value* ln_w, - Value* ln_bias, OpBuilder* builder); + Operation* CreateLayerNorm(Location loc, Value in, Value ln_w, Value ln_bias, + OpBuilder* builder); // Add the internal implementation of the LSTM to its regions. void LoadForLSTMOp(LSTMOp lstm, OpBuilder* builder); @@ -71,7 +71,7 @@ struct LoadQuantizationRecipe : public FunctionPass { void LoadQuantizationRecipe::Initialize(LSTMOp lstm, OpBuilder* builder) { Type expressed_type = - lstm.input()->getType().cast().getElementType(); + lstm.input().getType().cast().getElementType(); Type int8_storage_type = builder->getIntegerType(8); Type int16_storage_type = builder->getIntegerType(16); auto flag = quant::QuantizationFlags::FlagValue::Signed; @@ -88,12 +88,12 @@ void LoadQuantizationRecipe::Initialize(LSTMOp lstm, OpBuilder* builder) { auto any_int16 = quant::AnyQuantizedType::get( flag, int16_storage_type, expressed_type, int16_min, int16_max); - int8 = any_int8.castFromExpressedType(lstm.input()->getType()); - int16 = any_int16.castFromExpressedType(lstm.input()->getType()); + int8 = any_int8.castFromExpressedType(lstm.input().getType()); + int16 = any_int16.castFromExpressedType(lstm.input().getType()); } -Operation* LoadQuantizationRecipe::CreateLayerNorm(Location loc, Value* in, - Value* ln_w, Value* ln_bias, +Operation* LoadQuantizationRecipe::CreateLayerNorm(Location loc, Value in, + Value ln_w, Value ln_bias, OpBuilder* builder) { // Note that l2_normalization and add ops here are not the execution kernel // implementation for layer_normalization and we just want to use them to @@ -105,8 +105,8 @@ Operation* LoadQuantizationRecipe::CreateLayerNorm(Location loc, Value* in, } Operation* LoadQuantizationRecipe::CreateGate( - Location loc, Value* in, Value* in_w, Value* rec, Value* rec_w, - llvm::Optional> cell, Value* ln_w, Value* ln_bias, + Location loc, Value in, Value in_w, Value rec, Value rec_w, + llvm::Optional> cell, Value ln_w, Value ln_bias, OpBuilder* builder) { auto s1 = builder->create(loc, int16, in, in_w, none_cst, none_af, fc_format, keep_dims); @@ -119,13 +119,13 @@ Operation* LoadQuantizationRecipe::CreateGate( cell.getValue().second, none_af); s4 = builder->create( loc, int16, - llvm::ArrayRef( + llvm::ArrayRef( {*s1.output().begin(), *s2.output().begin(), s3.output()})); } else { s4 = builder->create( loc, int16, - llvm::ArrayRef({*s1.output().begin(), *s2.output().begin()})); + llvm::ArrayRef({*s1.output().begin(), *s2.output().begin()})); } auto s5 = CreateLayerNorm(loc, s4.sum(), ln_w, ln_bias, builder); @@ -144,22 +144,20 @@ void LoadQuantizationRecipe::LoadForLSTMOp(LSTMOp lstm, OpBuilder* builder) { region.push_back(new Block); builder->setInsertionPointToEnd(®ion.front()); Location loc = lstm.getLoc(); - Type int32_type = builder->getIntegerType(32); - Type int32_tensor = UnrankedTensorType::get(int32_type); none_cst = builder->create(loc, builder->getNoneType(), builder->getUnitAttr()); auto input_gate = CreateGate( loc, lstm.input(), lstm.input_to_input_weights(), lstm.input_activation_state(), lstm.recurrent_to_input_weights(), - llvm::Optional>( + llvm::Optional>( {lstm.input_cell_state(), lstm.cell_to_input_weights()}), lstm.input_layer_norm_coefficients(), lstm.input_gate_bias(), builder); auto forget_gate = CreateGate( loc, lstm.input(), lstm.input_to_forget_weights(), lstm.input_activation_state(), lstm.recurrent_to_forget_weights(), - llvm::Optional>( + llvm::Optional>( {lstm.input_cell_state(), lstm.cell_to_forget_weights()}), lstm.forget_layer_norm_coefficients(), lstm.forget_gate_bias(), builder); @@ -179,7 +177,7 @@ void LoadQuantizationRecipe::LoadForLSTMOp(LSTMOp lstm, OpBuilder* builder) { auto output_gate = CreateGate( loc, lstm.input(), lstm.input_to_output_weights(), lstm.input_activation_state(), lstm.recurrent_to_output_weights(), - llvm::Optional>( + llvm::Optional>( {new_cell, lstm.cell_to_output_weights()}), lstm.output_layer_norm_coefficients(), lstm.output_gate_bias(), builder); diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc index 7c02342eedd..bc8d9162b78 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc @@ -29,28 +29,28 @@ limitations under the License. #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" -#include "mlir/Analysis/LoopAnalysis.h" // TF:local_config_mlir -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Block.h" // TF:local_config_mlir -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Matchers.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir -#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/SymbolTable.h" // TF:local_config_mlir -#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir -#include "mlir/IR/Types.h" // TF:local_config_mlir -#include "mlir/IR/Value.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir -#include "mlir/Support/Functional.h" // TF:local_config_mlir -#include "mlir/Support/LLVM.h" // TF:local_config_mlir -#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir -#include "mlir/Transforms/DialectConversion.h" // TF:local_config_mlir +#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Block.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Matchers.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/OperationSupport.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/SymbolTable.h" // TF:llvm-project +#include "mlir/IR/TypeUtilities.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/IR/Value.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/PassRegistry.h" // TF:llvm-project +#include "mlir/Support/Functional.h" // TF:llvm-project +#include "mlir/Support/LLVM.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h" @@ -84,8 +84,8 @@ struct LowerStaticTensorListPass TensorListPatternRewriter *rewriter); }; -Value *CreateI32SplatConst(Location loc, PatternRewriter *rewriter, - ArrayRef shape, int32_t val) { +Value CreateI32SplatConst(Location loc, PatternRewriter *rewriter, + ArrayRef shape, int32_t val) { RankedTensorType type = RankedTensorType::get(shape, rewriter->getIntegerType(32)); DenseElementsAttr attr = @@ -93,9 +93,9 @@ Value *CreateI32SplatConst(Location loc, PatternRewriter *rewriter, return rewriter->create(loc, type, attr); } -Value *CreateI32SplatTensor(Location loc, PatternRewriter *rewriter, - Value *shape_tensor, int32_t val) { - Value *scalar_val = CreateI32SplatConst(loc, rewriter, {}, val); +Value CreateI32SplatTensor(Location loc, PatternRewriter *rewriter, + Value shape_tensor, int32_t val) { + Value scalar_val = CreateI32SplatConst(loc, rewriter, {}, val); return rewriter->create( loc, RankedTensorType::get({-1}, rewriter->getIntegerType(32)), shape_tensor, scalar_val); @@ -131,32 +131,32 @@ Type GetTensorTypeForTensorList(Type element_type, TF::VariantType handle_dtype, // Requires that `start_index` and `size` are scalar tensors and // `item_position_shape` is a 1-D tensor with only one element equal to the rank // of an item in the tensorlist. -TF::SliceOp CreateSliceOpForTensorList(Location loc, Value *input_list, - Value *start_index, Value *size, - Value *item_rank, Type result_type, +TF::SliceOp CreateSliceOpForTensorList(Location loc, Value input_list, + Value start_index, Value size, + Value item_rank, Type result_type, PatternRewriter *rewriter) { // Create the start position of slice. This is done by concatenating // `start_index` and `partial_start_position` together. IntegerType shape_dtype = rewriter->getIntegerType(32); RankedTensorType position_type = RankedTensorType::get({-1}, shape_dtype); - Value *partial_start_position = + Value partial_start_position = CreateI32SplatTensor(loc, rewriter, item_rank, 0); - Value *scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0); + Value scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0); RankedTensorType vector_type = RankedTensorType::get({1}, shape_dtype); auto expanded_start_index = rewriter->create( loc, vector_type, start_index, scalar_zero); auto start_position = rewriter->create( loc, position_type, scalar_zero, - ArrayRef({expanded_start_index, partial_start_position})); + ArrayRef({expanded_start_index, partial_start_position})); // Create the slice size tensor. This is done by concatenating `size` and // `partial_size`. auto size_leading_dim = rewriter->create(loc, vector_type, size, scalar_zero); - Value *partial_size = CreateI32SplatTensor(loc, rewriter, item_rank, -1); + Value partial_size = CreateI32SplatTensor(loc, rewriter, item_rank, -1); auto slice_size = rewriter->create( loc, position_type, scalar_zero, - ArrayRef({size_leading_dim, partial_size})); + ArrayRef({size_leading_dim, partial_size})); return rewriter->create(loc, result_type, input_list, start_position, slice_size); @@ -180,31 +180,31 @@ struct ConvertTensorListSetItem : public ConversionPattern { // 0), [-1, -1, ...])), (ExpandDims $item, expand_dim = 0), (Slice // $input, [$index + 1, 0, 0, ...], [-1, -1, ...]))>; PatternMatchResult matchAndRewrite( - Operation *operation, ArrayRef operands, + Operation *operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto op = llvm::cast(operation); Location loc = op.getLoc(); - Value *input = operands[0]; - Value *index = operands[1]; - Value *item = operands[2]; + Value input = operands[0]; + Value index = operands[1]; + Value item = operands[2]; IntegerType shape_dtype = rewriter.getIntegerType(32); auto item_rank = rewriter.create( loc, RankedTensorType::get({}, shape_dtype), item); - Value *scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0); + Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0); // Calculate `index` + 1, which is used to generate the start position for // the second slice op. auto suffix_start = - rewriter.create(loc, index->getType(), index, + rewriter.create(loc, index.getType(), index, CreateI32SplatConst(loc, &rewriter, {}, 1)); auto item_position_shape = rewriter.create( loc, RankedTensorType::get({1}, shape_dtype), item_rank, scalar_zero); // Create two slice ops. - Type element_type = input->getType().cast().getElementType(); + Type element_type = input.getType().cast().getElementType(); UnrankedTensorType unranked_tensor = UnrankedTensorType::get(element_type); - Value *scalar_minus_one = CreateI32SplatConst(loc, &rewriter, {}, -1); + Value scalar_minus_one = CreateI32SplatConst(loc, &rewriter, {}, -1); TF::SliceOp slice1 = CreateSliceOpForTensorList(loc, /*input_list=*/input, /*start_index=*/scalar_zero, @@ -225,8 +225,8 @@ struct ConvertTensorListSetItem : public ConversionPattern { // Concatenate three parts together to generate the final result. rewriter.replaceOpWithNewOp( - op, input->getType(), scalar_zero, - ArrayRef({slice1, expanded_item, slice2})); + op, input.getType(), scalar_zero, + ArrayRef({slice1, expanded_item, slice2})); return matchSuccess(); } }; @@ -241,14 +241,14 @@ struct ConvertTensorListInitOp : public ConversionPattern { // Create and return a 1-d tensor with exactly one element equal to the number // of list elements to initialize the output tensor list with. - virtual Value *GetNumElements(OpT op, ArrayRef operands, - PatternRewriter *rewriter) const = 0; + virtual Value GetNumElements(OpT op, ArrayRef operands, + PatternRewriter *rewriter) const = 0; // Rewrites the original op into `tf.fill`. The result tensor shape is // [num_element, element_shape]. All the values in the result tensor will be // initialized to 0. PatternMatchResult matchAndRewrite( - Operation *operation, ArrayRef operands, + Operation *operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { OpT op = llvm::cast(operation); @@ -263,8 +263,8 @@ struct ConvertTensorListInitOp : public ConversionPattern { return matchFailure(); } - Value *element_shape = operands[0]; - Type shape_dtype = getElementTypeOrSelf(element_shape->getType()); + Value element_shape = operands[0]; + Type shape_dtype = getElementTypeOrSelf(element_shape.getType()); DenseIntElementsAttr dense_elem_attr; if (matchPattern(element_shape, m_Constant(&dense_elem_attr))) { @@ -297,11 +297,10 @@ struct ConvertTensorListInitOp : public ConversionPattern { new_element_shape_values.push_back(dim_value); } - auto attr = - DenseIntElementsAttr::get(element_shape->getType().cast(), - new_element_shape_values); + auto attr = DenseIntElementsAttr::get( + element_shape.getType().cast(), new_element_shape_values); auto new_element_shape = rewriter.create( - op.getLoc(), element_shape->getType(), attr); + op.getLoc(), element_shape.getType(), attr); element_shape = new_element_shape; } @@ -330,11 +329,11 @@ struct ConvertTensorListInitOp : public ConversionPattern { Location loc = op.getLoc(); // Add number of elements as the prefix to the element shape to get shape of // the output tensor. - Value *leading_dim = GetNumElements(op, operands, &rewriter); - Value *scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0); + Value leading_dim = GetNumElements(op, operands, &rewriter); + Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0); auto list_shape = rewriter.create( loc, shape_type, scalar_zero, - ArrayRef({leading_dim, element_shape})); + ArrayRef({leading_dim, element_shape})); // Create a zero-initialized constant tensor that has the same type // as specified by element_dtype. @@ -352,11 +351,11 @@ struct ConvertTensorListReserve explicit ConvertTensorListReserve(MLIRContext *context) : ConvertTensorListInitOp(context) {} - Value *GetNumElements(TF::TensorListReserveOp op, ArrayRef operands, - PatternRewriter *rewriter) const override { - Value *scalar_zero = CreateI32SplatConst(op.getLoc(), rewriter, {}, 0); - Type shape_dtype = getElementTypeOrSelf(op.element_shape()->getType()); - Value *num_elements = operands[1]; + Value GetNumElements(TF::TensorListReserveOp op, ArrayRef operands, + PatternRewriter *rewriter) const override { + Value scalar_zero = CreateI32SplatConst(op.getLoc(), rewriter, {}, 0); + Type shape_dtype = getElementTypeOrSelf(op.element_shape().getType()); + Value num_elements = operands[1]; return rewriter->create( op.getLoc(), RankedTensorType::get({1}, shape_dtype), num_elements, scalar_zero); @@ -371,8 +370,8 @@ struct ConvertEmptyTensorList explicit ConvertEmptyTensorList(MLIRContext *context) : ConvertTensorListInitOp(context) {} - Value *GetNumElements(TF::EmptyTensorListOp op, ArrayRef operands, - PatternRewriter *rewriter) const override { + Value GetNumElements(TF::EmptyTensorListOp op, ArrayRef operands, + PatternRewriter *rewriter) const override { return CreateI32SplatConst(op.getLoc(), rewriter, {1}, 0); } }; @@ -383,23 +382,23 @@ struct ConvertTensorListPushBack : public ConversionPattern { context) {} PatternMatchResult matchAndRewrite( - Operation *op, ArrayRef operands, + Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { TF::TensorListPushBackOp push_back_op = cast(op); - Value *input_handle = operands[0]; - Value *item = operands[1]; + Value input_handle = operands[0]; + Value item = operands[1]; // Expand the shape of the item so that it will have rank same as the input // tensor and it is compatible for the Concat Op. Type expanded_item_type = - PrependLeadingDimIfRanked(1, item->getType(), &rewriter); - Value *scalar_zero = CreateI32SplatConst(op->getLoc(), &rewriter, {}, 0); + PrependLeadingDimIfRanked(1, item.getType(), &rewriter); + Value scalar_zero = CreateI32SplatConst(op->getLoc(), &rewriter, {}, 0); auto expanded_item = rewriter.create( op->getLoc(), expanded_item_type, item, scalar_zero); Type elem_type = getElementTypeOrSelf(item); auto handle_dtype = - getElementTypeOrSelf(push_back_op.output_handle()->getType()) + getElementTypeOrSelf(push_back_op.output_handle().getType()) .cast(); Type result_type = GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter); @@ -408,7 +407,7 @@ struct ConvertTensorListPushBack : public ConversionPattern { // get a tensor equivalent to the TensorList generated by this op. rewriter.replaceOpWithNewOp( push_back_op, result_type, scalar_zero, - ArrayRef({input_handle, expanded_item})); + ArrayRef({input_handle, expanded_item})); return matchSuccess(); } }; @@ -429,14 +428,14 @@ struct ConvertTensorListResize : public ConversionPattern { context) {} PatternMatchResult matchAndRewrite( - Operation *op, ArrayRef operands, + Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { TF::TensorListResizeOp resize_op = cast(op); - Value *input_handle = operands[0]; - Value *size = operands[1]; + Value input_handle = operands[0]; + Value size = operands[1]; Location loc = resize_op.getLoc(); - Value *scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0); + Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0); // Compute the input tensorlist's length and store it in `input_size`. IntegerType shape_dtype = rewriter.getIntegerType(32); @@ -446,7 +445,7 @@ struct ConvertTensorListResize : public ConversionPattern { // Infer result type of this op based on TF's shape inference result. Type elem_type = getElementTypeOrSelf(input_handle); auto handle_dtype = - getElementTypeOrSelf(resize_op.output_handle()->getType()) + getElementTypeOrSelf(resize_op.output_handle().getType()) .cast(); Type result_type = GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter); @@ -463,8 +462,8 @@ struct ConvertTensorListResize : public ConversionPattern { auto input_shape = rewriter.create( loc, RankedTensorType::get({-1}, shape_dtype), input_handle); - Type branch_args_type[] = {input_handle->getType(), input_shape.getType(), - size_diff.getType(), size->getType()}; + Type branch_args_type[] = {input_handle.getType(), input_shape.getType(), + size_diff.getType(), size.getType()}; Type branch_result_type[] = {result_type}; auto func_type = FunctionType::get(branch_args_type, branch_result_type, rewriter.getContext()); @@ -491,7 +490,7 @@ struct ConvertTensorListResize : public ConversionPattern { rewriter.replaceOpWithNewOp( op, result_type, if_cond, /*input=*/ - ArrayRef({input_handle, input_shape, size_diff, size}), + ArrayRef({input_handle, input_shape, size_diff, size}), /*then_branch=*/rewriter.getSymbolRefAttr(then_branch_op), /*else_branch=*/rewriter.getSymbolRefAttr(else_branch_op), /*output_shapes=*/rewriter.getStrArrayAttr({"{}"}), @@ -517,14 +516,14 @@ struct ConvertTensorListResize : public ConversionPattern { Location loc = resize_op.getLoc(); // Get the element shape by slicing from index 1 in the input shape. - Value *slice_size = CreateI32SplatConst(loc, rewriter, {1}, -1); - Value *scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0); - Value *slice_start = CreateI32SplatConst(loc, rewriter, {1}, 1); + Value slice_size = CreateI32SplatConst(loc, rewriter, {1}, -1); + Value scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0); + Value slice_start = CreateI32SplatConst(loc, rewriter, {1}, 1); auto elem_shape = rewriter->create( loc, RankedTensorType::get({-1}, shape_dtype), input_shape, slice_start, slice_size); auto extended_part = rewriter->create( - loc, resize_op.output_handle()->getType(), elem_shape, size_diff); + loc, resize_op.output_handle().getType(), elem_shape, size_diff); // `ConcatOp` expects non-variant-typed input. Insert a // `TensorListStackOp` here to convert type from variant to non-variant. // Note that we are using the same `result_type` for both the @@ -536,8 +535,8 @@ struct ConvertTensorListResize : public ConversionPattern { /*num_elements=*/rewriter->getI32IntegerAttr(-1)); auto concat_op = rewriter->create( loc, result_type, scalar_zero, - ArrayRef({input, stacked_extended_part})); - rewriter->create(loc, ArrayRef({concat_op})); + ArrayRef({input, stacked_extended_part})); + rewriter->create(loc, ArrayRef({concat_op})); } void CreateCondFalseBranch(Location loc, Type shape_dtype, Type result_type, @@ -550,8 +549,8 @@ struct ConvertTensorListResize : public ConversionPattern { Block *block = branch_func.addEntryBlock(); rewriter->setInsertionPointToStart(block); - Value *scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0); - Value *vector_one = CreateI32SplatConst(loc, rewriter, {1}, 1); + Value scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0); + Value vector_one = CreateI32SplatConst(loc, rewriter, {1}, 1); auto input = block->getArgument(0); auto size = block->getArgument(3); @@ -566,7 +565,7 @@ struct ConvertTensorListResize : public ConversionPattern { /*start_index=*/scalar_zero, /*size=*/size, /*item_rank=*/partial_position_shape, /*result_type=*/result_type, rewriter); - rewriter->create(loc, ArrayRef({slice_op})); + rewriter->create(loc, ArrayRef({slice_op})); } }; @@ -576,11 +575,11 @@ struct ConvertTensorListGetItem : public ConversionPattern { context) {} PatternMatchResult matchAndRewrite( - Operation *operation, ArrayRef operands, + Operation *operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto op = llvm::cast(operation); - Value *input = operands[0]; - Value *index = operands[1]; + Value input = operands[0]; + Value index = operands[1]; rewriter.replaceOpWithNewOp( operation, op.getType(), input, index, rewriter.getBoolAttr(true)); return matchSuccess(); @@ -593,11 +592,11 @@ struct ConvertTensorListLength : public ConversionPattern { context) {} PatternMatchResult matchAndRewrite( - Operation *operation, ArrayRef operands, + Operation *operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto op = llvm::cast(operation); Location loc = op.getLoc(); - Value *input_handle = operands[0]; + Value input_handle = operands[0]; BoolAttr true_attr = rewriter.getBoolAttr(true); auto shape = rewriter.create(loc, input_handle, @@ -615,19 +614,19 @@ struct ConvertTensorListStack : public ConversionPattern { context) {} PatternMatchResult matchAndRewrite( - Operation *operation, ArrayRef operands, + Operation *operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto op = llvm::cast(operation); Location loc = op.getLoc(); - Value *input = operands[0]; - Value *element_shape = operands[1]; + Value input = operands[0]; + Value element_shape = operands[1]; // If the `element_shape` is a known constant (which is defined when calling // `tensor_list_stack`) and also valid (not scalar), we rewrite this op to a // trivial Reshape op (that doesn't actually change the input's shape) and // also populate the shape info to the op result. The shape of the // tensorlist is inferred from `num_elements` and `element_shape`. - auto ranked_type = element_shape->getType().dyn_cast(); + auto ranked_type = element_shape.getType().dyn_cast(); DenseIntElementsAttr dense_elem_attr; if ((ranked_type && ranked_type.getRank() == 0) || !matchPattern(element_shape, m_Constant(&dense_elem_attr))) { @@ -655,11 +654,11 @@ struct ConvertIdentity : public ConversionPattern { : ConversionPattern(TF::IdentityOp::getOperationName(), 1, context) {} PatternMatchResult matchAndRewrite( - Operation *operation, ArrayRef operands, + Operation *operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto op = llvm::cast(operation); - Value *input = operands[0]; - rewriter.replaceOpWithNewOp(op, input->getType(), operands, + Value input = operands[0]; + rewriter.replaceOpWithNewOp(op, input.getType(), operands, op.getAttrs()); return matchSuccess(); } @@ -687,7 +686,7 @@ static LogicalResult UpdateFunctionTypes(TF::WhileOp op) { Type arg_type = func_type.getInput(i); if (getElementTypeOrSelf(arg_type).isa()) { arg_type = UnrankedTensorType::get( - getElementTypeOrSelf(op.getOperand(i)->getType())); + getElementTypeOrSelf(op.getOperand(i).getType())); } updated_argument_types.push_back(arg_type); } @@ -703,7 +702,7 @@ static LogicalResult UpdateFunctionTypes(TF::WhileOp op) { // from the corresponding input operand. This is correct because while // body's inputs and results have the same type. result_type = UnrankedTensorType::get( - getElementTypeOrSelf(op.getOperand(i)->getType())); + getElementTypeOrSelf(op.getOperand(i).getType())); } updated_result_types.push_back(result_type); } @@ -717,7 +716,7 @@ static LogicalResult UpdateFunctionTypes(TF::WhileOp op) { // Change the argument type for the first block. Block &body_first_bb = func.front(); for (int i = 0; i < body_first_bb.getNumArguments(); ++i) { - body_first_bb.getArgument(i)->setType(updated_argument_types[i]); + body_first_bb.getArgument(i).setType(updated_argument_types[i]); } } return success(); @@ -728,19 +727,19 @@ struct ConvertWhile : public ConversionPattern { : ConversionPattern(TF::WhileOp::getOperationName(), 1, context) {} PatternMatchResult matchAndRewrite( - Operation *operation, ArrayRef operands, + Operation *operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto op = llvm::cast(operation); llvm::SmallVector result_types; result_types.reserve(op.getNumOperands()); for (int i = 0, e = operands.size(); i != e; ++i) { - Type result_ty = op.getResult(i)->getType(); + Type result_ty = op.getResult(i).getType(); // If we notice the result type is a DT_VARIANT, we change the // corresponding result type to unranked tensor type. if (getElementTypeOrSelf(result_ty).isa()) { - Type element_ty = getElementTypeOrSelf(operands[i]->getType()); + Type element_ty = getElementTypeOrSelf(operands[i].getType()); result_ty = UnrankedTensorType::get(element_ty); } result_types.push_back(result_ty); diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc index 1313bae97a1..69b767068ff 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc @@ -30,14 +30,14 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Casting.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Matchers.h" // TF:local_config_mlir -#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Support/Functional.h" // TF:local_config_mlir -#include "mlir/Support/LLVM.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Matchers.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Support/Functional.h" // TF:llvm-project +#include "mlir/Support/LLVM.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/validators.h" @@ -50,16 +50,16 @@ namespace TFL { // The actual Optimize Pass. namespace { -bool L2NormalizeReduceAxis(Value *sq_op, DenseElementsAttr axis) { - if (sq_op->getType().cast().getRank() - 1 == +bool L2NormalizeReduceAxis(Value sq_op, DenseElementsAttr axis) { + if (sq_op.getType().cast().getRank() - 1 == *axis.getValues().begin() || *axis.getValues().begin() == -1) { return true; } - if (sq_op->getType().cast().getRank() != axis.getNumElements()) { + if (sq_op.getType().cast().getRank() != axis.getNumElements()) { return false; } - auto shape = sq_op->getType().cast(); + auto shape = sq_op.getType().cast(); SmallVector elems{axis.getValues().begin(), axis.getValues().end()}; for (int i = 0; i < shape.getRank(); ++i) { @@ -142,8 +142,8 @@ ElementsAttr ExpandTo4DForDepthwiseConv(Attribute a) { // Returns shape of a ranked tensor. // Precondition: output_val's is ranked tensor. -DenseElementsAttr GetShape(Value *output_val) { - auto output_type = output_val->getType().cast(); +DenseElementsAttr GetShape(Value output_val) { + auto output_type = output_val.getType().cast(); auto shape_vector = output_type.getShape(); std::vector shape(shape_vector.size()); for (int i = 0; i < shape_vector.size(); ++i) { @@ -152,7 +152,7 @@ DenseElementsAttr GetShape(Value *output_val) { return mlir::DenseElementsAttr::get( RankedTensorType::get( {static_cast(shape.size())}, - mlir::IntegerType::get(32, output_val->getContext())), + mlir::IntegerType::get(32, output_val.getContext())), llvm::makeArrayRef(shape)); } @@ -167,19 +167,19 @@ struct FuseFullyConnectedAndAdd : public OpRewritePattern { PatternRewriter &rewriter) const override { // Add. DenseElementsAttr added_value; - Value *constant_val = add_op.rhs(); + Value constant_val = add_op.rhs(); if (!matchPattern(constant_val, m_Constant(&added_value))) return matchFailure(); // Fully Connected. auto fc_op = - dyn_cast_or_null(add_op.lhs()->getDefiningOp()); + dyn_cast_or_null(add_op.lhs().getDefiningOp()); if (!fc_op) return matchFailure(); - Value *filter = fc_op.filter(); - Value *bias = fc_op.bias(); + Value filter = fc_op.filter(); + Value bias = fc_op.bias(); ElementsAttr bias_value; - const bool is_none_bias = bias->getType().isa(); + const bool is_none_bias = bias.getType().isa(); if (!is_none_bias && !matchPattern(bias, m_Constant(&bias_value))) return matchFailure(); if (fc_op.fused_activation_function() != "NONE") return matchFailure(); @@ -213,7 +213,7 @@ struct FuseFullyConnectedAndRelu : public OpRewritePattern { PatternMatchResult matchAndRewrite(TFL::ReluOp relu_op, PatternRewriter &rewriter) const override { - Operation *input = relu_op.getOperand()->getDefiningOp(); + Operation *input = relu_op.getOperand().getDefiningOp(); if (!isa_and_nonnull(input)) return matchFailure(); auto fully_connected_op = cast(input); if (fully_connected_op.fused_activation_function() != "NONE") @@ -242,18 +242,18 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern { PatternRewriter &rewriter) const override { // Mul. DenseElementsAttr cst; - Value *constant_val = mul_op.rhs(); + Value constant_val = mul_op.rhs(); if (!matchPattern(constant_val, m_Constant(&cst))) return matchFailure(); // Fully Connected. auto fc_op = - dyn_cast_or_null(mul_op.lhs()->getDefiningOp()); + dyn_cast_or_null(mul_op.lhs().getDefiningOp()); if (!fc_op) return matchFailure(); - Value *filter = fc_op.filter(); - Value *bias = fc_op.bias(); + Value filter = fc_op.filter(); + Value bias = fc_op.bias(); ElementsAttr cst_tmp; if (!matchPattern(filter, m_Constant(&cst_tmp))) return matchFailure(); - if (!bias->getType().isa() && + if (!bias.getType().isa() && !matchPattern(bias, m_Constant(&cst_tmp))) return matchFailure(); if (fc_op.fused_activation_function().equals("None")) return matchFailure(); @@ -261,8 +261,8 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern { // Broadcast the constant operand of Mul if it isn't compatible to the // filter input. We only support broadcasting the operand along the depth // dimension, when the operand's depth is 1. - Value *new_const_val = constant_val; - if (!IsBroadcastableElementsAttrAndType(cst.getType(), filter->getType())) { + Value new_const_val = constant_val; + if (!IsBroadcastableElementsAttrAndType(cst.getType(), filter.getType())) { auto original_shape = cst.getType().getShape(); llvm::SmallVector normalized_shape(original_shape.begin(), original_shape.end()); @@ -270,7 +270,7 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern { auto new_cst = cst.reshape(RankedTensorType::get( normalized_shape, cst.getType().getElementType())); Type new_type = new_cst.getType(); - if (!IsBroadcastableElementsAttrAndType(new_type, filter->getType())) { + if (!IsBroadcastableElementsAttrAndType(new_type, filter.getType())) { return matchFailure(); } auto new_op = @@ -285,7 +285,7 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern { auto new_filter = rewriter.create(loc, filter, new_const_val).z(); // If bias isn't None, it needs to be multiplied as well. - if (!bias->getType().isa()) { + if (!bias.getType().isa()) { bias = rewriter.create(loc, bias, constant_val).z(); } @@ -311,7 +311,7 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern { PatternMatchResult matchAndRewrite(AffineOpType fc_op, PatternRewriter &rewriter) const override { // Binary op. - Operation *binary_op = fc_op.input()->getDefiningOp(); + Operation *binary_op = fc_op.input().getDefiningOp(); if (!binary_op || binary_op->getNumOperands() != 2) return this->matchFailure(); // We only handle the cases the RHS is a scalar. @@ -325,20 +325,20 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern { APFloat cst_value = *cst.float_value_begin(); // Affine op. - Value *filter = fc_op.filter(); - Value *bias = fc_op.bias(); + Value filter = fc_op.filter(); + Value bias = fc_op.bias(); DenseFPElementsAttr filter_cst, bias_cst; if (!matchPattern(filter, m_Constant(&filter_cst))) { // The filter maybe quantized, then we should set it to the real constant. - auto dq = llvm::dyn_cast_or_null(filter->getDefiningOp()); + auto dq = llvm::dyn_cast_or_null(filter.getDefiningOp()); if (!dq) return this->matchFailure(); - auto q = llvm::dyn_cast_or_null(dq.input()->getDefiningOp()); + auto q = llvm::dyn_cast_or_null(dq.input().getDefiningOp()); if (!q || !matchPattern(q.input(), m_Constant(&filter_cst))) { return this->matchFailure(); } filter = q.input(); } - if (!bias->getType().isa() && + if (!bias.getType().isa() && !matchPattern(bias, m_Constant(&bias_cst))) return this->matchFailure(); ShapedType filter_type = filter_cst.getType(); @@ -362,7 +362,7 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern { // The new bias should be a 1-D tensor with length equals to the bias // dimension of the weight. SmallVector new_bias_values; - if (bias->getType().isa()) { // none bias, a list of zeros + if (bias.getType().isa()) { // none bias, a list of zeros new_bias_values.resize(bias_size, APFloat(0.0)); } else if (bias_cst.getNumElements() == 1) { // scalar bias, broadcast it new_bias_values.resize(bias_size, *bias_cst.float_value_begin()); @@ -401,12 +401,12 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern { // We recreate the constant op in case it is shared by the other ops. This // might increase the model size. auto new_filter_op = rewriter.create( - fc_op.getLoc(), filter->getType(), new_filter); + fc_op.getLoc(), filter.getType(), new_filter); fc_op.setOperand(0, binary_op->getOperand(0)); if (fc_op.filter() != filter) { // This filter goes through quantize and dequantize ops. Then we just // need to update the weight to the quantize op. - filter->replaceAllUsesWith(new_filter_op); + filter.replaceAllUsesWith(new_filter_op); } else { // This filter doesn't go through quantize and dequantize ops, Then // we update the weight of the affine op directly. diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc b/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc index 59dc271400e..6761abf36ec 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc @@ -17,15 +17,15 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallSet.h" #include "llvm/Support/Casting.h" -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/BlockAndValueMapping.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/TypeUtilities.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" namespace mlir { @@ -98,13 +98,13 @@ class FoldIfOp : public OpRewritePattern { for (int i = 0, e = func.getNumArguments(); i != e; ++i) mapper.map(func.getArgument(i), op.getOperand(i + 1)); - llvm::SmallVector updated_results; + llvm::SmallVector updated_results; for (auto& op_to_inline : func.getBody().front()) { // If this is a terminator, identify the values to use to replace the // original If op. if (op_to_inline.isKnownTerminator()) { updated_results.reserve(op_to_inline.getNumOperands()); - for (Value* operand : op_to_inline.getOperands()) + for (Value operand : op_to_inline.getOperands()) updated_results.push_back(mapper.lookup(operand)); break; } diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index a91f6de1971..c0e49bfb49a 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -18,6 +18,7 @@ limitations under the License. include "mlir/IR/OpBase.td" include "mlir/Dialect/StandardOps/Ops.td" include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" +include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" def F32ElementsAttr : ElementsAttrBase< CPred<"$_self.cast().getType().getElementType().isF32()">, "float constant tensor">; @@ -53,13 +54,15 @@ foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu], [TFL_Relu1Op, TFL_AF_Relu1]] in defm : FuseActFnIntoConvOpPat; +// Checks if the value has only one user. +def HasOneUse : Constraint>; // If we see a binary op (add, sub) op adding a constant value to a convolution // op with constant bias, we can fuse the binary op into the convolution op by // constant folding the bias and the binary op's constant operand. The following // pattern restricts to float constant values for now. multiclass FuseBinaryOpToPrecedingAffine { - def : Pat<(binaryOp (TFL_Conv2DOp $input, $filter, + def : Pat<(binaryOp (TFL_Conv2DOp:$output $input, $filter, (ConstantOp F32ElementsAttr:$bias), $h_factor, $w_factor, TFL_AF_None, $padding, $stride_h, $stride_w), @@ -68,8 +71,9 @@ multiclass FuseBinaryOpToPrecedingAffine { (binaryOp (ConstantOp $bias), (ConstantOp $value), TFL_AF_None), $h_factor, $w_factor, $act_fn, - $padding, $stride_h, $stride_w)>; - def : Pat<(binaryOp (TFL_DepthwiseConv2DOp $input, $filter, + $padding, $stride_h, $stride_w), + [(HasOneUse $output)]>; + def : Pat<(binaryOp (TFL_DepthwiseConv2DOp:$output $input, $filter, (ConstantOp F32ElementsAttr:$bias), $h_factor, $w_factor, TFL_AF_None, $padding, $stride_h, $stride_w, @@ -81,7 +85,8 @@ multiclass FuseBinaryOpToPrecedingAffine { TFL_AF_None), $h_factor, $w_factor, $act_fn, $padding, $stride_h, $stride_w, - $multiplier)>; + $multiplier), + [(HasOneUse $output)]>; } foreach binaryOp = [TFL_AddOp, TFL_SubOp] in defm : FuseBinaryOpToPrecedingAffine; @@ -101,7 +106,7 @@ def ExpandTo4DForDepthwiseConv: NativeCodeCall< // The following pattern restricts to float constant values for now. multiclass FuseMulOrDivWithConv2dOrDepthwiseConv2d { - def : Pat<(BinaryOp (TFL_DepthwiseConv2DOp $input, + def : Pat<(BinaryOp (TFL_DepthwiseConv2DOp:$output $input, (ConstantOp F32ElementsAttr:$filter), (ConstantOp F32ElementsAttr:$bias), $h_factor, $w_factor, TFL_AF_None, @@ -119,8 +124,9 @@ multiclass FuseMulOrDivWithConv2dOrDepthwiseConv2d { $h_factor, $w_factor, $act_fn, $padding, $stride_h, $stride_w, $multiplier), - [(CanFuseConvOrDepthwiseConv<"true"> $filter, $value)]>; - def : Pat<(BinaryOp (TFL_Conv2DOp $input, + [(CanFuseConvOrDepthwiseConv<"true"> $filter, $value), + (HasOneUse $output)]>; + def : Pat<(BinaryOp (TFL_Conv2DOp:$conv_output $input, (ConstantOp F32ElementsAttr:$filter), (ConstantOp F32ElementsAttr:$bias), $h_factor, $w_factor, TFL_AF_None, @@ -135,7 +141,8 @@ multiclass FuseMulOrDivWithConv2dOrDepthwiseConv2d { TFL_AF_None), $h_factor, $w_factor, $act_fn, $padding, $stride_h, $stride_w), - [(CanFuseConvOrDepthwiseConv<"false"> $filter, $value)]>; + [(CanFuseConvOrDepthwiseConv<"false"> $filter, $value), + (HasOneUse $conv_output)]>; } foreach BinaryOp = [TFL_DivOp, TFL_MulOp] in @@ -154,7 +161,7 @@ def EqualOperands : Constraint>; // Checks if the operand has rank == n class OperandHasRank : Constraint< - CPred<"$0->getType().cast().getRank() == " # n>>; + CPred<"$0.getType().cast().getRank() == " # n>>; // Matching HardSwish def : Pat< @@ -249,7 +256,7 @@ foreach L2NormalizePairs = [[TFL_MulOp, TFL_RsqrtOp], [TFL_DivOp, TFL_SqrtOp]] in defm : L2NormalizePatterns; def AreBroadcastableTypes : ConstraintgetType(), $1->getType())">>; + "TFL::IsBroadcastableElementsAttrAndType($0.getType(), $1.getType())">>; // Pattern for skipping Tile if it is mainly for broadcasting and the // Op is already supporting broadcasting. @@ -307,3 +314,7 @@ multiclass FusedBinaryActivationFuncOpPat { foreach BinaryOps = [TFL_AddOp, TFL_DivOp, TFL_MulOp, TFL_SubOp] in defm : FusedBinaryActivationFuncOpPat; + +// The constant folding in this pass might produce constant in the tf dialect. +// This rule is to legalize these constant to the tfl dialect. +def : Pat<(TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>; diff --git a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc index 4f56de26864..267901f69f3 100644 --- a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc @@ -16,8 +16,8 @@ limitations under the License. // This transformation pass applies some clean up steps after quantization. #include "llvm/Support/Casting.h" -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" @@ -67,33 +67,33 @@ void RemoveQuantizationAdaptorOps(FuncOp func) { // In each iteration, a new argument is appended to the end of the list // and the current argument is erased, so here we always process the first // argument in the list. - auto* arg = bb.getArgument(0); + auto arg = bb.getArgument(0); auto remove_quantize_op = [&](QuantizeOp quantize_op) { auto quantize_output = quantize_op.output(); - auto quantize_type = quantize_output->getType(); + auto quantize_type = quantize_output.getType(); input_types.push_back(quantize_type); - auto* new_arg = bb.addArgument(quantize_type); - quantize_output->replaceAllUsesWith(new_arg); + auto new_arg = bb.addArgument(quantize_type); + quantize_output.replaceAllUsesWith(new_arg); quantize_op.erase(); - arg->dropAllUses(); + arg.dropAllUses(); bb.eraseArgument(0); }; // This is looking for a pattern: arg -> tfl.quantize - if (arg->hasOneUse() && llvm::isa(*arg->user_begin())) { - auto quantize_op = llvm::cast(*arg->user_begin()); + if (arg.hasOneUse() && llvm::isa(*arg.user_begin())) { + auto quantize_op = llvm::cast(*arg.user_begin()); remove_quantize_op(quantize_op); continue; } // Make a copy of current argument and append it to the end of the list if // the pattern isn't found. - Type arg_type = arg->getType(); + Type arg_type = arg.getType(); input_types.push_back(arg_type); - auto* new_arg = bb.addArgument(arg_type); - arg->replaceAllUsesWith(new_arg); - arg->dropAllUses(); + auto new_arg = bb.addArgument(arg_type); + arg.replaceAllUsesWith(new_arg); + arg.dropAllUses(); bb.eraseArgument(0); } @@ -102,16 +102,16 @@ void RemoveQuantizationAdaptorOps(FuncOp func) { llvm::SmallVector output_types; output_types.reserve(num_return_operands); for (int i = 0; i != num_return_operands; ++i) { - auto* returned_value = terminator->getOperand(i); - Operation* returned_op = returned_value->getDefiningOp(); + auto returned_value = terminator->getOperand(i); + Operation* returned_op = returned_value.getDefiningOp(); if (returned_op && llvm::isa(returned_op)) { auto dequantize_op = llvm::cast(returned_op); - Value* dequantized_result = dequantize_op.input(); - output_types.push_back(dequantized_result->getType()); + Value dequantized_result = dequantize_op.input(); + output_types.push_back(dequantized_result.getType()); terminator->setOperand(i, dequantized_result); returned_op->erase(); } else { - output_types.push_back(returned_value->getType()); + output_types.push_back(returned_value.getType()); } } auto new_func_type = builder.getFunctionType(input_types, output_types); diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc index c299064a136..a1fb78ac38b 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc @@ -22,19 +22,19 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/Identifier.h" // TF:local_config_mlir -#include "mlir/IR/Location.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/SymbolTable.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/Identifier.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/SymbolTable.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h" @@ -53,8 +53,8 @@ class ConvertEmbeddedLookupFunc { void RewriteFunc() { func_.setAttr(kTFImplements, StringAttr::get("embedding_lookup", func_.getContext())); - Value* lookup = func_.getArgument(1); - Value* value = func_.getArgument(0); + Value lookup = func_.getArgument(1); + Value value = func_.getArgument(0); auto output_type = func_.getType().getResult(0); OpBuilder builder(func_.getBody()); diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td b/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td index a2dc2e93746..40bf54935c4 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td @@ -135,10 +135,10 @@ def : Pat<(TF_ReshapeOp // Casts result type of $1 to a quantized type by using the quantization // parameters from the type in $0. class UpdateShapeWithAxis : NativeCodeCall< - "CastQuantizedTypeAttrFromExpressedType($_builder, $0, $1->getType(), " # i # ")">; + "CastQuantizedTypeAttrFromExpressedType($_builder, $0, $1.getType(), " # i # ")">; class UsedBy : Constraint< - CPred<"llvm::isa(*$0->getUsers().begin())">>; + CPred<"llvm::isa(*$0.getUsers().begin())">>; // When the op is passing-through, the output types of the quantized ops need // to be updated as well. Since the quantize op manages its own type by the diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc index 5d139f83933..0b6da59ca6e 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc @@ -21,10 +21,10 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/CommandLine.h" -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir -#include "mlir/IR/Value.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/IR/Value.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" @@ -139,7 +139,7 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) { BoolAttr narrow_range = builder.getBoolAttr(false); auto add_quantize_op = [&](Location loc, Type input_type, Block* block, - Block::iterator insertion_point, Value* arg, + Block::iterator insertion_point, Value arg, int i) { if (auto shaped = input_type.dyn_cast()) { if (shaped.getElementType().isa()) { @@ -153,16 +153,16 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) { params); auto dq_op = builder.create(loc, input_type, q_op.output()); - arg->replaceAllUsesWith(dq_op.output()); + arg.replaceAllUsesWith(dq_op.output()); q_op.setOperand(arg); } } }; for (int i = 0, e = func.getNumArguments(); i != e; ++i) { - BlockArgument* arg = func.getArgument(i); - auto* arg_block = arg->getOwner(); - add_quantize_op(arg->getLoc(), arg->getType(), arg_block, + BlockArgument arg = func.getArgument(i); + auto* arg_block = arg.getOwner(); + add_quantize_op(arg.getLoc(), arg.getType(), arg_block, std::next(arg_block->begin(), i), arg, i); } diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc index 45248ddc01c..ab4d30e1170 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc @@ -38,17 +38,17 @@ limitations under the License. #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" -#include "mlir/Analysis/LoopAnalysis.h" // TF:local_config_mlir -#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir -#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Support/Functional.h" // TF:local_config_mlir -#include "mlir/Support/LLVM.h" // TF:local_config_mlir -#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project +#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project +#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Support/Functional.h" // TF:llvm-project +#include "mlir/Support/LLVM.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" @@ -115,17 +115,17 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp PatternRewriter &rewriter) const override { // We don't want to insert quantize/dequantize if the quantize op exists. auto res = tf_op.outputs(); - if (!res->hasOneUse() || isa(*res->user_begin())) + if (!res.hasOneUse() || isa(*res.user_begin())) return this->matchFailure(); // Extract the min/max constant values from the operands. We also consider // a special case that there are tf.Identity ops between the min/max // constants and the tf.FakeQuantWithMinMaxVarsOp. - Value *min = tf_op.min(), *max = tf_op.max(); + Value min = tf_op.min(), max = tf_op.max(); DenseFPElementsAttr min_value, max_value; - if (auto id1 = dyn_cast_or_null(min->getDefiningOp())) + if (auto id1 = dyn_cast_or_null(min.getDefiningOp())) min = id1.input(); - if (auto id2 = dyn_cast_or_null(max->getDefiningOp())) + if (auto id2 = dyn_cast_or_null(max.getDefiningOp())) max = id2.input(); if (!matchPattern(min, m_Constant(&min_value))) return this->matchFailure(); if (!matchPattern(max, m_Constant(&max_value))) return this->matchFailure(); @@ -133,7 +133,7 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp int quant_dim = -1; if (PerAxis) { // This is a special case that the quant_dim is the last dimensions. - quant_dim = res->getType().template cast().getRank() - 1; + quant_dim = res.getType().template cast().getRank() - 1; } // Use the min/max from the operands and the num_bits and narrow_range // attribute to create the quantization parameter for the new quantize op. @@ -150,12 +150,12 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp // Finally, use the quantization parameter to create the quantize and // dequantize ops, and insert them between the tf.FakeQuantWithMinMaxVarsOp // and its users. - Value *value = tf_op.outputs(); + Value value = tf_op.outputs(); auto quantize = rewriter.create( tf_op.getLoc(), qtype.getValue(), value, qtype); auto dequantize = rewriter.create( tf_op.getLoc(), res_type, quantize.output()); - value->replaceAllUsesWith(dequantize); + value.replaceAllUsesWith(dequantize); quantize.getOperation()->replaceUsesOfWith(dequantize, value); return this->matchSuccess(); @@ -177,8 +177,8 @@ using PreparePerChannelFakeQuant = // // TFL::[op] createTFLOp(ConvertTFConvOpMatchState *state, // PatternRewriter &rewriter, Location loc, -// Type result_type, Value *input, -// Value *filter, Value *bias) const; +// Type result_type, Value input, +// Value filter, Value bias) const; // // And also the following method for getting the dimension for bias tensor: // @@ -240,7 +240,7 @@ struct ConvertTFConvOp : public RewritePattern { // that we can extract info from the shape (e.g., for constructing bias // tensor, for setting depth_multiplier attribute, etc.). auto filter_type = - tf_op.filter()->getType().template dyn_cast(); + tf_op.filter().getType().template dyn_cast(); if (filter_type && filter_type.getRank() == 4) return matchSuccess(std::move(state)); @@ -262,7 +262,7 @@ struct ConvertTFConvOp : public RewritePattern { // Get a splat zero tensor with the expected dimension for the bias tensor auto filter = tf_op.filter(); - auto filter_type = filter->getType().template cast(); + auto filter_type = filter.getType().template cast(); auto elem_type = filter_type.getElementType(); auto bias_dim = static_cast(this)->getBiasDim( filter_type.getShape()); @@ -294,8 +294,8 @@ class ConvertTFConv2D : public ConvertTFConvOp { TFL::Conv2DOp createTFLOp(ConvertTFConvOpMatchState *state, PatternRewriter &rewriter, Location loc, - Type result_type, Value *input, Value *filter, - Value *bias) const { + Type result_type, Value input, Value filter, + Value bias) const { filter = legalizeFilter(rewriter, loc, filter); return rewriter.create( loc, result_type, input, filter, bias, @@ -312,8 +312,8 @@ class ConvertTFConv2D : public ConvertTFConvOp { // format HWIO to TFLite Conv2D op filter data format OHWI and return Value // for the converted filter. Requires that filter is verified by the match // method that it is a 4-D RankedTensorType. - Value *legalizeFilter(PatternRewriter &rewriter, Location loc, - Value *filter) const { + Value legalizeFilter(PatternRewriter &rewriter, Location loc, + Value filter) const { // Create a constant op for HWIO to OHWI transpose permutation. SmallVector perm = {3, 0, 1, 2}; auto perm_type = RankedTensorType::get({static_cast(perm.size())}, @@ -323,7 +323,7 @@ class ConvertTFConv2D : public ConvertTFConvOp { auto perm_op = rewriter.create(loc, perm_type, perm_attr); // Create tensor type for the transpose result. - auto filter_type = filter->getType().cast(); + auto filter_type = filter.getType().cast(); auto result_shape = functional::map( [filter_type](int64_t dim) { return filter_type.getDimSize(dim); }, perm); @@ -349,14 +349,14 @@ class ConvertTFDepthwiseConv2dNative TFL::DepthwiseConv2DOp createTFLOp(ConvertTFConvOpMatchState *state, PatternRewriter &rewriter, Location loc, - Type result_type, Value *input, - Value *filter, Value *bias) const { + Type result_type, Value input, + Value filter, Value bias) const { // Compared to tfl.conv_2d, tfl.depthwise_conv_2d has an additional // 'depth_multiplier' attribute. However, tf.DepthwiseConv2dNative does not // have a corresponding 'depth_multiplier' attribute; the multiplier is the // fourth dimension in the 4-D filter tensor. We query the multiplier from // tf.DepthwiseConv2dNative and set it as the attribute value accordingly. - auto multiplier = filter->getType().cast().getDimSize(3); + auto multiplier = filter.getType().cast().getDimSize(3); filter = legalizeFilter(rewriter, loc, filter); return rewriter.create( @@ -378,9 +378,9 @@ class ConvertTFDepthwiseConv2dNative /// filter data format is [1, filter_height, filter_width, out_channels]. /// Requires that filter is verified by the match method that it is a 4-D /// RankedTensorType. - Value *legalizeFilter(PatternRewriter &rewriter, Location loc, - Value *filter) const { - auto filter_type = filter->getType().cast(); + Value legalizeFilter(PatternRewriter &rewriter, Location loc, + Value filter) const { + auto filter_type = filter.getType().cast(); auto filterShape = filter_type.getShape(); SmallVector result_shape = {1, filterShape[0], filterShape[1], filterShape[2] * filterShape[3]}; @@ -430,13 +430,13 @@ struct ConvertTFStridedSlice : public RewritePattern { if (new_axis_mask == 0) return matchFailure(); // Insert a new reshape op. - Value *original_input = strided_slice_op.input(); + Value original_input = strided_slice_op.input(); RankedTensorType original_input_type = - original_input->getType().cast(); + original_input.getType().cast(); const ArrayRef &original_input_shape = original_input_type.getShape(); RankedTensorType begin_type = - strided_slice_op.begin()->getType().cast(); + strided_slice_op.begin().getType().cast(); const int dim_size = begin_type.getShape()[0]; SmallVector new_shape; int mask = 1; diff --git a/tensorflow/compiler/mlir/lite/transforms/quantize.cc b/tensorflow/compiler/mlir/lite/transforms/quantize.cc index e47e97a60e8..6842621db70 100644 --- a/tensorflow/compiler/mlir/lite/transforms/quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/quantize.cc @@ -19,17 +19,17 @@ limitations under the License. #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" -#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Matchers.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir -#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Support/Functional.h" // TF:local_config_mlir +#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Matchers.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/OperationSupport.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Support/Functional.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/lite/transforms/split_merged_operands.cc b/tensorflow/compiler/mlir/lite/transforms/split_merged_operands.cc index 123d1f86319..17125bffd85 100644 --- a/tensorflow/compiler/mlir/lite/transforms/split_merged_operands.cc +++ b/tensorflow/compiler/mlir/lite/transforms/split_merged_operands.cc @@ -18,24 +18,24 @@ limitations under the License. #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/StringMap.h" #include "llvm/Support/Casting.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Block.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Matchers.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/SymbolTable.h" // TF:local_config_mlir -#include "mlir/IR/Types.h" // TF:local_config_mlir -#include "mlir/IR/Value.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir -#include "mlir/Support/LLVM.h" // TF:local_config_mlir -#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Block.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Matchers.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/OperationSupport.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/SymbolTable.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/IR/Value.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/PassRegistry.h" // TF:llvm-project +#include "mlir/Support/LLVM.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h" @@ -71,19 +71,19 @@ struct SplitMergedOperandsPass : public FunctionPass { }; LogicalResult DuplicateValueIfNeeded(Operation* op, - llvm::DenseSet* values, + llvm::DenseSet* values, OpBuilder* builder) { std::vector stateful_operands_index; if (!IsStatefulOp(op, &stateful_operands_index)) return success(); for (int index : stateful_operands_index) { - Value* operand = op->getOperand(index); + Value operand = op->getOperand(index); auto inserted_value = values->insert(operand).second; if (inserted_value) continue; // We can only clone the constant op at this point. // Since all ops have been legalized to tflite ops, so we only care about // ConstOp or QConstOp or mlir constant op/ - Operation* input_op = operand->getDefiningOp(); + Operation* input_op = operand.getDefiningOp(); if (input_op == nullptr) return failure(); Attribute attr; @@ -102,7 +102,7 @@ LogicalResult DuplicateValueIfNeeded(Operation* op, } void SplitMergedOperandsPass::runOnFunction() { - llvm::DenseSet stateful_values; + llvm::DenseSet stateful_values; auto func = getFunction(); OpBuilder builder(func); for (auto& bb : func.getBody()) { diff --git a/tensorflow/compiler/mlir/lite/transforms/trim_functions_tf.cc b/tensorflow/compiler/mlir/lite/transforms/trim_functions_tf.cc index 87b96de762a..5a7397ed9c9 100644 --- a/tensorflow/compiler/mlir/lite/transforms/trim_functions_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/trim_functions_tf.cc @@ -20,13 +20,13 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/CommandLine.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Identifier.h" // TF:local_config_mlir -#include "mlir/IR/Location.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/SymbolTable.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Identifier.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/SymbolTable.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/transforms/passes.h" // The cmd line flag to specify the whitelist of functions. Rest are trimmed diff --git a/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.cc b/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.cc index 61d33a5233e..e245bb801b7 100644 --- a/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.cc +++ b/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.cc @@ -24,17 +24,17 @@ limitations under the License. #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" -#include "mlir/Analysis/LoopAnalysis.h" // TF:local_config_mlir -#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir -#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir -#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Support/Functional.h" // TF:local_config_mlir -#include "mlir/Support/LLVM.h" // TF:local_config_mlir -#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project +#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project +#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/OpImplementation.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Support/Functional.h" // TF:llvm-project +#include "mlir/Support/LLVM.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" @@ -67,7 +67,7 @@ void UnrollBatchMatMulPass::runOnFunction() { template TF::ReshapeOp ConvertTFBatchMatMulOp::createReshapeOp( - Value* value, ArrayRef shape, Type element_type, Location loc, + Value value, ArrayRef shape, Type element_type, Location loc, PatternRewriter& rewriter) { int64_t shape_rank = shape.size(); auto shape_spec_type = @@ -81,9 +81,9 @@ TF::ReshapeOp ConvertTFBatchMatMulOp::createReshapeOp( } template -std::vector ConvertTFBatchMatMulOp::sliceInput( - Value* value, int batch_size, Location loc, PatternRewriter& rewriter) { - RankedTensorType tensorType = value->getType().cast(); +std::vector ConvertTFBatchMatMulOp::sliceInput( + Value value, int batch_size, Location loc, PatternRewriter& rewriter) { + RankedTensorType tensorType = value.getType().cast(); Type element_type = tensorType.getElementType(); int rank = tensorType.getShape().size(); @@ -96,7 +96,7 @@ std::vector ConvertTFBatchMatMulOp::sliceInput( SmallVector slice_size = {1, num_rows, num_cols}; - std::vector sliced; + std::vector sliced; Type int64_type = rewriter.getIntegerType(64); Type slice_result_type = RankedTensorType::get(slice_size, element_type); @@ -126,8 +126,8 @@ std::vector ConvertTFBatchMatMulOp::sliceInput( template TF::TransposeOp ConvertTFBatchMatMulOp::createTransposeOp( - Value* value, Location loc, PatternRewriter& rewriter) { - auto value_type = value->getType().cast(); + Value value, Location loc, PatternRewriter& rewriter) { + auto value_type = value.getType().cast(); auto shape = value_type.getShape(); int dims = shape.size(); @@ -158,13 +158,12 @@ TF::TransposeOp ConvertTFBatchMatMulOp::createTransposeOp( template TF::PackOp ConvertTFBatchMatMulOp::createMatMulOps( - const std::vector& sliced_lhs, - const std::vector& sliced_rhs, const tensorflow::MatMulBCast& bcast, - int rows, int cols, Type element_type, Location loc, - PatternRewriter& rewriter) { + const std::vector& sliced_lhs, const std::vector& sliced_rhs, + const tensorflow::MatMulBCast& bcast, int rows, int cols, Type element_type, + Location loc, PatternRewriter& rewriter) { auto matmul_type = RankedTensorType::get({rows, cols}, element_type); - std::vector matmuls; + std::vector matmuls; for (int batch_idx = 0; batch_idx < bcast.output_batch_size(); ++batch_idx) { int lhs_batch_idx, rhs_batch_idx; if (bcast.IsBroadcastingRequired()) { @@ -195,20 +194,20 @@ TF::PackOp ConvertTFBatchMatMulOp::createMatMulOps( template PatternMatchResult ConvertTFBatchMatMulOp::matchAndRewrite( BatchMatMulOpType op, PatternRewriter& rewriter) const { - Value* input_lhs = op.x(); - Value* input_rhs = op.y(); + Value input_lhs = op.x(); + Value input_rhs = op.y(); - if (!input_lhs->getType().isa()) { + if (!input_lhs.getType().isa()) { // LHS must be a ranked tensor type return this->matchFailure(); } - if (!input_rhs->getType().isa()) { + if (!input_rhs.getType().isa()) { // RHS must be a ranked tensor type return this->matchFailure(); } - auto lhs_type = input_lhs->getType().cast(); - auto rhs_type = input_rhs->getType().cast(); + auto lhs_type = input_lhs.getType().cast(); + auto rhs_type = input_rhs.getType().cast(); auto element_type = lhs_type.getElementType(); @@ -234,7 +233,7 @@ PatternMatchResult ConvertTFBatchMatMulOp::matchAndRewrite( if (op.adj_x()) { input_lhs = createTransposeOp(input_lhs, loc, rewriter); - lhs_type = input_lhs->getType().cast(); + lhs_type = input_lhs.getType().cast(); lhs_shape = lhs_type.getShape(); } @@ -242,7 +241,7 @@ PatternMatchResult ConvertTFBatchMatMulOp::matchAndRewrite( if (op.adj_y()) { input_rhs = createTransposeOp(input_rhs, loc, rewriter); - rhs_type = input_rhs->getType().cast(); + rhs_type = input_rhs.getType().cast(); rhs_shape = rhs_type.getShape(); } @@ -276,9 +275,9 @@ PatternMatchResult ConvertTFBatchMatMulOp::matchAndRewrite( } // Compute slices for each batch in the LHS and RHS. - std::vector sliced_lhs = + std::vector sliced_lhs = sliceInput(input_lhs, bcast.x_batch_size(), loc, rewriter); - std::vector sliced_rhs = + std::vector sliced_rhs = sliceInput(input_rhs, bcast.y_batch_size(), loc, rewriter); // Compute (single batch) MatMul for each output batch. The MatMul outputs diff --git a/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.h b/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.h index 19b75963ebf..4aae05bde60 100644 --- a/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.h +++ b/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.h @@ -17,9 +17,9 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_UNROLL_BATCH_MATMUL_H_ #include "llvm/ADT/ArrayRef.h" -#include "mlir/IR/Location.h" // TF:local_config_mlir -#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir -#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/IR/TypeUtilities.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/core/util/matmul_bcast.h" @@ -33,19 +33,18 @@ template class ConvertTFBatchMatMulOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - static TF::ReshapeOp createReshapeOp(Value* value, ArrayRef shape, + static TF::ReshapeOp createReshapeOp(Value value, ArrayRef shape, Type element_type, Location loc, PatternRewriter& rewriter); - static std::vector sliceInput(Value* value, int batch_size, - Location loc, - PatternRewriter& rewriter); + static std::vector sliceInput(Value value, int batch_size, + Location loc, PatternRewriter& rewriter); - static TF::TransposeOp createTransposeOp(Value* value, Location loc, + static TF::TransposeOp createTransposeOp(Value value, Location loc, PatternRewriter& rewriter); - static TF::PackOp createMatMulOps(const std::vector& sliced_lhs, - const std::vector& sliced_rhs, + static TF::PackOp createMatMulOps(const std::vector& sliced_lhs, + const std::vector& sliced_rhs, const tensorflow::MatMulBCast& bcast, int rows, int cols, Type element_type, Location loc, PatternRewriter& rewriter); diff --git a/tensorflow/compiler/mlir/lite/utils/attribute_utils.cc b/tensorflow/compiler/mlir/lite/utils/attribute_utils.cc index 33da9929711..a9cc483df76 100644 --- a/tensorflow/compiler/mlir/lite/utils/attribute_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/attribute_utils.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project namespace mlir { namespace TFL { diff --git a/tensorflow/compiler/mlir/lite/utils/attribute_utils.h b/tensorflow/compiler/mlir/lite/utils/attribute_utils.h index 263a0a8dc93..5a11690d15f 100644 --- a/tensorflow/compiler/mlir/lite/utils/attribute_utils.h +++ b/tensorflow/compiler/mlir/lite/utils/attribute_utils.h @@ -19,7 +19,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_ATTRIBUTE_UTILS_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_ATTRIBUTE_UTILS_H_ -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project namespace mlir { namespace TFL { diff --git a/tensorflow/compiler/mlir/lite/utils/convert_type.cc b/tensorflow/compiler/mlir/lite/utils/convert_type.cc index 167749d5f2e..85bd6a18764 100644 --- a/tensorflow/compiler/mlir/lite/utils/convert_type.cc +++ b/tensorflow/compiler/mlir/lite/utils/convert_type.cc @@ -15,15 +15,20 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/Types.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/lite/schema/schema_generated.h" namespace tflite { +using xla::StatusOr; + +namespace errors = tensorflow::errors; + mlir::Type ConvertElementType(tflite::TensorType type, mlir::Builder builder) { switch (type) { case tflite::TensorType_FLOAT32: @@ -74,4 +79,31 @@ tensorflow::DataType TflTypeToTfType(tflite::TensorType type) { } } +StatusOr TfTypeToTflType(tensorflow::DataType type) { + switch (type) { + case tensorflow::DT_BOOL: + return tflite::TensorType_BOOL; + case tensorflow::DT_COMPLEX64: + return tflite::TensorType_COMPLEX64; + case tensorflow::DT_HALF: + return tflite::TensorType_FLOAT16; + case tensorflow::DT_FLOAT: + return tflite::TensorType_FLOAT32; + case tensorflow::DT_INT8: + return tflite::TensorType_INT8; + case tensorflow::DT_INT16: + return tflite::TensorType_INT16; + case tensorflow::DT_INT32: + return tflite::TensorType_INT32; + case tensorflow::DT_INT64: + return tflite::TensorType_INT64; + case tensorflow::DT_STRING: + return tflite::TensorType_STRING; + case tensorflow::DT_UINT8: + return tflite::TensorType_UINT8; + default: + return errors::InvalidArgument("unsupported tensor data type", type); + } +} + } // namespace tflite diff --git a/tensorflow/compiler/mlir/lite/utils/convert_type.h b/tensorflow/compiler/mlir/lite/utils/convert_type.h index ff4ccb325a8..90600c423bd 100644 --- a/tensorflow/compiler/mlir/lite/utils/convert_type.h +++ b/tensorflow/compiler/mlir/lite/utils/convert_type.h @@ -16,11 +16,13 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONVERT_TYPE_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONVERT_TYPE_H_ -#include "mlir/IR/Types.h" // TF:local_config_mlir +#include "mlir/IR/Types.h" // TF:llvm-project +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/lite/schema/schema_generated.h" namespace mlir { + class Builder; } @@ -32,5 +34,8 @@ mlir::Type ConvertElementType(tflite::TensorType type, mlir::Builder builder); // Tensorflow type tensorflow::DataType TflTypeToTfType(tflite::TensorType type); +// Convert the Tensorflow scalar type to the corresponding TFLite type +xla::StatusOr TfTypeToTflType(tensorflow::DataType type); + } // namespace tflite #endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONVERT_TYPE_H_ diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc index 92a8ad49bf4..132448c58bd 100644 --- a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc @@ -20,20 +20,20 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/Identifier.h" // TF:local_config_mlir -#include "mlir/IR/Location.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/OpDefinition.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/Types.h" // TF:local_config_mlir -#include "mlir/IR/Value.h" // TF:local_config_mlir -#include "mlir/Support/LLVM.h" // TF:local_config_mlir -#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/Identifier.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/OpDefinition.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/IR/Value.h" // TF:llvm-project +#include "mlir/Support/LLVM.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -42,35 +42,35 @@ namespace TFL { namespace { -Value* CreateI32SplatConst(OpBuilder* builder, ArrayRef shape, - int32_t val, mlir::Location location) { +Value CreateI32SplatConst(OpBuilder* builder, ArrayRef shape, + int32_t val, mlir::Location location) { auto type = RankedTensorType::get(shape, builder->getIntegerType(32)); auto attr = DenseElementsAttr::get(type, val); return builder->create(location, type, attr); } -Value* CreateF32SplatConst(OpBuilder* builder, ArrayRef shape, - float val, mlir::Location location) { +Value CreateF32SplatConst(OpBuilder* builder, ArrayRef shape, + float val, mlir::Location location) { auto type = RankedTensorType::get(shape, builder->getF32Type()); auto attr = DenseElementsAttr::get(type, val); return builder->create(location, type, attr); } -Value* CreateI64DenseConst(OpBuilder* builder, ArrayRef shape, - ArrayRef values, mlir::Location location) { +Value CreateI64DenseConst(OpBuilder* builder, ArrayRef shape, + ArrayRef values, mlir::Location location) { auto type = RankedTensorType::get(static_cast(shape.size()), builder->getIntegerType(64)); auto attr = DenseElementsAttr::get(type, values); return builder->create(location, type, attr); } -Value* CreateNoneValue(OpBuilder* builder, mlir::Location location) { +Value CreateNoneValue(OpBuilder* builder, mlir::Location location) { return builder->create(location, builder->getNoneType(), builder->getUnitAttr()); } -Value* Transpose2D(OpBuilder* builder, Value* value_to_transpose, - RankedTensorType type, mlir::Location location) { +Value Transpose2D(OpBuilder* builder, Value value_to_transpose, + RankedTensorType type, mlir::Location location) { // Create a constant op for transpose permutation. SmallVector perm = {1, 0}; auto perm_op = CreateI64DenseConst(builder, perm, perm, location); @@ -87,16 +87,16 @@ Value* Transpose2D(OpBuilder* builder, Value* value_to_transpose, value_to_transpose, perm_op); } -ArrayRef GetRankedTensorShape(Value* value) { - return value->getType().cast().getShape(); +ArrayRef GetRankedTensorShape(Value value) { + return value.getType().cast().getShape(); } -Value* SliceRankedTensor(OpBuilder* builder, Value* input, - ArrayRef begin_shape, - ArrayRef begin_values, - ArrayRef size_shape, - ArrayRef size_values, - mlir::Location location) { +Value SliceRankedTensor(OpBuilder* builder, Value input, + ArrayRef begin_shape, + ArrayRef begin_values, + ArrayRef size_shape, + ArrayRef size_values, + mlir::Location location) { // If the size of the tensor to be sliced from the input overflows // the input tensor's dimensions, return 0-valued tensor of the requested // shape. @@ -120,7 +120,7 @@ Value* SliceRankedTensor(OpBuilder* builder, Value* input, location, RankedTensorType::get( size_values, - input->getType().cast().getElementType()), + input.getType().cast().getElementType()), input, slice_i2c_begin, slice_i2c_size); } @@ -327,8 +327,7 @@ void ConvertLSTMCellSimpleToFusedLSTM::UpdateFuncSignature() { SmallVector output_shape{1, -1}; auto input_types = fused_func_op_.getType().getInputs(); auto output_type = mlir::RankedTensorType::get( - output_shape, - input_->getType().cast().getElementType()); + output_shape, input_.getType().cast().getElementType()); fused_func_op_.setType(mlir::FunctionType::get(input_types, output_type, fused_func_op_.getContext())); } @@ -351,8 +350,7 @@ LogicalResult ConvertLSTMCellSimpleToFusedLSTM::RewriteFunc() { // Create the fused LSTM op. SmallVector output_shape = {1, n_output_}; auto result_type = mlir::RankedTensorType::get( - output_shape, - input_->getType().cast().getElementType()); + output_shape, input_.getType().cast().getElementType()); lstm_ = builder_.create( fused_func_op_.getLoc(), result_type, input_, input2input_, input2forget_, input2cell_, input2output_, rec2input_, rec2forget_, rec2cell_, @@ -371,7 +369,7 @@ LogicalResult ConvertLSTMCellSimpleToFusedLSTM::RewriteFunc() { SmallVector func_output_shape = {1, -1}; auto func_result_type = mlir::RankedTensorType::get( func_output_shape, - input_->getType().cast().getElementType()); + input_.getType().cast().getElementType()); auto tensor_cast = builder_.create( fused_func_op_.getLoc(), lstm_.getResult(), func_result_type); @@ -426,7 +424,7 @@ LogicalResult ConvertLSTMCellSimpleToFusedLSTM::Initialize() { bias_ = fused_func_op_.getArgument(2); weight_ = fused_func_op_.getArgument(1); - weight_type_ = weight_->getType().cast(); + weight_type_ = weight_.getType().cast(); if (weight_type_.getRank() != 2) { return fused_func_op_.emitError() << "The weight tensor was not of rank 2"; @@ -440,7 +438,7 @@ LogicalResult ConvertLSTMCellSimpleToFusedLSTM::Initialize() { n_cell_ = weight_type_.getDimSize(1) / num_gates_; projection_ = fused_func_op_.getArgument(3); - projection_type_ = projection_->getType().cast(); + projection_type_ = projection_.getType().cast(); if (projection_type_.getRank() != 2) { n_output_ = n_cell_; } else { @@ -467,8 +465,7 @@ LogicalResult ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::Initialize() { } layer_norm_scale_ = fused_func_op_.getArgument(4); - layer_norm_scale_type_ = - layer_norm_scale_->getType().cast(); + layer_norm_scale_type_ = layer_norm_scale_.getType().cast(); if (layer_norm_scale_type_.getRank() != 1) { return fused_func_op_.emitError() << "The layer_norm_scale tensor was not of rank 1"; diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils.h b/tensorflow/compiler/mlir/lite/utils/lstm_utils.h index 235d4387faf..f6a2991ca4c 100644 --- a/tensorflow/compiler/mlir/lite/utils/lstm_utils.h +++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils.h @@ -20,12 +20,12 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_LSTM_UTILS_H_ #include "llvm/ADT/StringRef.h" -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/Location.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/Value.h" // TF:local_config_mlir -#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/Value.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" namespace mlir { @@ -102,15 +102,15 @@ class ConvertLSTMCellSimpleToFusedLSTM { // specified state FuncOp fused_func_op_; - Value* input_; - Value* weight_; - Value* bias_; - Value* projection_; + Value input_; + Value weight_; + Value bias_; + Value projection_; bool couple_input_forget_gates_; // internal state - Value* weight_transposed_; - Value* projection_transposed_; + Value weight_transposed_; + Value projection_transposed_; RankedTensorType weight_type_; RankedTensorType projection_type_; int num_gates_; @@ -121,40 +121,40 @@ class ConvertLSTMCellSimpleToFusedLSTM { int num_cols_projection_transposed_; // input -> cifg - Value* input2input_; - Value* input2forget_; - Value* input2cell_; - Value* input2output_; + Value input2input_; + Value input2forget_; + Value input2cell_; + Value input2output_; // recurrent -> cifg - Value* rec2input_; - Value* rec2forget_; - Value* rec2cell_; - Value* rec2output_; + Value rec2input_; + Value rec2forget_; + Value rec2cell_; + Value rec2output_; // bias -> cifg - Value* bias2input_; - Value* bias2forget_; - Value* bias2cell_; - Value* bias2output_; + Value bias2input_; + Value bias2forget_; + Value bias2cell_; + Value bias2output_; // projection - Value* proj_weight_; - Value* proj_bias_; + Value proj_weight_; + Value proj_bias_; // state - Value* input_activation_state_; - Value* input_cell_state_; + Value input_activation_state_; + Value input_cell_state_; // layer norm coefficients - Value* input_layer_norm_coefficients_; - Value* forget_layer_norm_coefficients_; - Value* cell_layer_norm_coefficients_; - Value* output_layer_norm_coefficients_; + Value input_layer_norm_coefficients_; + Value forget_layer_norm_coefficients_; + Value cell_layer_norm_coefficients_; + Value output_layer_norm_coefficients_; mlir::TFL::LSTMOp lstm_; - Value* none_; + Value none_; SmallVector bias_slice_shape_; SmallVector bias_size_values_; SmallVector weight_slice_shape_; @@ -199,7 +199,7 @@ class ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM private: // specified state - Value* layer_norm_scale_; + Value layer_norm_scale_; // internal state RankedTensorType layer_norm_scale_type_; diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc index 798c6db5355..b229206a4e4 100644 --- a/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc +++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc @@ -24,17 +24,17 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/Casting.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/Location.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/Types.h" // TF:local_config_mlir -#include "mlir/IR/Value.h" // TF:local_config_mlir -#include "mlir/Support/LLVM.h" // TF:local_config_mlir -#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/IR/Value.h" // TF:llvm-project +#include "mlir/Support/LLVM.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project #include "tensorflow/core/platform/test.h" namespace mlir { @@ -128,22 +128,20 @@ TEST_F(LstmUtilsTest, ConvertLSTMCellSimple) { auto transpose_op = fused_lstm_func_.getBody().front().begin(); transpose_op++; - EXPECT_EQ(transpose_op->getOperand(0) - ->getType() - .cast() - .getDimSize(0), - 3); - EXPECT_EQ(transpose_op->getOperand(0) - ->getType() - .cast() - .getDimSize(1), - 12); EXPECT_EQ( - transpose_op->getResult(0)->getType().cast().getDimSize( + transpose_op->getOperand(0).getType().cast().getDimSize( + 0), + 3); + EXPECT_EQ( + transpose_op->getOperand(0).getType().cast().getDimSize( + 1), + 12); + EXPECT_EQ( + transpose_op->getResult(0).getType().cast().getDimSize( 0), 12); EXPECT_EQ( - transpose_op->getResult(0)->getType().cast().getDimSize( + transpose_op->getResult(0).getType().cast().getDimSize( 1), 3); @@ -156,12 +154,12 @@ TEST_F(LstmUtilsTest, ConvertLSTMCellSimple) { EXPECT_EQ(it->getNumOperands(), 24); EXPECT_EQ(it->getNumResults(), 1); // cifg = false, so input2input is not None. - EXPECT_FALSE(it->getOperand(1)->getType().isa()); + EXPECT_FALSE(it->getOperand(1).getType().isa()); // input layer norm is None - EXPECT_TRUE(it->getOperand(20)->getType().isa()); + EXPECT_TRUE(it->getOperand(20).getType().isa()); // proj_bias is F32 EXPECT_TRUE(it->getOperand(17) - ->getType() + .getType() .cast() .getElementType() .isF32()); @@ -169,7 +167,7 @@ TEST_F(LstmUtilsTest, ConvertLSTMCellSimple) { // output gate bias is 0 since it is out of bounds of the bias tensor, so // we set its value as a const tensor of specified size and value 0. EXPECT_TRUE( - mlir::cast(it->getOpOperand(15).get()->getDefiningOp()) + mlir::cast(it->getOpOperand(15).get().getDefiningOp()) .getValue() .cast() .getValue(0) @@ -209,7 +207,7 @@ TEST_F(LstmUtilsTest, ConvertLSTMCellSimpleToFusedLSTMCoupleInputForget) { EXPECT_EQ(it->getNumOperands(), 24); EXPECT_EQ(it->getNumResults(), 1); // cifg = true, so input2input is None. - EXPECT_TRUE(it->getOperand(1)->getType().isa()); + EXPECT_TRUE(it->getOperand(1).getType().isa()); } TEST_F(LstmUtilsTest, ConvertLayerNormLSTMCellSimpleToFusedLSTM) { @@ -235,15 +233,15 @@ TEST_F(LstmUtilsTest, ConvertLayerNormLSTMCellSimpleToFusedLSTM) { EXPECT_EQ(it->getNumOperands(), 24); EXPECT_EQ(it->getNumResults(), 1); // cifg = false, so input2input is not None. - EXPECT_FALSE(it->getOperand(1)->getType().isa()); + EXPECT_FALSE(it->getOperand(1).getType().isa()); // input layer norm - EXPECT_FALSE(it->getOperand(20)->getType().isa()); + EXPECT_FALSE(it->getOperand(20).getType().isa()); EXPECT_EQ( - it->getOperand(20)->getType().cast().getShape().size(), + it->getOperand(20).getType().cast().getShape().size(), 1); - EXPECT_EQ( - it->getOperand(20)->getType().cast().getDimSize(0), 3); + EXPECT_EQ(it->getOperand(20).getType().cast().getDimSize(0), + 3); EXPECT_EQ(fused_ln_lstm_func_.getType().getNumResults(), 1); auto output_types = fused_ln_lstm_func_.getType().getResults(); diff --git a/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.cc b/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.cc index 45b8fc96361..f830f67bc10 100644 --- a/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h b/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h index b1d24284acc..917ae93f6a8 100644 --- a/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h +++ b/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_STATEFUL_OPS_UTILS_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_STATEFUL_OPS_UTILS_H_ -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project namespace mlir { namespace TFL { diff --git a/tensorflow/compiler/mlir/lite/utils/validators.cc b/tensorflow/compiler/mlir/lite/utils/validators.cc index f00f8b489d0..f8e3dd12c8b 100644 --- a/tensorflow/compiler/mlir/lite/utils/validators.cc +++ b/tensorflow/compiler/mlir/lite/utils/validators.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/utils/validators.h" -#include "mlir/Dialect/Traits.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir +#include "mlir/Dialect/Traits.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project namespace mlir { namespace TFL { diff --git a/tensorflow/compiler/mlir/lite/utils/validators.h b/tensorflow/compiler/mlir/lite/utils/validators.h index 0a5d790a6eb..e1ae4392881 100644 --- a/tensorflow/compiler/mlir/lite/utils/validators.h +++ b/tensorflow/compiler/mlir/lite/utils/validators.h @@ -19,8 +19,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_VALIDATORS_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_VALIDATORS_H_ -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project namespace mlir { namespace TFL { @@ -51,8 +51,8 @@ bool TFIntListIsAllOnes(const ArrayAttr &attr); // Returns true iff the given value is a float tensor. // is "DT_FLOAT". -inline bool TFTypeIsFloatTensor(Value *value) { - auto tensorType = value->getType().dyn_cast(); +inline bool TFTypeIsFloatTensor(Value value) { + auto tensorType = value.getType().dyn_cast(); if (!tensorType) return false; return tensorType.getElementType().isa(); } diff --git a/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc b/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc index 6b8dd7b0c14..fdaddcfb318 100644 --- a/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc +++ b/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc @@ -25,9 +25,9 @@ limitations under the License. #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/FormatVariadic.h" -#include "mlir/IR/Location.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/Value.h" // TF:local_config_mlir +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/Value.h" // TF:llvm-project static inline absl::string_view StringRefToView(llvm::StringRef ref) { return absl::string_view(ref.data(), ref.size()); @@ -148,18 +148,18 @@ std::string OpOrArgLocNameMapper::GetName(OpOrVal op_or_val) { // generated using the op type. return op->getName().getStringRef(); } - auto* val = op_or_val.dyn_cast(); - auto name_from_loc = GetNameFromLoc(val->getLoc()); + auto val = op_or_val.dyn_cast(); + auto name_from_loc = GetNameFromLoc(val.getLoc()); if (!name_from_loc.empty()) return name_from_loc; // If the location is none of the expected types, then simply use name // generated using the op type. Follow TF convention and append the result // index unless 0. - if (auto* result = llvm::dyn_cast(val)) { - if (result->getResultNumber() > 0) + if (auto result = val.dyn_cast()) { + if (result.getResultNumber() > 0) return llvm::formatv("{0}:{1}", - result->getOwner()->getName().getStringRef(), - result->getResultNumber()); - return result->getOwner()->getName().getStringRef(); + result.getOwner()->getName().getStringRef(), + result.getResultNumber()); + return result.getOwner()->getName().getStringRef(); } return ""; } diff --git a/tensorflow/compiler/mlir/op_or_arg_name_mapper.h b/tensorflow/compiler/mlir/op_or_arg_name_mapper.h index 6517349146e..db83a8dfd7c 100644 --- a/tensorflow/compiler/mlir/op_or_arg_name_mapper.h +++ b/tensorflow/compiler/mlir/op_or_arg_name_mapper.h @@ -23,14 +23,14 @@ limitations under the License. #include "llvm/ADT/PointerUnion.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringRef.h" -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/Value.h" // TF:local_config_mlir +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/Value.h" // TF:llvm-project namespace tensorflow { // PointerUnion for operation and value. // TODO(jpienaar): Rename the files. -using OpOrVal = llvm::PointerUnion; +using OpOrVal = llvm::PointerUnion; // Mapper from operation or value to name. class OpOrArgNameMapper { diff --git a/tensorflow/compiler/mlir/python/mlir.i b/tensorflow/compiler/mlir/python/mlir.i index 2ecea47b3d3..b1d53288204 100644 --- a/tensorflow/compiler/mlir/python/mlir.i +++ b/tensorflow/compiler/mlir/python/mlir.i @@ -108,6 +108,45 @@ string ExperimentalConvertSavedModelToMlir( return MlirModuleToString(*module_or.ConsumeValueOrDie(), show_debug_info); } +// Load a SavedModel V1 and return a textual MLIR string corresponding to it. +// +// Args: +// saved_model_path: File path from which to load the SavedModel. +// tags: Tags to identify MetaGraphDef that need to be loaded. +// +// Returns: +// A string of textual MLIR representing the raw imported SavedModel. +string ExperimentalConvertSavedModelV1ToMlir( + const string &saved_model_path, + const string &tags, + bool show_debug_info, + TF_Status* status) { + // Load the saved model into a SavedModelBundle. + + std::unordered_set tag_set + = absl::StrSplit(tags, ',', absl::SkipEmpty()); + + tensorflow::SavedModelBundle bundle; + auto load_status = tensorflow::LoadSavedModel( + {}, {}, + saved_model_path, tag_set, &bundle); + if (!load_status.ok()) { + Set_TF_Status_from_Status(status, load_status); + return "// error"; + } + + // Convert the SavedModelBundle to an MLIR module. + + mlir::MLIRContext context; + auto module_or = ConvertSavedModelV1ToMlir(bundle, &context); + if (!module_or.status().ok()) { + Set_TF_Status_from_Status(status, module_or.status()); + return "// error"; + } + + return MlirModuleToString(*module_or.ConsumeValueOrDie(), show_debug_info); +} + string ExperimentalRunPassPipeline( const string &mlir_txt, @@ -154,6 +193,7 @@ string ExperimentalRunPassPipeline( %unignore tensorflow::swig; %unignore tensorflow::swig::ImportGraphDef; %unignore tensorflow::swig::ExperimentalConvertSavedModelToMlir; +%unignore tensorflow::swig::ExperimentalConvertSavedModelV1ToMlir; %unignore tensorflow::swig::ExperimentalRunPassPipeline; // Wrap this function @@ -167,6 +207,11 @@ static string ExperimentalConvertSavedModelToMlir( const string &exported_names, bool show_debug_info, TF_Status* status); +static string ExperimentalConvertSavedModelV1ToMlir( + const string &saved_model_path, + const string &tags, + bool show_debug_info, + TF_Status* status); static string ExperimentalRunPassPipeline( const string &mlir_txt, const string &pass_pipeline, @@ -188,6 +233,14 @@ def experimental_convert_saved_model_to_mlir(saved_model_path, show_debug_info ).decode('utf-8'); +def experimental_convert_saved_model_v1_to_mlir(saved_model_path, + tags, show_debug_info): + return ExperimentalConvertSavedModelV1ToMlir( + str(saved_model_path).encode('utf-8'), + str(tags).encode('utf-8'), + show_debug_info + ).decode('utf-8'); + def experimental_run_pass_pipeline(mlir_txt, pass_pipeline, show_debug_info): return ExperimentalRunPassPipeline( mlir_txt.encode('utf-8'), diff --git a/tensorflow/compiler/mlir/runlit.site.cfg.py b/tensorflow/compiler/mlir/runlit.site.cfg.py index e14199ed43b..8f36de71c5f 100644 --- a/tensorflow/compiler/mlir/runlit.site.cfg.py +++ b/tensorflow/compiler/mlir/runlit.site.cfg.py @@ -24,10 +24,11 @@ import lit.llvm # file, instead config is injected by lit.py. The structure is common for lit # tests and intended to only persist temporarily (b/136126535). # pylint: disable=undefined-variable -config.llvm_tools_dir = os.path.join(os.environ['TEST_SRCDIR'], 'llvm') +config.llvm_tools_dir = os.path.join(os.environ['TEST_SRCDIR'], 'llvm-project', + 'llvm') config.mlir_obj_root = os.path.join(os.environ['TEST_SRCDIR']) -config.mlir_tools_dir = os.path.join(os.environ['TEST_SRCDIR'], - 'local_config_mlir') +config.mlir_tools_dir = os.path.join(os.environ['TEST_SRCDIR'], 'llvm-project', + 'mlir') # TODO(jpienaar): Replace with suffices in build rule. config.suffixes = ['.td', '.mlir', '.pbtxt'] diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 93fad60614b..2888997c7b2 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -1,4 +1,4 @@ -load("@local_config_mlir//:tblgen.bzl", "gentbl") +load("//third_party/mlir:tblgen.bzl", "gentbl") load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_gen_op_wrapper_py", "tf_native_cc_binary") package( @@ -8,10 +8,10 @@ package( package_group( name = "friends", - includes = ["@local_config_mlir//:subpackages"], + includes = ["//third_party/mlir:subpackages"], packages = [ "//tensorflow/compiler/...", - "//tensorflow/core/tfrt_delegate/...", + "//tensorflow/lite/experimental/tf_runtime/...", "//tensorflow/python/...", ], ) @@ -22,8 +22,8 @@ filegroup( "ir/tf_generated_ops.td", "ir/tf_op_base.td", "ir/tf_ops.td", - "@local_config_mlir//:OpBaseTdFiles", - "@local_config_mlir//:include/mlir/Analysis/CallInterfaces.td", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:include/mlir/Analysis/CallInterfaces.td", ], ) @@ -43,7 +43,7 @@ gentbl( "g3doc/tf_ops.md", ), ], - tblgen = "@local_config_mlir//:mlir-tblgen", + tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tf_ops.td", td_srcs = [ ":tensorflow_ops_td_files", @@ -66,11 +66,11 @@ gentbl( "g3doc/tf_saved_model.md", ), ], - tblgen = "@local_config_mlir//:mlir-tblgen", + tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tf_saved_model_ops.td", td_srcs = [ - "@local_config_mlir//:include/mlir/IR/OpBase.td", - "@local_config_mlir//:include/mlir/Dialect/StandardOps/Ops.td", + "@llvm-project//mlir:include/mlir/IR/OpBase.td", + "@llvm-project//mlir:include/mlir/Dialect/StandardOps/Ops.td", ], ) @@ -90,11 +90,11 @@ gentbl( "g3doc/tf_executor.md", ), ], - tblgen = "@local_config_mlir//:mlir-tblgen", + tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tf_executor_ops.td", td_srcs = [ - "@local_config_mlir//:include/mlir/IR/OpBase.td", - "@local_config_mlir//:include/mlir/Dialect/StandardOps/Ops.td", + "@llvm-project//mlir:include/mlir/IR/OpBase.td", + "@llvm-project//mlir:include/mlir/Dialect/StandardOps/Ops.td", ], ) @@ -114,11 +114,11 @@ gentbl( "g3doc/tf_device.md", ), ], - tblgen = "@local_config_mlir//:mlir-tblgen", + tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tf_device_ops.td", td_srcs = [ - "@local_config_mlir//:include/mlir/IR/OpBase.td", - "@local_config_mlir//:include/mlir/Dialect/StandardOps/Ops.td", + "@llvm-project//mlir:include/mlir/IR/OpBase.td", + "@llvm-project//mlir:include/mlir/Dialect/StandardOps/Ops.td", ], ) @@ -130,7 +130,7 @@ gentbl( "transforms/generated_canonicalize.inc", ), ], - tblgen = "@local_config_mlir//:mlir-tblgen", + tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/canonicalize.td", td_srcs = [ ":tensorflow_ops_td_files", @@ -162,8 +162,8 @@ cc_library( "ir/tf_types.h", "transforms/bridge.h", "transforms/passes.h", - "@local_config_mlir//:include/mlir/Analysis/CallInterfaces.h", - "@local_config_mlir//:include/mlir/Transforms/InliningUtils.h", + "@llvm-project//mlir:include/mlir/Analysis/CallInterfaces.h", + "@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h", ], includes = ["include"], deps = [ @@ -177,17 +177,17 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/platform:logging", - "@llvm//:support", - "@local_config_mlir//:Analysis", - "@local_config_mlir//:CallOpInterfacesIncGen", - "@local_config_mlir//:Dialect", - "@local_config_mlir//:IR", - "@local_config_mlir//:Parser", - "@local_config_mlir//:Pass", - "@local_config_mlir//:StandardOps", - "@local_config_mlir//:Support", - "@local_config_mlir//:TransformUtils", - "@local_config_mlir//:Transforms", + "@llvm-project//llvm:support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:CallOpInterfacesIncGen", + "@llvm-project//mlir:Dialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", ], # TODO(jpienaar): Merge in the dialect registration. alwayslink = 1, @@ -201,11 +201,11 @@ gentbl( "transforms/generated_decompose_resource_ops.inc", ), ], - tblgen = "@local_config_mlir//:mlir-tblgen", + tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/decompose_resource_ops.td", td_srcs = [ ":tensorflow_ops_td_files", - "@local_config_mlir//:StdOpsTdFiles", + "@llvm-project//mlir:StdOpsTdFiles", ], ) @@ -220,7 +220,7 @@ cc_library( deps = [ ":decompose_resource_ops_inc_gen", ":tensorflow", - "@local_config_mlir//:IR", + "@llvm-project//mlir:IR", ], ) @@ -290,15 +290,15 @@ cc_library( "//tensorflow/core/platform:logging", "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", "//tensorflow/core/protobuf/tpu:dynamic_padding_proto_cc", - "@llvm//:support", - "@local_config_mlir//:Analysis", - "@local_config_mlir//:IR", - "@local_config_mlir//:Parser", - "@local_config_mlir//:Pass", - "@local_config_mlir//:StandardOps", - "@local_config_mlir//:Support", - "@local_config_mlir//:TransformUtils", - "@local_config_mlir//:Transforms", + "@llvm-project//llvm:support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", ], # TODO(jpienaar): Merge in the dialect registration. alwayslink = 1, @@ -311,8 +311,8 @@ cc_library( ], deps = [ ":lower_tf_lib", - "@local_config_mlir//:IR", - "@local_config_mlir//:Pass", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", ], alwayslink = 1, ) @@ -323,7 +323,7 @@ cc_library( srcs = ["ir/dialect_registration.cc"], deps = [ ":tensorflow", - "@local_config_mlir//:IR", + "@llvm-project//mlir:IR", ], alwayslink = 1, ) @@ -348,15 +348,18 @@ cc_library( ":tensorflow", ":tensorflow_passes", "//tensorflow/cc/saved_model:bundle_v2", + "//tensorflow/cc/saved_model:loader_lite", "//tensorflow/compiler/jit:shape_inference_helpers", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/tf2xla:functionalize_control_flow", "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler/utils:transitive_fanin", "//tensorflow/core/platform:types", "//tensorflow/stream_executor/lib", "@com_google_absl//absl/algorithm:container", @@ -365,12 +368,12 @@ cc_library( "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:Pass", - "@local_config_mlir//:StandardDialectRegistration", - "@local_config_mlir//:StandardOps", - "@local_config_mlir//:Support", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardDialectRegistration", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", ], ) @@ -387,7 +390,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/strings", - "@llvm//:support", + "@llvm-project//llvm:support", ], ) @@ -417,11 +420,11 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:StandardDialectRegistration", - "@local_config_mlir//:StandardOps", - "@local_config_mlir//:Support", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:StandardDialectRegistration", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", ], ) @@ -444,8 +447,8 @@ cc_library( "//tensorflow/stream_executor/lib", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", - "@llvm//:support", - "@local_config_mlir//:IR", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", ], ) @@ -454,10 +457,10 @@ cc_library( srcs = ["translate/translate_tf_dialect_op.cc"], deps = [ ":export_tf_dialect_op", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:Support", - "@local_config_mlir//:Translation", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Translation", ], alwayslink = 1, ) @@ -474,9 +477,9 @@ cc_library( "//tensorflow/core:core_cpu_lib", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "@local_config_mlir//:Analysis", - "@local_config_mlir//:IR", - "@local_config_mlir//:StandardOps", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:StandardOps", ], alwayslink = 1, ) @@ -513,7 +516,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", - "@llvm//:support", + "@llvm-project//llvm:support", ], ) @@ -529,9 +532,9 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/stream_executor/lib", "@com_google_absl//absl/strings", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:Support", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -546,8 +549,8 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/stream_executor/lib", - "@llvm//:support", - "@local_config_mlir//:IR", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", ], ) @@ -565,8 +568,8 @@ cc_library( "//tensorflow/stream_executor/lib", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", - "@llvm//:support", - "@local_config_mlir//:IR", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", ], ) @@ -581,7 +584,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/stream_executor/lib", - "@local_config_mlir//:IR", + "@llvm-project//mlir:IR", ], ) @@ -603,8 +606,8 @@ cc_library( hdrs = ["utils/error_util.h"], deps = [ "//tensorflow/core:lib", - "@llvm//:support", - "@local_config_mlir//:IR", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", ], ) @@ -626,13 +629,14 @@ cc_library( "//tensorflow/c:tf_status", "//tensorflow/c/eager:c_api", "//tensorflow/core:framework", + "//tensorflow/core:lib", "//tensorflow/stream_executor", "//tensorflow/stream_executor/lib", - "@llvm//:support", - "@local_config_mlir//:Analysis", - "@local_config_mlir//:IR", - "@local_config_mlir//:Pass", - "@local_config_mlir//:Support", + "@llvm-project//llvm:support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", ], alwayslink = 1, ) @@ -642,7 +646,7 @@ cc_library( deps = [ ":tensorflow_dialect_registration", ":tf_dialect_passes", - "@local_config_mlir//:StandardDialectRegistration", + "@llvm-project//mlir:StandardDialectRegistration", ], ) @@ -661,9 +665,9 @@ cc_library( "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", "//tensorflow/stream_executor/lib", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:Pass", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", ], alwayslink = 1, ) @@ -690,26 +694,23 @@ cc_library( "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:Support", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) cc_library( name = "translate_lib", - srcs = [ - "translate/tf_mlir_translate.cc", - ], - hdrs = [ - "translate/tf_mlir_translate.h", - ], + srcs = ["translate/tf_mlir_translate.cc"], + hdrs = ["translate/tf_mlir_translate.h"], deps = [ ":convert_graphdef", ":error_util", ":import_utils", ":mangling_util", ":mlir_roundtrip_flags", + "//tensorflow/cc/saved_model:bundle_v2", "//tensorflow/core:graph", "//tensorflow/core:lib_proto_parsing", "//tensorflow/core:ops", @@ -718,10 +719,10 @@ cc_library( "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:Parser", - "@local_config_mlir//:Pass", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", ], ) @@ -734,7 +735,7 @@ cc_library( "translate/tf_mlir_translate_cl.h", ], deps = [ - "@llvm//:support", + "@llvm-project//llvm:support", ], alwayslink = 1, ) @@ -751,9 +752,9 @@ cc_library( ":translate_lib", "//tensorflow/core:protos_all_cc", "//tensorflow/stream_executor/lib", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:Translation", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Translation", ], alwayslink = 1, ) @@ -768,8 +769,8 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", - "@llvm//:support", - "@local_config_mlir//:IR", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", ], ) @@ -779,17 +780,17 @@ tf_native_cc_binary( "translate/derived_attr_populator_gen.cc", ], deps = [ - "@llvm//:support", - "@llvm//:tablegen", - "@local_config_mlir//:TableGen", + "@llvm-project//llvm:support", + "@llvm-project//llvm:tablegen", + "@llvm-project//mlir:TableGen", ], ) genrule( name = "derived_attr_populator_inc", srcs = [ - "@local_config_mlir//:include/mlir/Analysis/CallInterfaces.td", - "@local_config_mlir//:include/mlir/IR/OpBase.td", + "@llvm-project//mlir:include/mlir/Analysis/CallInterfaces.td", + "@llvm-project//mlir:include/mlir/IR/OpBase.td", "ir/tf_generated_ops.td", "ir/tf_op_base.td", "ir/tf_ops.td", @@ -798,7 +799,7 @@ genrule( "translate/derived_attr_populator.inc", ], cmd = ("$(location :derived_attr_populator_gen) " + - "-I external/local_config_mlir/include " + + "-I external/llvm-project/mlir/include " + "-I external/org_tensorflow " + "$(location //tensorflow/compiler/mlir/tensorflow:ir/tf_ops.td) " + " -o $@"), tools = [":derived_attr_populator_gen"], @@ -819,11 +820,11 @@ gentbl( "transforms/generated_optimize.inc", ), ], - tblgen = "@local_config_mlir//:mlir-tblgen", + tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/optimize.td", td_srcs = [ ":tensorflow_ops_td_files", - "@local_config_mlir//:StdOpsTdFiles", + "@llvm-project//mlir:StdOpsTdFiles", ], ) @@ -846,13 +847,13 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core/platform:logging", "//tensorflow/stream_executor/lib", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:Parser", - "@local_config_mlir//:Pass", - "@local_config_mlir//:StandardOps", - "@local_config_mlir//:TransformUtils", - "@local_config_mlir//:Transforms", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", ], ) @@ -903,11 +904,11 @@ gentbl( "transforms/generated_lower_tf.inc", ), ], - tblgen = "@local_config_mlir//:mlir-tblgen", + tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/lower_tf.td", td_srcs = [ ":tensorflow_ops_td_files", - "@local_config_mlir//:StdOpsTdFiles", + "@llvm-project//mlir:StdOpsTdFiles", ], ) @@ -923,7 +924,8 @@ cc_library( ":lower_tf_inc_gen", ":tensorflow", "//tensorflow/core:framework", - "@local_config_mlir//:IR", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", ], alwayslink = 1, ) @@ -936,7 +938,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "@com_google_absl//absl/strings", - "@llvm//:support", + "@llvm-project//llvm:support", ], ) @@ -949,7 +951,7 @@ tf_cc_test( "//tensorflow/core:framework", "//tensorflow/core:test", "//tensorflow/core:test_main", - "@llvm//:support", + "@llvm-project//llvm:support", ], ) @@ -960,9 +962,9 @@ cc_library( deps = [ "//tensorflow/core:core_cpu_lib", "//tensorflow/core:framework", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:Support", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -978,9 +980,9 @@ tf_cc_test( "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:Support", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -991,8 +993,8 @@ cc_library( deps = [ "//tensorflow/core:lib", "//tensorflow/core/platform:logging", - "@llvm//:support", - "@local_config_mlir//:IR", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", ], ) @@ -1006,8 +1008,8 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/platform:test", - "@llvm//:support", - "@local_config_mlir//:IR", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", ], ) @@ -1017,9 +1019,9 @@ cc_library( hdrs = ["utils/bridge_logger.h"], deps = [ ":dump_mlir_util", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:Pass", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", ], ) @@ -1032,9 +1034,9 @@ cc_library( "//tensorflow/compiler/tf2xla:resource_operation_table", "//tensorflow/core:framework", "@com_google_absl//absl/strings", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:Pass", - "@local_config_mlir//:Support", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", ], ) diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc index 36a2560b7c8..785f8e7f966 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc +++ b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc @@ -26,16 +26,16 @@ limitations under the License. #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Block.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Location.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/Value.h" // TF:local_config_mlir -#include "mlir/Support/LLVM.h" // TF:local_config_mlir -#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Block.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/Value.h" // TF:llvm-project +#include "mlir/Support/LLVM.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -84,17 +84,17 @@ int64_t FindPassthroughArgumentForReturnValue(int64_t return_index, FuncOp func_op) { auto value = func_op.getBody().front().getTerminator()->getOperand(return_index); - assert(mlir::getElementTypeOrSelf(value->getType()).isa()); + assert(mlir::getElementTypeOrSelf(value.getType()).isa()); int64_t arg_index = -1; - auto try_parse_arg_index = [&arg_index](Value* v) { - auto resource_arg = llvm::dyn_cast(v); - if (resource_arg) arg_index = resource_arg->getArgNumber(); + auto try_parse_arg_index = [&arg_index](Value v) { + auto resource_arg = v.dyn_cast(); + if (resource_arg) arg_index = resource_arg.getArgNumber(); return arg_index; }; while (try_parse_arg_index(value) == -1) { - auto op = value->getDefiningOp(); + auto op = value.getDefiningOp(); assert(op); - int64_t res_num = llvm::dyn_cast(value)->getResultNumber(); + int64_t res_num = value.cast().getResultNumber(); if (auto graph = llvm::dyn_cast(op)) { value = graph.GetFetch().getOperand(res_num); } else if (auto island = llvm::dyn_cast(op)) { @@ -126,13 +126,13 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) { // Before having that, we assume function arguments do not alias each other. int64_t next_unique_id = 0; for (auto arg : func_op.getArguments()) { - if (!mlir::getElementTypeOrSelf(arg->getType()).isa()) + if (!mlir::getElementTypeOrSelf(arg.getType()).isa()) continue; resource_value_to_ids_[arg].insert(next_unique_id++); } llvm::StringMap var_handle_name_id_map; - auto forward_input_to_output = [&](Value* operand, Value* result) { - if (!mlir::getElementTypeOrSelf(result->getType()).isa()) + auto forward_input_to_output = [&](Value operand, Value result) { + if (!mlir::getElementTypeOrSelf(result.getType()).isa()) return; auto& result_ids = resource_value_to_ids_[result]; auto operand_it = resource_value_to_ids_.find(operand); @@ -161,8 +161,7 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) { // analysis. Inside that block, we can still treat its block arguments as // different resources. for (auto arg : replicate.GetBody().getArguments()) { - if (mlir::getElementTypeOrSelf(arg->getType()) - .isa()) { + if (mlir::getElementTypeOrSelf(arg.getType()).isa()) { resource_value_to_ids_[arg].insert(next_unique_id++); } } @@ -171,7 +170,7 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) { // If a result is a passthrough of the body input, use the corresponding // operand's resource IDs. for (auto result : llvm::enumerate(while_op.getResults())) { - if (!mlir::getElementTypeOrSelf(result.value()->getType()) + if (!mlir::getElementTypeOrSelf(result.value().getType()) .isa()) { continue; } @@ -192,7 +191,7 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) { // If a result is a passthrough of both branches' inputs, merge the // resource IDs of corresponding operands for the two inputs. for (auto result : llvm::enumerate(if_op.getResults())) { - if (!mlir::getElementTypeOrSelf(result.value()->getType()) + if (!mlir::getElementTypeOrSelf(result.value().getType()) .isa()) { continue; } @@ -211,7 +210,7 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) { } } else { for (auto result : op->getResults()) { - if (!mlir::getElementTypeOrSelf(result->getType()) + if (!mlir::getElementTypeOrSelf(result.getType()) .isa()) continue; resource_value_to_ids_[result].insert(kUnknownResourceId); @@ -220,7 +219,7 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) { }); } -bool ResourceAliasAnalysis::IsUnknownResource(const Value* resource) const { +bool ResourceAliasAnalysis::IsUnknownResource(const Value resource) const { auto it = resource_value_to_ids_.find(resource); assert(it != resource_value_to_ids_.end() && !it->getSecond().empty()); // The set is sorted so we only need to check the first element since @@ -231,7 +230,7 @@ bool ResourceAliasAnalysis::IsUnknownResource(const Value* resource) const { } const llvm::SmallSet& ResourceAliasAnalysis::GetResourceUniqueIds( - const Value* resource) const { + const Value resource) const { auto it = resource_value_to_ids_.find(resource); assert(it != resource_value_to_ids_.end() && "Unseen resource was queried"); return it->getSecond(); @@ -253,14 +252,14 @@ llvm::SmallDenseSet FindAccessedResources( llvm::SmallDenseSet resources; for (auto operand : op->getOperands()) { - if (!mlir::getElementTypeOrSelf(operand->getType()).isa()) + if (!mlir::getElementTypeOrSelf(operand.getType()).isa()) continue; if (alias_analysis.IsUnknownResource(operand)) return UnknownResourceSet(); const auto& ids = alias_analysis.GetResourceUniqueIds(operand); resources.insert(ids.begin(), ids.end()); } for (auto result : op->getResults()) { - if (!mlir::getElementTypeOrSelf(result->getType()).isa()) + if (!mlir::getElementTypeOrSelf(result.getType()).isa()) continue; if (alias_analysis.IsUnknownResource(result)) return UnknownResourceSet(); const auto& ids = alias_analysis.GetResourceUniqueIds(result); @@ -310,7 +309,21 @@ bool OpIsKnownToHaveNoSideEffect(Operation* op) { if (auto while_op = llvm::dyn_cast(op)) { return while_op.is_stateless(); } - return false; + + // Try to get the statefulness flag from the registry. + // + // TODO(yuanzx): Remove this after all ops are defined in the dialect. + if (op->getName().getDialect() != + TF::TensorFlowDialect::getDialectNamespace()) { + return false; + } + StringRef op_name = op->getName().getStringRef(); + // Drop the `tf.` prefix to query TF registry. + auto node_name = + op_name.drop_front(TensorFlowDialect::getDialectNamespace().size() + 1); + const tensorflow::OpRegistrationData* op_reg_data = + tensorflow::OpRegistry::Global()->LookUp(node_name.data()); + return op_reg_data && !op_reg_data->op_def.is_stateful(); } } // namespace diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h index 98df0941340..9457a3e8c6d 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h +++ b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h @@ -22,10 +22,10 @@ limitations under the License. #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringMap.h" -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/Region.h" // TF:local_config_mlir -#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/Region.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project namespace mlir { namespace TF { @@ -42,12 +42,12 @@ class ResourceAliasAnalysis { ResourceAliasAnalysis(ResourceAliasAnalysis&&) = default; // Returns if the analysis fails to resolve a resource-type value. - bool IsUnknownResource(const Value* resource) const; + bool IsUnknownResource(const Value resource) const; // Returns the set unique IDs which `resource` could alias. Requires that // IsUnknownResource(resource) == true. const llvm::SmallSet& GetResourceUniqueIds( - const Value* resource) const; + const Value resource) const; private: ResourceAliasAnalysis() = default; @@ -56,7 +56,7 @@ class ResourceAliasAnalysis { void AnalyzeFunction(FuncOp func_op); // Maps each resource-type value to a set of unique IDs that it could alias. - llvm::SmallDenseMap, 8> + llvm::SmallDenseMap, 8> resource_value_to_ids_; }; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.cc index 08712a7929b..e4b797d349a 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.cc @@ -18,9 +18,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h" -#include "mlir/IR/DialectImplementation.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir +#include "mlir/IR/DialectImplementation.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/OpImplementation.h" // TF:llvm-project namespace mlir { namespace TFControlFlow { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h b/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h index d3cf173473b..59a1cc21b28 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h @@ -23,9 +23,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_CONTROL_FLOW_OPS_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_CONTROL_FLOW_OPS_H_ -#include "mlir/IR/Dialect.h" // TF:local_config_mlir -#include "mlir/IR/OpDefinition.h" // TF:local_config_mlir -#include "mlir/IR/Types.h" // TF:local_config_mlir +#include "mlir/IR/Dialect.h" // TF:llvm-project +#include "mlir/IR/OpDefinition.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project namespace mlir { namespace TFControlFlow { @@ -90,8 +90,8 @@ class EnterOp static StringRef getOperationName() { return "_tf.Enter"; } - Value *getData() { return getOperand(0); } - void setData(Value *value) { setOperand(0, value); } + Value getData() { return getOperand(0); } + void setData(Value value) { setOperand(0, value); } LogicalResult verify(); }; @@ -172,8 +172,8 @@ class NextIterationSinkOp static StringRef getOperationName() { return "_tf.NextIteration.sink"; } - Value *getData() { return getOperand(0); } - void setData(Value *value) { setOperand(0, value); } + Value getData() { return getOperand(0); } + void setData(Value value) { setOperand(0, value); } LogicalResult verify(); }; @@ -202,8 +202,8 @@ class LoopCondOp using Op::Op; static StringRef getOperationName() { return "_tf.LoopCond"; } - Value *getData() { return getOperand(0); } - void setData(Value *value) { setOperand(0, value); } + Value getData() { return getOperand(0); } + void setData(Value value) { setOperand(0, value); } LogicalResult verify(); }; @@ -233,11 +233,11 @@ class SwitchOp : public Op::Impl, static StringRef getOperationName() { return "_tf.Switch"; } - Value *getData() { return getOperand(0); } - void setData(Value *value) { setOperand(0, value); } + Value getData() { return getOperand(0); } + void setData(Value value) { setOperand(0, value); } - Value *getPredicate() { return getOperand(1); } - void setPredicate(Value *value) { setOperand(1, value); } + Value getPredicate() { return getOperand(1); } + void setPredicate(Value value) { setOperand(1, value); } LogicalResult verify(); }; @@ -266,8 +266,8 @@ class ExitOp : public Op::Impl, using Op::Op; static StringRef getOperationName() { return "_tf.Exit"; } - Value *getData() { return getOperand(0); } - void setData(Value *value) { setOperand(0, value); } + Value getData() { return getOperand(0); } + void setData(Value value) { setOperand(0, value); } LogicalResult verify(); }; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc index ffba86e78ff..b313b06bd3b 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc @@ -25,19 +25,19 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/SMLoc.h" -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/OpDefinition.h" // TF:local_config_mlir -#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir -#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir -#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir -#include "mlir/IR/Types.h" // TF:local_config_mlir -#include "mlir/IR/Value.h" // TF:local_config_mlir -#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir -#include "mlir/Support/STLExtras.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/OpDefinition.h" // TF:llvm-project +#include "mlir/IR/OpImplementation.h" // TF:llvm-project +#include "mlir/IR/OperationSupport.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/TypeUtilities.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/IR/Value.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/Support/STLExtras.h" // TF:llvm-project #include "tensorflow/core/platform/logging.h" namespace mlir { @@ -183,12 +183,12 @@ void Print(ReplicateOp op, OpAsmPrinter* p) { if (op.getNumOperands()) { *p << '('; Block& block = op.body().front(); - interleaveComma(block.getArguments(), *p, [&](BlockArgument* arg) { - const int block_arg_num = arg->getArgNumber(); + interleaveComma(block.getArguments(), *p, [&](BlockArgument arg) { + const int block_arg_num = arg.getArgNumber(); *p << '['; p->printOperands(std::next(op.operand_begin(), block_arg_num * n), std::next(op.operand_begin(), (block_arg_num + 1) * n)); - *p << "] as " << *arg << ": " << arg->getType(); + *p << "] as " << arg << ": " << arg.getType(); }); *p << ')'; } @@ -229,13 +229,13 @@ LogicalResult Verify(ReplicateOp op) { // Check replicated input types match block argument types. for (auto block_arg : block.getArguments()) { - Type block_arg_type = block_arg->getType(); - for (int i = n * block_arg->getArgNumber(), e = i + n; i < e; ++i) + Type block_arg_type = block_arg.getType(); + for (int i = n * block_arg.getArgNumber(), e = i + n; i < e; ++i) if (failed(VerifyCompatibleTypes(block_arg_type, - op.getOperand(i)->getType()))) + op.getOperand(i).getType()))) return op.emitOpError() << "incompatible types for operand " << i - << " and block argument " << block_arg->getArgNumber(); + << " and block argument " << block_arg.getArgNumber(); } Operation& terminator = block.back(); @@ -280,9 +280,9 @@ void BuildReplicateOp( for (auto& replicated_input : replicated_inputs) { DCHECK_EQ(llvm::size(replicated_input.first), n); - for (auto* input : replicated_input.first) { + for (auto input : replicated_input.first) { DCHECK(succeeded( - VerifyCompatibleTypes(input->getType(), replicated_input.second))); + VerifyCompatibleTypes(input.getType(), replicated_input.second))); state->addOperands(input); } block.addArgument(replicated_input.second); @@ -296,7 +296,7 @@ void BuildReplicateOp( void ReplicateOp::build( Builder* builder, OperationState& state, int n, llvm::ArrayRef devices, - llvm::ArrayRef, Type>> replicated_inputs, + llvm::ArrayRef, Type>> replicated_inputs, llvm::ArrayRef replica_output_types) { BuildReplicateOp(builder, &state, n, devices, replicated_inputs, replica_output_types); diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.h index 91370bc6501..a500af45c44 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.h @@ -19,8 +19,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_DEVICE_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_DEVICE_H_ -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Dialect.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Dialect.h" // TF:llvm-project namespace mlir { namespace tf_device { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td index 403932ed9a8..88cc08aca6d 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td @@ -185,7 +185,7 @@ For example: let builders = [ OpBuilder<"Builder* builder, OperationState& state, int n, " "llvm::ArrayRef devices, " - "llvm::ArrayRef, Type>>" + "llvm::ArrayRef, Type>>" " replicated_inputs, " "llvm::ArrayRef replica_output_types">, OpBuilder<"Builder* builder, OperationState& state, int n, " diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc index 5a018a39fd7..13dc2993371 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc @@ -26,23 +26,23 @@ limitations under the License. #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/Dialect/Traits.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/DialectImplementation.h" // TF:local_config_mlir -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Matchers.h" // TF:local_config_mlir -#include "mlir/IR/OpDefinition.h" // TF:local_config_mlir -#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir -#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/Types.h" // TF:local_config_mlir -#include "mlir/IR/Value.h" // TF:local_config_mlir -#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir -#include "mlir/Transforms/FoldUtils.h" // TF:local_config_mlir -#include "mlir/Transforms/InliningUtils.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/Traits.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/DialectImplementation.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Matchers.h" // TF:llvm-project +#include "mlir/IR/OpDefinition.h" // TF:llvm-project +#include "mlir/IR/OpImplementation.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/IR/Value.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/Transforms/FoldUtils.h" // TF:llvm-project +#include "mlir/Transforms/InliningUtils.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" namespace mlir { @@ -167,7 +167,7 @@ namespace { LogicalResult VerifyControlOperandsAfterAllData(Operation *op) { bool found_control = false; for (int operand_idx : llvm::seq(0, op->getNumOperands())) { - if (op->getOperand(operand_idx)->getType().isa()) { + if (op->getOperand(operand_idx).getType().isa()) { found_control = true; continue; } @@ -216,9 +216,9 @@ LogicalResult Verify(GraphOp graph) { return fetch.emitOpError() << "does not have enough operands to cover the " "graph returned values"; for (int i : llvm::seq(0, fetch.getNumOperands())) { - Value *operand = fetch.getOperand(i); + Value operand = fetch.getOperand(i); // Break out of the loop at the first control operand encountered. - if (operand->getType().isa()) { + if (operand.getType().isa()) { if (i != graph.getNumResults()) return fetch.emitOpError() << "operand #" << i @@ -228,7 +228,7 @@ LogicalResult Verify(GraphOp graph) { if (i >= graph.getNumResults()) return fetch.emitOpError() << "operand #" << i << " does not have a graph results to bind"; - if (graph.getResult(i)->getType() != operand->getType()) + if (graph.getResult(i).getType() != operand.getType()) return fetch.emitOpError() << "operand #" << i << " type mismatch graph results"; } @@ -331,8 +331,8 @@ LogicalResult Verify(IslandOp island) { << "has " << yield.getNumOperands() << " operand, but island returns " << result_count; for (int operand_idx : llvm::seq(0, yield.getNumOperands())) { - if (island.getResult(operand_idx)->getType() != - yield.getOperand(operand_idx)->getType()) + if (island.getResult(operand_idx).getType() != + yield.getOperand(operand_idx).getType()) return yield.emitOpError() << "operand #" << operand_idx << " type mismatch island results"; } @@ -340,7 +340,7 @@ LogicalResult Verify(IslandOp island) { // Check that there aren't any control results other than the last one. Type control_type = ControlType::get(island.getContext()); for (int operand_idx : llvm::seq(0, island.getNumResults() - 1)) { - if (island.getResult(operand_idx)->getType() == control_type) + if (island.getResult(operand_idx).getType() == control_type) return yield.emitOpError() << "unexpected control type for operand #" << operand_idx; } @@ -503,12 +503,12 @@ ParseResult ParseSwitchOp(OpAsmParser &parser, OperationState &result) { void Print(SwitchOp switch_op, OpAsmPrinter &p) { p << switch_op.getOperationName() << ' '; p.printOperands(switch_op.getOperands()); - Type data_operand_ty = switch_op.data()->getType(); + Type data_operand_ty = switch_op.data().getType(); // If the types aren't perfectly matching, print the functional type syntax // else print the shorter single type. p << " : "; - if (switch_op.trueOutput()->getType() != data_operand_ty || - switch_op.falseOutput()->getType() != data_operand_ty) { + if (switch_op.trueOutput().getType() != data_operand_ty || + switch_op.falseOutput().getType() != data_operand_ty) { p.printFunctionalType(switch_op.getOperation()); } else { p << switch_op.getType(0); @@ -535,12 +535,12 @@ LogicalResult Verify(SwitchNOp switchn) { << "expect `num_outs` (" << num_outs.getInt() << ") results but got " << (switchn.getNumResults() - 1); - auto operand0_type = switchn.getOperand(0)->getType(); - for (Value *result : switchn.outputs()) - if (operand0_type != result->getType()) + auto operand0_type = switchn.getOperand(0).getType(); + for (Value result : switchn.outputs()) + if (operand0_type != result.getType()) return switchn.emitOpError() << "type mismatch between data operand and result: " - << operand0_type << " vs " << result->getType(); + << operand0_type << " vs " << result.getType(); return success(); } @@ -616,12 +616,12 @@ LogicalResult Verify(MergeOp merge) { if (!merge.getNumOperands()) return merge.emitOpError() << "expects at least one operand"; - Type data_type = merge.getOperand(0)->getType(); + Type data_type = merge.getOperand(0).getType(); if (data_type.isa()) return merge.emitOpError() << "expects a non-control input"; // Check that each operand can be individually broadcasted to the output type. - Type output_type = merge.output()->getType(); + Type output_type = merge.output().getType(); TensorType output_tensor_ty = output_type.dyn_cast(); if (!output_tensor_ty) { return merge.emitOpError() @@ -666,7 +666,7 @@ void Print(MergeOp merge, OpAsmPrinter &p) { bool use_short_form = true; int num_data_operands = 0; - Type output_type = merge.output()->getType(); + Type output_type = merge.output().getType(); for (Type operand_type : merge.getOperandTypes()) { if (operand_type.isa()) break; num_data_operands++; @@ -750,7 +750,7 @@ void Print(EnterOp enter, OpAsmPrinter &p) { // If the types aren't perfectly matching, print the functional type syntax // else print the shorter single type. p << " : "; - if (enter.data()->getType() != enter.output()->getType()) { + if (enter.data().getType() != enter.output().getType()) { p.printFunctionalType(enter.getOperation()); } else { p << enter.getType(0); @@ -824,10 +824,10 @@ ParseResult ParseEnterOp(OpAsmParser &parser, OperationState &result) { namespace { LogicalResult Verify(NextIterationSourceOp source) { - Value *token = source.token(); - if (!token->hasOneUse()) + Value token = source.token(); + if (!token.hasOneUse()) return source.emitOpError() << "expects a single user for produced token"; - if (!isa(*token->user_begin())) + if (!isa(*token.user_begin())) return source.emitOpError() << "token should be consumed by a sink op"; return success(); } @@ -858,8 +858,8 @@ ParseResult ParseNextIterationSourceOp(OpAsmParser &parser, namespace { LogicalResult Verify(NextIterationSinkOp sink) { - Value *token = sink.token(); - Operation *definingOp = token->getDefiningOp(); + Value token = sink.token(); + Operation *definingOp = token.getDefiningOp(); if (!definingOp) return sink.emitOpError() << "expects a token directly produced by a " "tf_executor.NextIteration.Source op: "; @@ -867,11 +867,11 @@ LogicalResult Verify(NextIterationSinkOp sink) { if (!source) return sink.emitOpError() << "expects a token produced by a " "tf_executor.NextIteration.Source op: "; - if (source.output()->getType() != sink.input()->getType()) + if (source.output().getType() != sink.input().getType()) return sink.emitOpError() - << "input type " << sink.input()->getType() + << "input type " << sink.input().getType() << " mismatch the tf_executor.NextIteration.Source output type: " - << source.output()->getType(); + << source.output().getType(); return success(); } @@ -880,7 +880,7 @@ void Print(NextIterationSinkOp next_iteration, OpAsmPrinter &p) { p.printOperand(next_iteration.getOperand(0)); p << "] "; p.printOperands(llvm::drop_begin(next_iteration.getOperands(), 1)); - p << " : " << next_iteration.getOperand(1)->getType(); + p << " : " << next_iteration.getOperand(1).getType(); p.printOptionalAttrDict(next_iteration.getAttrs()); } @@ -980,11 +980,11 @@ void Print(LoopCondOp loop_cond, OpAsmPrinter &p) { p.printOperands(loop_cond.getOperands()); // If the types aren't matching (broadcast), print the functional type syntax. - if (loop_cond.input()->getType() != loop_cond.output()->getType()) { + if (loop_cond.input().getType() != loop_cond.output().getType()) { p << " : "; p.printFunctionalType(loop_cond.getOperation()); } else { - p << " : " << loop_cond.input()->getType(); + p << " : " << loop_cond.input().getType(); } p.printOptionalAttrDict(loop_cond.getAttrs()); @@ -1087,18 +1087,18 @@ struct HoistInnerOpsSingleIslandGraph : public OpRewritePattern { YieldOp yield_op = island_op.GetYield(); // Map graph results to inner ops results of single island. - llvm::SmallVector new_rets; - for (Value *operand : fetch_op.fetches()) { + llvm::SmallVector new_rets; + for (Value operand : fetch_op.fetches()) { // Control results should not be propagated out. - if (operand->getType().isa()) break; + if (operand.getType().isa()) break; - if (operand->getDefiningOp() != island_op) { + if (operand.getDefiningOp() != island_op) { // Operand is not from island, simply propagate it out. new_rets.push_back(operand); } else { // Lookup yield operand in island for inner op result. - auto result = llvm::cast(operand); - new_rets.push_back(yield_op.getOperand(result->getResultNumber())); + auto result = operand.cast(); + new_rets.push_back(yield_op.getOperand(result.getResultNumber())); } } @@ -1138,7 +1138,7 @@ struct DropEmptyIslandNoOperandNoDataResult !HasSingleOpInBlock(&op.GetBody())) return matchFailure(); - for (auto &use : llvm::make_early_inc_range(op.control()->getUses())) + for (auto &use : llvm::make_early_inc_range(op.control().getUses())) use.getOwner()->eraseOperand(use.getOperandNumber()); rewriter.eraseOp(op); @@ -1158,7 +1158,7 @@ struct DropEmptyIslandNoOperandOneDataResult PatternMatchResult matchAndRewrite(IslandOp op, PatternRewriter &rewriter) const override { if (op.getNumOperands() != 0 || op.getNumResults() != 2 || - !op.control()->use_empty() || + !op.control().use_empty() || !HasSingleOpInBlock(&op.GetBody())) return matchFailure(); @@ -1193,7 +1193,7 @@ struct DropEmptyControlTrigger : public OpRewritePattern { PatternRewriter &rewriter) const override { if (op.getNumOperands() != 0) return matchFailure(); - for (auto &use : llvm::make_early_inc_range(op.control()->getUses())) + for (auto &use : llvm::make_early_inc_range(op.control().getUses())) use.getOwner()->eraseOperand(use.getOperandNumber()); rewriter.eraseOp(op); diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h index 8df3ecb2559..b7d8549ece7 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h @@ -21,13 +21,13 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_EXECUTOR_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_EXECUTOR_H_ -#include "mlir/Dialect/Traits.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Dialect.h" // TF:local_config_mlir -#include "mlir/IR/Matchers.h" // TF:local_config_mlir -#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir +#include "mlir/Dialect/Traits.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Dialect.h" // TF:llvm-project +#include "mlir/IR/Matchers.h" // TF:llvm-project +#include "mlir/IR/OpImplementation.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td index 0f243957869..3922981bd50 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td @@ -460,7 +460,7 @@ def TfExecutor_NextIterationSourceOp : TfExecutor_Op<"NextIteration.Source", let extraClassDeclaration = [{ NextIterationSinkOp GetSink() { - return cast(*token()->user_begin()); + return cast(*token().user_begin()); } }]; @@ -514,8 +514,8 @@ def TfExecutor_NextIterationSinkOp : TfExecutor_Op<"NextIteration.Sink", ); let builders = [OpBuilder< - "Builder *builder, OperationState &result, Value *token, " - "ArrayRef operands, ArrayRef attributes = {}", + "Builder *builder, OperationState &result, Value token, " + "ArrayRef operands, ArrayRef attributes = {}", [{ assert(operands.size() >= 1 && "tf_executor.NextIteration.Sink builder " "expects at least one operand"); @@ -594,7 +594,7 @@ def TfExecutor_ControlTriggerOp : TfExecutor_Op<"ControlTrigger", let builders = [OpBuilder< "Builder *builder, OperationState &result, " - "ArrayRef operands, ArrayRef attributes = {}", + "ArrayRef operands, ArrayRef attributes = {}", [{ assert(operands.size() >= 1 && "tf_executor.ControlTrigger builder " "expects at least one operand"); diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 7257a1ba8f0..9b3d749864c 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -821,6 +821,40 @@ def TF_ConcatOp : TF_Op<"Concat", [NoSideEffect]> { }]; } +def TF_ConcatOffsetOp : TF_Op<"ConcatOffset", [NoSideEffect]> { + let summary = "Computes offsets of concat inputs within its output."; + + let description = [{ +For example: + +``` +# 'x' is [2, 2, 7] +# 'y' is [2, 3, 7] +# 'z' is [2, 5, 7] +concat_offset(2, [x, y, z]) => [0, 0, 0], [0, 2, 0], [0, 5, 0] +``` + +This is typically used by gradient computations for a concat operation. + }]; + + let arguments = (ins + I32Tensor:$concat_dim, + Variadic:$shape + ); + + let results = (outs + Variadic:$offset + ); + + TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<1>; + + let verifier = [{ + return Verify(*this); + }]; + + let hasFolder = 1; +} + def TF_ConcatV2Op : TF_Op<"ConcatV2", [NoSideEffect]> { let summary = "Concatenates tensors along one dimension."; @@ -1350,6 +1384,10 @@ as illustrated on the following example: TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>; + + let verifier = [{ + return Verify(*this); + }]; } def TF_EinsumOp : TF_Op<"Einsum", [NoSideEffect]> { @@ -1506,8 +1544,8 @@ tf.math.equal(x, y) ==> array([True, True]) TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; let builders = [ - OpBuilder<"Builder* builder, OperationState& result, Value* x, " - "Value* y, BoolAttr incompatible_shape_error"> + OpBuilder<"Builder* builder, OperationState& result, Value x, " + "Value y, BoolAttr incompatible_shape_error"> ]; let verifier = [{ @@ -1607,6 +1645,11 @@ size 1. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr Tdim = TF_DerivedOperandTypeAttr<1>; + + let builders = [ + OpBuilder<"Builder* builder, OperationState& result, Value condition, " + "Value dim"> + ]; } def TF_FakeQuantWithMinMaxArgsOp : TF_Op<"FakeQuantWithMinMaxArgs", [NoSideEffect, SameOperandsAndResultType]> { @@ -1883,6 +1926,102 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors. }]; } +def TF_FusedBatchNormGradOp : TF_Op<"FusedBatchNormGrad", [NoSideEffect]> { + let summary = "Gradient for batch normalization."; + + let description = [{ +Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". +The size of 1D Tensors matches the dimension C of the 4D Tensors. + }]; + + let arguments = (ins + F32Tensor:$y_backprop, + F32Tensor:$x, + F32Tensor:$scale, + F32Tensor:$reserve_space_1, + F32Tensor:$reserve_space_2, + + DefaultValuedAttr:$epsilon, + DefaultValuedAttr:$data_format, + DefaultValuedAttr:$is_training + ); + + let results = (outs + F32Tensor:$x_backprop, + F32Tensor:$scale_backprop, + F32Tensor:$offset_backprop, + F32Tensor:$reserve_space_3, + F32Tensor:$reserve_space_4 + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_FusedBatchNormGradV2Op : TF_Op<"FusedBatchNormGradV2", [NoSideEffect]> { + let summary = "Gradient for batch normalization."; + + let description = [{ +Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". +The size of 1D Tensors matches the dimension C of the 4D Tensors. + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32]>:$y_backprop, + TensorOf<[BF16, F16, F32]>:$x, + F32Tensor:$scale, + F32Tensor:$reserve_space_1, + F32Tensor:$reserve_space_2, + + DefaultValuedAttr:$epsilon, + DefaultValuedAttr:$data_format, + DefaultValuedAttr:$is_training + ); + + let results = (outs + TensorOf<[BF16, F16, F32]>:$x_backprop, + F32Tensor:$scale_backprop, + F32Tensor:$offset_backprop, + F32Tensor:$reserve_space_3, + F32Tensor:$reserve_space_4 + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<3>; +} + +def TF_FusedBatchNormGradV3Op : TF_Op<"FusedBatchNormGradV3", [NoSideEffect]> { + let summary = "Gradient for batch normalization."; + + let description = [{ +Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". +The size of 1D Tensors matches the dimension C of the 4D Tensors. + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32]>:$y_backprop, + TensorOf<[BF16, F16, F32]>:$x, + F32Tensor:$scale, + F32Tensor:$reserve_space_1, + F32Tensor:$reserve_space_2, + F32Tensor:$reserve_space_3, + + DefaultValuedAttr:$epsilon, + DefaultValuedAttr:$data_format, + DefaultValuedAttr:$is_training + ); + + let results = (outs + TensorOf<[BF16, F16, F32]>:$x_backprop, + F32Tensor:$scale_backprop, + F32Tensor:$offset_backprop, + F32Tensor:$reserve_space_4, + F32Tensor:$reserve_space_5 + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<3>; +} + def TF_FusedBatchNormV3Op : TF_Op<"FusedBatchNormV3", [NoSideEffect]> { let summary = "Batch normalization."; @@ -2455,6 +2594,55 @@ def TF_LeakyReluOp : TF_Op<"LeakyRelu", [NoSideEffect, SameOperandsAndResultType let hasFolder = 1; } +def TF_LeftShiftOp : TF_Op<"LeftShift", [Broadcastable, NoSideEffect]>, + WithBroadcastableBinOpBuilder { + let summary = "Elementwise computes the bitwise left-shift of `x` and `y`."; + + let description = [{ +If `y` is negative, or greater than or equal to the width of `x` in bits the +result is implementation defined. + +Example: + +```python +import tensorflow as tf +from tensorflow.python.ops import bitwise_ops +import numpy as np +dtype_list = [tf.int8, tf.int16, tf.int32, tf.int64] + +for dtype in dtype_list: + lhs = tf.constant([-1, -5, -3, -14], dtype=dtype) + rhs = tf.constant([5, 0, 7, 11], dtype=dtype) + + left_shift_result = bitwise_ops.left_shift(lhs, rhs) + + print(left_shift_result) + +# This will print: +# tf.Tensor([ -32 -5 -128 0], shape=(4,), dtype=int8) +# tf.Tensor([ -32 -5 -384 -28672], shape=(4,), dtype=int16) +# tf.Tensor([ -32 -5 -384 -28672], shape=(4,), dtype=int32) +# tf.Tensor([ -32 -5 -384 -28672], shape=(4,), dtype=int64) + +lhs = np.array([-2, 64, 101, 32], dtype=np.int8) +rhs = np.array([-1, -5, -3, -14], dtype=np.int8) +bitwise_ops.left_shift(lhs, rhs) +# +``` + }]; + + let arguments = (ins + TF_IntTensor:$x, + TF_IntTensor:$y + ); + + let results = (outs + TF_IntTensor:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_LessOp : TF_Op<"Less", [Broadcastable, NoSideEffect]>, WithBroadcastableCmpOpBuilder { let summary = "Returns the truth value of (x < y) element-wise."; @@ -2548,6 +2736,31 @@ tf.math.log(x) ==> [-inf, -0.6931472, 0. , 1.609438] let hasCanonicalizer = 1; } +def TF_Log1pOp : TF_Op<"Log1p", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes natural logarithm of (1 + x) element-wise."; + + let description = [{ +I.e., \\(y = \log_e (1 + x)\\). + +Example: + +```python +x = tf.constant([0, 0.5, 1, 5]) +tf.math.log1p(x) ==> [0., 0.4054651, 0.6931472, 1.7917595] +``` + }]; + + let arguments = (ins + TF_FpOrComplexTensor:$x + ); + + let results = (outs + TF_FpOrComplexTensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_LogSoftmaxOp : TF_Op<"LogSoftmax", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes log softmax activations."; @@ -3165,8 +3378,8 @@ retained with length 1. TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>; let builders = [OpBuilder< - "Builder *builder, OperationState &result, Value *input, " - "Value *reduction_indices, BoolAttr keep_dims" + "Builder *builder, OperationState &result, Value input, " + "Value reduction_indices, BoolAttr keep_dims" >]; } @@ -3577,8 +3790,8 @@ def TF_NotEqualOp : TF_Op<"NotEqual", [Commutative, NoSideEffect]> { TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; let builders = [ - OpBuilder<"Builder* builder, OperationState& result, Value* x, " - "Value* y, BoolAttr incompatible_shape_error"> + OpBuilder<"Builder* builder, OperationState& result, Value x, " + "Value y, BoolAttr incompatible_shape_error"> ]; let verifier = [{ @@ -3695,6 +3908,12 @@ output = TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>; TF_DerivedOperandTypeAttr TI = TF_DerivedOperandTypeAttr<0>; + let builders = [ + OpBuilder<"Builder* builder, OperationState& result, Value indices, " + "Value depth, Value on_value, Value off_value, " + "IntegerAttr axis"> + ]; + let verifier = [{ return Verify(*this); }]; @@ -4125,8 +4344,8 @@ tf.range(start, limit, delta) ==> [3, 6, 9, 12, 15] TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<0>; let builders = [ - OpBuilder<"Builder* builder, OperationState& result, Value* start, " - "Value* limit, Value* delta"> + OpBuilder<"Builder* builder, OperationState& result, Value start, " + "Value limit, Value delta"> ]; } @@ -4160,7 +4379,7 @@ of the tensor. Rank is also known as "order", "degree", or "ndims." TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; let builders = [ - OpBuilder<"Builder* builder, OperationState& result, Value* input"> + OpBuilder<"Builder* builder, OperationState& result, Value input"> ]; } @@ -4396,7 +4615,7 @@ reshape(t, []) ==> 7 let builders = [ OpBuilder< - "Builder* builder, OperationState& result, Value* tensor, Value* shape"> + "Builder* builder, OperationState& result, Value tensor, Value shape"> ]; let verifier = [{ @@ -4666,6 +4885,58 @@ reverse(t, dims) ==> [[[[8, 9, 10, 11], TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>; } +def TF_RightShiftOp : TF_Op<"RightShift", [Broadcastable, NoSideEffect]>, + WithBroadcastableBinOpBuilder { + let summary = "Elementwise computes the bitwise right-shift of `x` and `y`."; + + let description = [{ +Performs a logical shift for unsigned integer types, and an arithmetic shift +for signed integer types. + +If `y` is negative, or greater than or equal to than the width of `x` in bits +the result is implementation defined. + +Example: + +```python +import tensorflow as tf +from tensorflow.python.ops import bitwise_ops +import numpy as np +dtype_list = [tf.int8, tf.int16, tf.int32, tf.int64] + +for dtype in dtype_list: + lhs = tf.constant([-1, -5, -3, -14], dtype=dtype) + rhs = tf.constant([5, 0, 7, 11], dtype=dtype) + + right_shift_result = bitwise_ops.right_shift(lhs, rhs) + + print(right_shift_result) + +# This will print: +# tf.Tensor([-1 -5 -1 -1], shape=(4,), dtype=int8) +# tf.Tensor([-1 -5 -1 -1], shape=(4,), dtype=int16) +# tf.Tensor([-1 -5 -1 -1], shape=(4,), dtype=int32) +# tf.Tensor([-1 -5 -1 -1], shape=(4,), dtype=int64) + +lhs = np.array([-2, 64, 101, 32], dtype=np.int8) +rhs = np.array([-1, -5, -3, -14], dtype=np.int8) +bitwise_ops.right_shift(lhs, rhs) +# +``` + }]; + + let arguments = (ins + TF_IntTensor:$x, + TF_IntTensor:$y + ); + + let results = (outs + TF_IntTensor:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_RoundOp : TF_Op<"Round", [NoSideEffect, SameOperandsAndResultType]> { let summary = [{ Rounds the values of a tensor to the nearest integer, element-wise. @@ -4725,6 +4996,212 @@ is the corresponding input gradient. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_SegmentMaxOp : TF_Op<"SegmentMax", [NoSideEffect]> { + let summary = "Computes the maximum along segments of a tensor."; + + let description = [{ +Read +[the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) +for an explanation of segments. + +Computes a tensor such that +\\(output_i = \max_j(data_j)\\) where `max` is over `j` such +that `segment_ids[j] == i`. + +If the max is empty for a given segment ID `i`, `output[i] = 0`. + +
+ +
+ +For example: + +``` +c = tf.constant([[1,2,3,4], [4, 3, 2, 1], [5,6,7,8]]) +tf.segment_max(c, tf.constant([0, 0, 1])) +# ==> [[4, 3, 3, 4], +# [5, 6, 7, 8]] +``` + }]; + + let arguments = (ins + TF_IntOrFpTensor:$data, + TF_I32OrI64Tensor:$segment_ids + ); + + let results = (outs + TF_IntOrFpTensor:$output + ); + + TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_SegmentMeanOp : TF_Op<"SegmentMean", [NoSideEffect]> { + let summary = "Computes the mean along segments of a tensor."; + + let description = [{ +Read +[the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) +for an explanation of segments. + +Computes a tensor such that +\\(output_i = \frac{\sum_j data_j}{N}\\) where `mean` is +over `j` such that `segment_ids[j] == i` and `N` is the total number of +values summed. + +If the mean is empty for a given segment ID `i`, `output[i] = 0`. + +
+ +
+ +For example: + +``` +c = tf.constant([[1.0,2,3,4], [4, 3, 2, 1], [5,6,7,8]]) +tf.segment_mean(c, tf.constant([0, 0, 1])) +# ==> [[2.5, 2.5, 2.5, 2.5], +# [5, 6, 7, 8]] +``` + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$data, + TF_I32OrI64Tensor:$segment_ids + ); + + let results = (outs + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + ); + + TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_SegmentMinOp : TF_Op<"SegmentMin", [NoSideEffect]> { + let summary = "Computes the minimum along segments of a tensor."; + + let description = [{ +Read +[the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) +for an explanation of segments. + +Computes a tensor such that +\\(output_i = \min_j(data_j)\\) where `min` is over `j` such +that `segment_ids[j] == i`. + +If the min is empty for a given segment ID `i`, `output[i] = 0`. + +
+ +
+ +For example: + +``` +c = tf.constant([[1,2,3,4], [4, 3, 2, 1], [5,6,7,8]]) +tf.segment_min(c, tf.constant([0, 0, 1])) +# ==> [[1, 2, 2, 1], +# [5, 6, 7, 8]] +``` + }]; + + let arguments = (ins + TF_IntOrFpTensor:$data, + TF_I32OrI64Tensor:$segment_ids + ); + + let results = (outs + TF_IntOrFpTensor:$output + ); + + TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_SegmentProdOp : TF_Op<"SegmentProd", [NoSideEffect]> { + let summary = "Computes the product along segments of a tensor."; + + let description = [{ +Read +[the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) +for an explanation of segments. + +Computes a tensor such that +\\(output_i = \prod_j data_j\\) where the product is over `j` such +that `segment_ids[j] == i`. + +If the product is empty for a given segment ID `i`, `output[i] = 1`. + +
+ +
+ +For example: + +``` +c = tf.constant([[1,2,3,4], [4, 3, 2, 1], [5,6,7,8]]) +tf.segment_prod(c, tf.constant([0, 0, 1])) +# ==> [[4, 6, 6, 4], +# [5, 6, 7, 8]] +``` + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$data, + TF_I32OrI64Tensor:$segment_ids + ); + + let results = (outs + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + ); + + TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_SegmentSumOp : TF_Op<"SegmentSum", [NoSideEffect]> { + let summary = "Computes the sum along segments of a tensor."; + + let description = [{ +Read +[the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) +for an explanation of segments. + +Computes a tensor such that +\\(output_i = \sum_j data_j\\) where sum is over `j` such +that `segment_ids[j] == i`. + +If the sum is empty for a given segment ID `i`, `output[i] = 0`. + +
+ +
+ +For example: + +``` +c = tf.constant([[1,2,3,4], [4, 3, 2, 1], [5,6,7,8]]) +tf.segment_sum(c, tf.constant([0, 0, 1])) +# ==> [[5, 5, 5, 5], +# [5, 6, 7, 8]] +``` + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$data, + TF_I32OrI64Tensor:$segment_ids + ); + + let results = (outs + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + ); + + TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_SelectOp : TF_Op<"Select", [NoSideEffect]> { let summary = "Selects elements from `x` or `y`, depending on `condition`."; @@ -4799,6 +5276,10 @@ def TF_SelectV2Op : TF_Op<"SelectV2", [NoSideEffect]> { ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; + + let builders = [ + OpBuilder<"Builder* builder, OperationState& result, Value condition, Value e, Value t"> + ]; } def TF_ShapeOp : TF_Op<"Shape", [NoSideEffect]> { @@ -4831,7 +5312,7 @@ shape(t) ==> [2, 2, 3] }]; let builders = [ - OpBuilder<"Builder* builder, OperationState& result, Value* input, BoolAttr use32Bit"> + OpBuilder<"Builder* builder, OperationState& result, Value input, BoolAttr use32Bit"> ]; let hasFolder = 1; @@ -5207,6 +5688,36 @@ x = [[[[1, 2, 3, 4], TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_SparseSoftmaxCrossEntropyWithLogitsOp : TF_Op<"SparseSoftmaxCrossEntropyWithLogits", [NoSideEffect]> { + let summary = [{ +Computes softmax cross entropy cost and gradients to backpropagate. + }]; + + let description = [{ +Unlike `SoftmaxCrossEntropyWithLogits`, this operation does not accept +a matrix of label probabilities, but rather a single label per row +of features. This label is considered to have probability 1.0 for the +given row. + +Inputs are the logits, not probabilities. + }]; + + let arguments = (ins + TF_FpTensor:$features, + TF_I32OrI64Tensor:$labels + ); + + let results = (outs + TF_FpTensor:$loss, + TF_FpTensor:$backprop + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr Tlabels = TF_DerivedOperandTypeAttr<1>; + + let verifier = [{ return Verify(*this); }]; +} + def TF_SparseToDenseOp : TF_Op<"SparseToDense", [NoSideEffect]> { let summary = "Converts a sparse representation into a dense tensor."; @@ -5541,6 +6052,17 @@ receive 0, 0, and 1, respectively. The appropriate bits in `begin_mask` and TF_DerivedOperandTypeAttr Index = TF_DerivedOperandTypeAttr<1>; let verifier = [{ return VerifyStridedSliceBase(*this); }]; + + let extraClassDeclaration = [{ + // If sliced shape is able to be deduced, returns true, updates + // `begin_indices`, `end_indices`, and `strides` with their canonical + // values, respectively. + bool GetSlicedBoundRanges( + ::llvm::ArrayRef shape, + ::llvm::SmallVectorImpl *begin_indices, + ::llvm::SmallVectorImpl *end_indices, + ::llvm::SmallVectorImpl *strides); + }]; } def TF_StridedSliceGradOp : TF_Op<"StridedSliceGrad", [NoSideEffect]> { @@ -5641,8 +6163,8 @@ retained with length 1. TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>; let builders = [OpBuilder< - "Builder *builder, OperationState &result, Value *input, " - "Value *reduction_indices, BoolAttr keep_dims" + "Builder *builder, OperationState &result, Value input, " + "Value reduction_indices, BoolAttr keep_dims" >]; } @@ -5969,6 +6491,103 @@ num_elements: optional. If not -1, the number of elements in the list. }]; } +def TF_TensorScatterUpdateOp : TF_Op<"TensorScatterUpdate", [NoSideEffect]> { + let summary = [{ +Scatter `updates` into an existing tensor according to `indices`. + }]; + + let description = [{ +This operation creates a new tensor by applying sparse `updates` to the passed +in `tensor`. +This operation is very similar to `tf.scatter_nd`, except that the updates are +scattered onto an existing tensor (as opposed to a zero-tensor). If the memory +for the existing tensor cannot be re-used, a copy is made and updated. + +If `indices` contains duplicates, then their updates are accumulated (summed). + +**WARNING**: The order in which updates are applied is nondeterministic, so the +output will be nondeterministic if `indices` contains duplicates -- because +of some numerical approximation issues, numbers summed in different order +may yield different results. + +`indices` is an integer tensor containing indices into a new tensor of shape +`shape`. The last dimension of `indices` can be at most the rank of `shape`: + + indices.shape[-1] <= shape.rank + +The last dimension of `indices` corresponds to indices into elements +(if `indices.shape[-1] = shape.rank`) or slices +(if `indices.shape[-1] < shape.rank`) along dimension `indices.shape[-1]` of +`shape`. `updates` is a tensor with shape + + indices.shape[:-1] + shape[indices.shape[-1]:] + +The simplest form of scatter is to insert individual elements in a tensor by +index. For example, say we want to insert 4 scattered elements in a rank-1 +tensor with 8 elements. + +
+ +
+ +In Python, this scatter operation would look like this: + + >>> indices = tf.constant([[4], [3], [1], [7]]) + >>> updates = tf.constant([9, 10, 11, 12]) + >>> tensor = tf.ones([8], dtype=tf.int32) + >>> print(tf.tensor_scatter_nd_update(tensor, indices, updates)) + tf.Tensor([ 1 11 1 10 9 1 1 12], shape=(8,), dtype=int32) + +We can also, insert entire slices of a higher rank tensor all at once. For +example, if we wanted to insert two slices in the first dimension of a +rank-3 tensor with two matrices of new values. + +In Python, this scatter operation would look like this: + + >>> indices = tf.constant([[0], [2]]) + >>> updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6], + ... [7, 7, 7, 7], [8, 8, 8, 8]], + ... [[5, 5, 5, 5], [6, 6, 6, 6], + ... [7, 7, 7, 7], [8, 8, 8, 8]]]) + >>> tensor = tf.ones([4, 4, 4], dtype=tf.int32) + >>> print(tf.tensor_scatter_nd_update(tensor, indices, updates).numpy()) + [[[5 5 5 5] + [6 6 6 6] + [7 7 7 7] + [8 8 8 8]] + [[1 1 1 1] + [1 1 1 1] + [1 1 1 1] + [1 1 1 1]] + [[5 5 5 5] + [6 6 6 6] + [7 7 7 7] + [8 8 8 8]] + [[1 1 1 1] + [1 1 1 1] + [1 1 1 1] + [1 1 1 1]]] + +Note that on CPU, if an out of bound index is found, an error is returned. +On GPU, if an out of bound index is found, the index is ignored. + }]; + + let arguments = (ins + TF_Tensor:$tensor, + TF_I32OrI64Tensor:$indices, + TF_Tensor:$updates + ); + + let results = (outs + TF_Tensor:$output + ); + + TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + + let verifier = [{ return Verify(*this); }]; +} + def TF_TileOp : TF_Op<"Tile", [NoSideEffect]> { let summary = "Constructs a tensor by tiling a given tensor."; @@ -6075,7 +6694,7 @@ The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy: let builders = [ OpBuilder< - "Builder* builder, OperationState& result, Value* x, Value* perm"> + "Builder* builder, OperationState& result, Value x, Value perm"> ]; let verifier = [{ @@ -6119,7 +6738,7 @@ def TF_UniqueOp : TF_Op<"Unique", [NoSideEffect]> { let description = [{ This operation returns a tensor `y` containing all of the unique elements of `x` sorted in the same order that they occur in `x`; `x` does not need to be sorted. -This operation also returns a tensor `idx` the same size as `x` that contains +This operation also returns a tensor `idx` the same size as `x` that contains the index of each value of `x` in the unique output `y`. In other words: `y[idx[i]] = x[i] for i in [0, 1,...,rank(x) - 1]` diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td index c3a51613357..5505b8980e3 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td @@ -171,6 +171,8 @@ def TF_FpOrI32OrI64Tensor : TensorOf<[AnyFloat, TF_I32Or64]>; // Any integer or floating-point tensor types def TF_IntOrFpTensor : TensorOf<[TF_Int, AnyFloat]>; +def TF_SintOrFpTensor : TensorOf<[TF_SInt, AnyFloat]>; + def TF_FpOrComplexTensor : TensorOf<[AnyFloat, TF_AnyComplex]>; def TF_AnyNumber : AnyTypeOf<[TF_Int, AnyFloat, TF_AnyQuantized, TF_AnyComplex], @@ -297,10 +299,10 @@ def TF_IntTypeAttr : TypeAttrBase<"IntegerType", "integer type"> { // behavior. The result type has the same element type as both operands. class WithBroadcastableBinOpBuilder { list builders = [OpBuilder< -"Builder *builder, OperationState &result, Value* x, Value* y", +"Builder *builder, OperationState &result, Value x, Value y", [{ auto resultType = - OpTrait::util::getBroadcastedType(x->getType(), y->getType()); + OpTrait::util::getBroadcastedType(x.getType(), y.getType()); if (!resultType) mlir::emitError(result.location, "non-broadcastable operands"); return build(builder, result, resultType, x, y); @@ -312,17 +314,17 @@ class WithBroadcastableBinOpBuilder { // behavior. The result type has bool element type. class WithBroadcastableCmpOpBuilder { list builders = [OpBuilder< -"Builder *builder, OperationState &result, Value* x, Value* y", +"Builder *builder, OperationState &result, Value x, Value y", [{ Type resultType; - if (x->getType().isa() || - y->getType().isa()) { + if (x.getType().isa() || + y.getType().isa()) { resultType = UnrankedTensorType::get(builder->getI1Type()); } else { SmallVector resultShape; if (!OpTrait::util::getBroadcastedShape( - x->getType().cast().getShape(), - y->getType().cast().getShape(), resultShape)) { + x.getType().cast().getShape(), + y.getType().cast().getShape(), resultShape)) { mlir::emitError(result.location, "operands have no broadcastable shapes"); } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index 3744cdeb66e..9b07b2f0c92 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -21,9 +21,12 @@ limitations under the License. #include #include #include +#include #include #include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" @@ -32,27 +35,28 @@ limitations under the License. #include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/FormatVariadic.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/Dialect/Traits.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Diagnostics.h" // TF:local_config_mlir -#include "mlir/IR/DialectImplementation.h" // TF:local_config_mlir -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/Location.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Matchers.h" // TF:local_config_mlir -#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir -#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir -#include "mlir/IR/Types.h" // TF:local_config_mlir -#include "mlir/IR/Value.h" // TF:local_config_mlir -#include "mlir/Parser.h" // TF:local_config_mlir -#include "mlir/Support/LLVM.h" // TF:local_config_mlir -#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir -#include "mlir/Support/STLExtras.h" // TF:local_config_mlir -#include "mlir/Transforms/InliningUtils.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/Traits.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Diagnostics.h" // TF:llvm-project +#include "mlir/IR/DialectImplementation.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Matchers.h" // TF:llvm-project +#include "mlir/IR/OpDefinition.h" // TF:llvm-project +#include "mlir/IR/OpImplementation.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/TypeUtilities.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/IR/Value.h" // TF:llvm-project +#include "mlir/Parser.h" // TF:llvm-project +#include "mlir/Support/LLVM.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/Support/STLExtras.h" // TF:llvm-project +#include "mlir/Transforms/InliningUtils.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/tensor_format.h" @@ -68,17 +72,17 @@ namespace TF { // may have non-static shape because the shape is not propagated during constant // folding. If the defining op for the given operand is a constant op, this // routine uses the constant op's attribute to get the actual shape. -static RankedTensorType GetRankedTensorTypeForOperand(Value *operand) { +static RankedTensorType GetRankedTensorTypeForOperand(Value operand) { DenseElementsAttr attr; if (matchPattern(operand, m_Constant(&attr))) { return attr.getType().dyn_cast(); } - return operand->getType().dyn_cast(); + return operand.getType().dyn_cast(); } // Returns true if the given `value` is of ranked float tensor type with the // given `rank`. -static inline bool isOfRankedFloatTensorType(Value *value, int rank) { +static inline bool isOfRankedFloatTensorType(Value value, int rank) { RankedTensorType type = GetRankedTensorTypeForOperand(value); return type && type.getRank() == rank && type.getElementType().isa(); @@ -86,21 +90,21 @@ static inline bool isOfRankedFloatTensorType(Value *value, int rank) { // Returns true if the given `value` has the specified rank or has unranked // type. -static inline bool IsOfRankOrUnranked(Value *value, int64_t rank) { +static inline bool IsOfRankOrUnranked(Value value, int64_t rank) { RankedTensorType type = GetRankedTensorTypeForOperand(value); return !type || type.getRank() == rank; } // Returns true if the given `value` has at least the specified rank or has // unranked type. -static inline bool HasRankAtLeast(Value *value, int64_t rank) { +static inline bool HasRankAtLeast(Value value, int64_t rank) { RankedTensorType type = GetRankedTensorTypeForOperand(value); return !type || type.getRank() >= rank; } // Returns true if the given `value` has at most the specified rank or has // unranked type. -static inline bool HasRankAtMost(Value *value, int64_t rank) { +static inline bool HasRankAtMost(Value value, int64_t rank) { RankedTensorType type = GetRankedTensorTypeForOperand(value); return !type || type.getRank() <= rank; } @@ -154,10 +158,10 @@ static bool IsUnknownDimOrRank(int64_t dim_or_rank) { // Returns the tf.Equal/tf.NotEqual result type given `x` and `y` and inputs. If // `incompatible_shape_error` is true, reports error if `x` and `y` has // incompatible shapes. Otherwise, returns a tensor type with unknown rank. -static Type DeduceEqualCmpOpType(Builder *builder, Location loc, Value *x, - Value *y, BoolAttr incompatible_shape_error) { +static Type DeduceEqualCmpOpType(Builder *builder, Location loc, Value x, + Value y, BoolAttr incompatible_shape_error) { auto result_type = - OpTrait::util::getBroadcastedType(x->getType(), y->getType()); + OpTrait::util::getBroadcastedType(x.getType(), y.getType()); if (!result_type) { if (incompatible_shape_error.getValue()) { mlir::emitError(loc, "non-broadcastable operands"); @@ -181,9 +185,9 @@ static int64_t GetDimForAxis(int64_t axis, int64_t rank) { // Infers output type for reduction ops such as SumOp, MaxOp etc. // TODO(b/e667204a): Move this logic to shape inference once it supports custom // inference functions. -static Type InferReductionOpType(Value *input, Value *reduction_indices, +static Type InferReductionOpType(Value input, Value reduction_indices, BoolAttr keep_dims, Builder *builder) { - Type input_ty = input->getType(); + Type input_ty = input.getType(); Type element_ty = getElementTypeOrSelf(input_ty); // Output type is unranked if input type is not ranked. @@ -324,14 +328,14 @@ void AddV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results, //===----------------------------------------------------------------------===// // Verifies an reduction op's `input` and reduction `dims`. -static LogicalResult VerifyReductionInputAndDims(Value *input, Value *dims, +static LogicalResult VerifyReductionInputAndDims(Value input, Value dims, Location loc) { - auto dims_type = dims->getType().dyn_cast(); + auto dims_type = dims.getType().dyn_cast(); if (!dims_type) return success(); if (dims_type.getRank() > 1) return emitError(loc, "dimensions can only be 0D or 1D tensor"); - auto input_type = input->getType().dyn_cast(); + auto input_type = input.getType().dyn_cast(); if (!input_type) return success(); int64_t rank = input_type.getRank(); @@ -437,9 +441,8 @@ static LogicalResult Verify(BiasAddOp op) { if (!IsOfRankOrUnranked(op.bias(), 1)) return op.emitOpError("requires bias operand to have rank exactly one"); - RankedTensorType value_ty = - op.value()->getType().dyn_cast(); - RankedTensorType bias_ty = op.bias()->getType().dyn_cast(); + RankedTensorType value_ty = op.value().getType().dyn_cast(); + RankedTensorType bias_ty = op.bias().getType().dyn_cast(); if (!bias_ty || !value_ty) return success(); // TODO(hinsu): Leverage tensor_format.h utility in TensorFlow to compute @@ -524,7 +527,7 @@ static LogicalResult Verify(OpT op) { Operation::operand_range values = op.values(); int axis_idx = std::is_same() ? 0 : 1; - Value *axis = *op.getODSOperands(axis_idx).begin(); + Value axis = *op.getODSOperands(axis_idx).begin(); if (!HasRankAtMost(axis, 1)) { return op.emitOpError( "requires axis to be of scalar type (or vector type for older " @@ -535,6 +538,118 @@ static LogicalResult Verify(OpT op) { /*mask_one_dim=*/true, op.getOperation()); } +//===----------------------------------------------------------------------===// +// ConcatOffsetOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(ConcatOffsetOp op) { + if (op.N() < 2) + return op.emitOpError() << "requires N to be at least 2, got " << op.N(); + + if (op.shape().size() != op.offset().size()) + return op.emitOpError() + << "requires sizes of shapes and offsets to be the same, got sizes " + << op.shape().size() << " and " << op.offset().size(); + + auto ranked_dim = op.concat_dim().getType().dyn_cast(); + if (ranked_dim && ranked_dim.getRank() != 0) + return op.emitOpError() + << "requires concat_dim to be a scalar, got tensor of rank " + << ranked_dim.getRank(); + + int64_t num_dims = -1; + for (auto shape_offset_idx : + llvm::enumerate(llvm::zip(op.shape(), op.offset()))) { + Value shape = std::get<0>(shape_offset_idx.value()); + Value offset = std::get<1>(shape_offset_idx.value()); + const size_t idx = shape_offset_idx.index(); + + if (failed(verifyCompatibleShape(shape.getType(), offset.getType()))) + return op.emitOpError() << "requires operand and result " << idx + << " to have compatible shapes"; + + auto ranked_shape = shape.getType().dyn_cast(); + if (!ranked_shape) continue; + + if (ranked_shape.getRank() != 1) + return op.emitOpError() << "requires shape tensor operand " << idx + << " to be of rank 1, got tensor of rank " + << ranked_shape.getRank(); + + if (!ranked_shape.hasStaticShape()) continue; + + int64_t ranked_shape_dim = ranked_shape.getDimSize(0); + if (num_dims == -1) + num_dims = ranked_shape_dim; + else if (ranked_shape_dim != num_dims) + return op.emitOpError() + << "requires shape tensor (rank 1) operand " << idx + << " to be of length " << num_dims + << ", got tensor (rank 1) of length " << ranked_shape_dim; + } + + return success(); +} + +LogicalResult ConcatOffsetOp::fold(ArrayRef operands, + SmallVectorImpl &results) { + // ConcatOffset must have its first operand be concat_dim and at least two + // shape tensors in variadic shapes operand. + if (operands.size() < 3) return failure(); + + // Check concat_dim is a scalar. + auto concat_dim_attr = operands[0].dyn_cast_or_null(); + if (!concat_dim_attr || concat_dim_attr.getType().getRank() != 0) + return failure(); + + llvm::SmallVector shapes; + shapes.reserve(operands.size() - 1); + for (Attribute shape : llvm::drop_begin(operands, 1)) + if (auto shape_attr = shape.dyn_cast_or_null()) + shapes.push_back(shape_attr); + else + return failure(); + + // Check all shapes are vectors of the same length. + if (shapes.front().getType().getRank() != 1) return success(); + const int64_t num_dims = shapes.front().getNumElements(); + for (DenseIntElementsAttr shape : llvm::drop_begin(shapes, 1)) + if (shape.getType().getRank() != 1 || shape.getNumElements() != num_dims) + return failure(); + + // Check concat_dim is within [-num_dims, num_dims). + int32_t concat_dim = (*concat_dim_attr.getValues().begin()); + if (concat_dim < 0) concat_dim += num_dims; + if (concat_dim >= num_dims || concat_dim < 0) return failure(); + + // Check all elements besides at concat_dim match across all shape tensors. + SmallVector shape0; + shape0.reserve(num_dims); + for (int32_t dim : shapes.front().getValues()) shape0.push_back(dim); + + for (DenseIntElementsAttr shape : llvm::drop_begin(shapes, 1)) { + for (auto dims_and_idx : llvm::enumerate(llvm::zip(shape0, shape))) { + if (dims_and_idx.index() == concat_dim) continue; + + if (std::get<0>(dims_and_idx.value()) != + std::get<1>(dims_and_idx.value()).getSExtValue()) + return failure(); + } + } + + // Compute an exclusive cumulative sum of elements at concat_dim. + results.reserve(shapes.size()); + SmallVector cumulative_sum(num_dims, 0); + RankedTensorType offset_type = + RankedTensorType::get({num_dims}, IntegerType::get(32, getContext())); + for (DenseIntElementsAttr shape : shapes) { + results.push_back(DenseIntElementsAttr::get(offset_type, cumulative_sum)); + cumulative_sum[concat_dim] += shape.getValue(concat_dim); + } + + return success(); +} + //===----------------------------------------------------------------------===// // ConjOp //===----------------------------------------------------------------------===// @@ -670,7 +785,7 @@ static LogicalResult Verify(OpT op) { } int64_t input_channels = -1; - if (auto ty = op.input()->getType().template dyn_cast()) { + if (auto ty = op.input().getType().template dyn_cast()) { std::string data_format = op.data_format().str(); tensorflow::TensorFormat format; auto is_valid = FormatFromString(data_format, &format); @@ -680,7 +795,7 @@ static LogicalResult Verify(OpT op) { } int64_t filter_channels = -1; - if (auto ty = op.filter()->getType().template dyn_cast()) { + if (auto ty = op.filter().getType().template dyn_cast()) { int idx = tensorflow::GetFilterTensorInputChannelsDimIndex( num_dims, tensorflow::FORMAT_HWIO); filter_channels = ty.getDimSize(idx); @@ -726,6 +841,101 @@ void DivOp::getCanonicalizationPatterns(OwningRewritePatternList &results, results.insert(context); } +//===----------------------------------------------------------------------===// +// DynamicStitchOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(DynamicStitchOp op) { + if (op.N() < 1) return op.emitOpError("requires attribute N with value >= 1"); + + if (RankedTensorType out_ty = op.getType().dyn_cast()) { + if (out_ty.getRank() == 0) { + return op.emitOpError("requires non scalar output"); + } + } + + llvm::SmallDenseSet index_values; + bool all_indices_const = true; + int32_t max_index = -1; + llvm::Optional> inferred_item_shape; + for (auto it : llvm::zip(op.indices(), op.data())) { + Value index = std::get<0>(it); + + DenseIntElementsAttr index_attr; + if (matchPattern(index, m_Constant(&index_attr))) { + for (int32_t index : index_attr.getValues()) { + if (index < 0) + return op.emitOpError() + << "requires non-negative index values; found " << index; + max_index = std::max(index, max_index); + index_values.insert(index); + } + } else { + all_indices_const = false; + } + + Value data = std::get<1>(it); + RankedTensorType index_ty = index.getType().dyn_cast(); + RankedTensorType data_ty = data.getType().dyn_cast(); + if (!index_ty || !data_ty) continue; + + int64_t index_rank = index_ty.getRank(); + ArrayRef data_shape = data_ty.getShape(); + ArrayRef index_shape = index_ty.getShape(); + if (failed(mlir::verifyCompatibleShape(index_shape, + data_shape.take_front(index_rank)))) + return op.emitOpError() << "requires shape of data with type " << data_ty + << " to have prefix matching with shape of the " + "corresponding index type " + << index_ty; + + ArrayRef item_shape = data_shape.drop_front(index_rank); + if (!inferred_item_shape) { + inferred_item_shape = llvm::to_vector<4>(item_shape); + continue; + } + + if (failed(mlir::verifyCompatibleShape(item_shape, *inferred_item_shape))) + return op.emitOpError() << "has inconsistent shaped data and index " + "pairs; inferred item shapes [" + << llvm::makeArrayRef(*inferred_item_shape) + << "] and [" << item_shape << "] don't match"; + for (int i = 0, e = item_shape.size(); i < e; ++i) { + int64_t &inferred_dim = (*inferred_item_shape)[i]; + int64_t dim = item_shape[i]; + if (ShapedType::isDynamic(inferred_dim)) inferred_dim = dim; + } + } + + // If all indices are constants, then verify that they cover all indices in + // the range [0, max_index] and the output type is legal. + if (all_indices_const) { + for (int32_t i = 0; i <= max_index; i++) { + if (!index_values.count(i)) + return op.emitOpError() << "missing index " << i; + } + + if (inferred_item_shape) { + SmallVector expected_shape; + expected_shape.push_back(max_index + 1); + expected_shape.append(inferred_item_shape->begin(), + inferred_item_shape->end()); + + auto out_ty = op.getType().cast(); + auto expected_out_ty = + RankedTensorType::get(expected_shape, out_ty.getElementType()); + + if (!AreCastCompatible(out_ty, expected_out_ty)) { + return op.emitOpError() << "has invalid output type; should be " + "compatible with inferred type " + << expected_out_ty; + } + } + } + + return success(); +} + //===----------------------------------------------------------------------===// // EinsumOp //===----------------------------------------------------------------------===// @@ -770,13 +980,44 @@ static LogicalResult Verify(EqualOp op) { op.getOperation()); } -void EqualOp::build(Builder *builder, OperationState &result, Value *x, - Value *y, BoolAttr incompatible_shape_error) { +void EqualOp::build(Builder *builder, OperationState &result, Value x, Value y, + BoolAttr incompatible_shape_error) { auto result_type = DeduceEqualCmpOpType(builder, result.location, x, y, incompatible_shape_error); return build(builder, result, result_type, x, y, incompatible_shape_error); } +//===----------------------------------------------------------------------===// +// ExpandDimsOp +//===----------------------------------------------------------------------===// + +Type InferExpandDimsOpType(Value input, Value dim) { + Type element_ty = input.getType().cast().getElementType(); + auto unranked_ty = UnrankedTensorType::get(element_ty); + + auto input_ty = input.getType().dyn_cast(); + if (!input_ty) return unranked_ty; + + DenseIntElementsAttr dim_attr; + if (!matchPattern(dim, m_Constant(&dim_attr)) || + dim_attr.getNumElements() != 1) + return unranked_ty; + int64_t dim_val = (*dim_attr.begin()).getSExtValue(); + int64_t input_rank = input_ty.getRank(); + + if (dim_val < -input_rank - 1 || dim_val > input_rank + 1) return unranked_ty; + if (dim_val < 0) dim_val += input_rank + 1; + + SmallVector shape = llvm::to_vector<4>(input_ty.getShape()); + shape.insert(shape.begin() + dim_val, 1); + return RankedTensorType::get(shape, element_ty); +} + +void ExpandDimsOp::build(Builder *builder, OperationState &result, Value input, + Value dim) { + return build(builder, result, InferExpandDimsOpType(input, dim), input, dim); +} + //===----------------------------------------------------------------------===// // FakeQuantWithMinMaxArgsOp //===----------------------------------------------------------------------===// @@ -832,16 +1073,16 @@ static LogicalResult Verify(FakeQuantWithMinMaxVarsPerChannelOp op) { if (!isOfRankedFloatTensorType(op.max(), 1)) return op.emitOpError("requires max to be a 1d float tensor"); - Value *inputs = op.inputs(); + Value inputs = op.inputs(); if (!HasRankAtLeast(inputs, 1) || - inputs->getType().isa()) { + inputs.getType().isa()) { return op.emitError("requires inputs to be at least 1d float tensor"); } - auto inputsType = inputs->getType().cast(); + auto inputsType = inputs.getType().cast(); int depth = inputsType.getDimSize(inputsType.getRank() - 1); - if (op.min()->getType().cast().getDimSize(0) != depth || - op.max()->getType().cast().getDimSize(0) != depth) { + if (op.min().getType().cast().getDimSize(0) != depth || + op.max().getType().cast().getDimSize(0) != depth) { return op.emitOpError( "requires min and max to have same size as last dimension of inputs"); } @@ -897,7 +1138,7 @@ static LogicalResult Verify(FusedBatchNormOp op) { static LogicalResult Verify(GatherV2Op op) { int64_t batch_dims = op.batch_dims().getSExtValue(); - if (auto ty = op.indices()->getType().dyn_cast()) { + if (auto ty = op.indices().getType().dyn_cast()) { int64_t rank = ty.getRank(); if (batch_dims > rank || batch_dims < -rank) return op.emitOpError() @@ -912,7 +1153,7 @@ static LogicalResult Verify(GatherV2Op op) { DenseIntElementsAttr axis_attr; if (matchPattern(op.axis(), m_Constant(&axis_attr))) { int64_t axis = (*axis_attr.begin()).getSExtValue(); - if (auto ty = op.params()->getType().dyn_cast()) { + if (auto ty = op.params().getType().dyn_cast()) { int64_t rank = ty.getRank(); if (axis >= rank || axis < -rank) return op.emitOpError() << "axis (" << axis << ") must be in range [" @@ -955,7 +1196,7 @@ static LogicalResult Verify(IfOp op) { " inputs"); for (unsigned i = 0; i < expectedNumInputs; ++i) { - auto operandType = op.getOperand(i + 1)->getType().cast(); + auto operandType = op.getOperand(i + 1).getType().cast(); auto thenInputType = thenFuncType.getInput(i).cast(); if (!AreCastCompatible(operandType, thenInputType)) return op.emitError( @@ -986,7 +1227,7 @@ static LogicalResult Verify(IfOp op) { " results"); for (unsigned i = 0; i < expectedNumResults; ++i) { - auto resultType = op.getResult(i)->getType().cast(); + auto resultType = op.getResult(i).getType().cast(); auto thenResultType = thenFuncType.getResult(i).cast(); if (!AreCastCompatible(thenResultType, resultType)) return op.emitError( @@ -1062,8 +1303,8 @@ void LogicalNotOp::getCanonicalizationPatterns( // MaxOp //===----------------------------------------------------------------------===// -void MaxOp::build(Builder *builder, OperationState &result, Value *input, - Value *reduction_indices, BoolAttr keep_dims) { +void MaxOp::build(Builder *builder, OperationState &result, Value input, + Value reduction_indices, BoolAttr keep_dims) { Type out_ty = InferReductionOpType(input, reduction_indices, keep_dims, builder); build(builder, result, out_ty, input, reduction_indices, keep_dims); @@ -1108,8 +1349,8 @@ static LogicalResult Verify(NotEqualOp op) { op.getOperation()); } -void NotEqualOp::build(Builder *builder, OperationState &result, Value *x, - Value *y, BoolAttr incompatible_shape_error) { +void NotEqualOp::build(Builder *builder, OperationState &result, Value x, + Value y, BoolAttr incompatible_shape_error) { auto result_type = DeduceEqualCmpOpType(builder, result.location, x, y, incompatible_shape_error); return build(builder, result, result_type, x, y, incompatible_shape_error); @@ -1122,7 +1363,7 @@ void NotEqualOp::build(Builder *builder, OperationState &result, Value *x, static LogicalResult Verify(OneHotOp op) { int64_t axis = op.axis().getSExtValue(); - auto indices_ty = op.indices()->getType().dyn_cast(); + auto indices_ty = op.indices().getType().dyn_cast(); if (indices_ty && !(axis == -1 || (axis >= 0 && axis <= indices_ty.getShape().size()))) { return op.emitOpError() @@ -1147,9 +1388,8 @@ static LogicalResult Verify(OneHotOp op) { DenseIntElementsAttr depth_attr; if (matchPattern(op.depth(), m_Constant(&depth_attr))) { - if (depth_attr.getType().getRank() != 0) { + if (depth_attr.getType().getRank() != 0) return op.emitOpError() << "requires depth to be a scalar"; - } int64_t depth = depth_attr.getValue({}).getSExtValue(); if (depth < 0) { return op.emitOpError() << "depth must be non-negative, got: " << depth; @@ -1159,6 +1399,36 @@ static LogicalResult Verify(OneHotOp op) { return success(); } +static TensorType InferOneHotOpType(Value indices, Value depth, Value on_value, + Value off_value, IntegerAttr axis) { + int64_t axis_val = axis.getInt(); + Type element_ty = on_value.getType().cast().getElementType(); + auto unranked_ty = UnrankedTensorType::get(element_ty); + if (axis_val < -1) return unranked_ty; + + auto indices_ty = indices.getType().dyn_cast(); + if (!indices_ty) return unranked_ty; + + auto shape = llvm::to_vector<2>(indices_ty.getShape()); + if (axis_val == -1) axis_val = shape.size(); + + int64_t depth_val = ShapedType::kDynamicSize; + DenseIntElementsAttr depth_attr; + if (matchPattern(depth, m_Constant(&depth_attr)) && + depth_attr.getNumElements() == 1) + depth_val = (*depth_attr.begin()).getSExtValue(); + shape.insert(shape.begin() + axis_val, depth_val); + return RankedTensorType::get(shape, element_ty); +} + +void OneHotOp::build(Builder *builder, OperationState &result, Value indices, + Value depth, Value on_value, Value off_value, + IntegerAttr axis) { + build(builder, result, + InferOneHotOpType(indices, depth, on_value, off_value, axis), indices, + depth, on_value, off_value, axis); +} + //===----------------------------------------------------------------------===// // PackOp //===----------------------------------------------------------------------===// @@ -1174,8 +1444,8 @@ static LogicalResult Verify(PackOp op) { } int64_t inputs_rank = -1; - for (Value *value : values) { - if (auto ty = value->getType().dyn_cast()) { + for (Value value : values) { + if (auto ty = value.getType().dyn_cast()) { // Exit early as input types are verified to be compatible so all ranked // tensors have the same rank. inputs_rank = ty.getRank(); @@ -1199,6 +1469,59 @@ static LogicalResult Verify(PackOp op) { return success(); } +//===----------------------------------------------------------------------===// +// ParseExampleV2Op +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(ParseExampleV2Op op) { + // NOTE(mrry): This validates properties of an op that would previously be + // validated by the TensorFlow OpDef type checker. In addition to these + // checks, the shape inference function for ParseExampleV2 validates the + // consistency of the argument and result types. + + // Validate dense variadic input and output lengths. + // NOTE(mrry): The Tdense attr is derived from dense_defaults, so we + // do not need to validate dense_defaults. + auto dense_types_count = + std::distance(op.Tdense().begin(), op.Tdense().end()); + auto dense_values_count = + std::distance(op.dense_values().begin(), op.dense_values().end()); + if (dense_values_count != dense_types_count) { + return op.emitError() << "output 'dense_values' should have same length " + << "as attribute 'Tdense'"; + } + + // Validate sparse variadic output lengths. + // NOTE(mrry): The sparse_types attr is derived from sparse_values, so we + // do not need to validate sparse_values. + auto sparse_types_count = + std::distance(op.sparse_types().begin(), op.sparse_types().end()); + if (op.num_sparse() != sparse_types_count) { + return op.emitError() << "attribute 'num_sparse' should be the same as " + << "the length of attribute 'sparse_types'"; + } + if (op.sparse_indices().size() != sparse_types_count) { + return op.emitError() << "output 'sparse_indices' should have same length " + << "as attribute 'sparse_types'"; + } + if (op.sparse_shapes().size() != sparse_types_count) { + return op.emitError() << "output 'sparse_shapes' should have same length " + << "as attribute 'sparse_types'"; + } + + // Validate ragged variadic output lengths. + auto ragged_value_types_count = std::distance(op.ragged_value_types().begin(), + op.ragged_value_types().end()); + auto ragged_split_types_count = std::distance(op.ragged_split_types().begin(), + op.ragged_split_types().end()); + if (ragged_value_types_count != ragged_split_types_count) { + return op.emitError() << "attribute 'ragged_value_types' should have same " + << "length as attribute 'ragged_split_types'"; + } + + return success(); +} + //===----------------------------------------------------------------------===// // ReciprocalOp //===----------------------------------------------------------------------===// @@ -1222,10 +1545,10 @@ static LogicalResult Verify(RandomUniformOp op) { // RangeOp //===----------------------------------------------------------------------===// -void RangeOp::build(Builder *builder, OperationState &result, Value *start, - Value *limit, Value *delta) { - assert(start->getType() == limit->getType()); - assert(start->getType() == delta->getType()); +void RangeOp::build(Builder *builder, OperationState &result, Value start, + Value limit, Value delta) { + assert(start.getType() == limit.getType()); + assert(start.getType() == delta.getType()); DenseIntElementsAttr start_val; DenseIntElementsAttr limit_val; DenseIntElementsAttr delta_val; @@ -1239,20 +1562,20 @@ void RangeOp::build(Builder *builder, OperationState &result, Value *start, builder, result, RankedTensorType::get( size.getSExtValue(), - start->getType().cast().getElementType()), + start.getType().cast().getElementType()), start, limit, delta); } return RangeOp::build( builder, result, RankedTensorType::get( - {-1}, start->getType().cast().getElementType()), + {-1}, start.getType().cast().getElementType()), start, limit, delta); } //===----------------------------------------------------------------------===// // RankOp //===----------------------------------------------------------------------===// -void RankOp::build(Builder *builder, OperationState &result, Value *input) { +void RankOp::build(Builder *builder, OperationState &result, Value input) { return RankOp::build(builder, result, RankedTensorType::get({}, builder->getIntegerType(32)), input); @@ -1274,17 +1597,17 @@ void RealDivOp::getCanonicalizationPatterns(OwningRewritePatternList &results, // TODO(b/128020684): Verify the rank of the output and change to use // m_Constant. static LogicalResult Verify(ReshapeOp op) { - auto shapeType = op.shape()->getType().cast(); + auto shapeType = op.shape().getType().cast(); if (!shapeType.hasRank()) return success(); if (shapeType.getRank() != 1) return op.emitOpError("shape must be 1D tensor"); auto rankByShape = shapeType.getShape()[0]; - auto typeOfTensor = op.tensor()->getType().cast(); + auto typeOfTensor = op.tensor().getType().cast(); // No compile time verification for unknown sized shape. if (rankByShape == -1 || !typeOfTensor.hasStaticShape()) return success(); // Check values if constant shape. No compiling time verification for // non-constant shape. - auto *shapeOp = op.shape()->getDefiningOp(); + auto *shapeOp = op.shape().getDefiningOp(); if (!shapeOp) return success(); Attribute shapeCst; if (auto shapeStdOp = dyn_cast(shapeOp)) { @@ -1336,9 +1659,9 @@ static LogicalResult Verify(ReshapeOp op) { return success(); } -void ReshapeOp::build(Builder *builder, OperationState &result, Value *tensor, - Value *shape) { - auto ttype = tensor->getType().cast(); +void ReshapeOp::build(Builder *builder, OperationState &result, Value tensor, + Value shape) { + auto ttype = tensor.getType().cast(); auto etype = ttype.getElementType(); auto unranked = [builder, etype, &result, shape, tensor]() { @@ -1394,6 +1717,37 @@ void ReshapeOp::build(Builder *builder, OperationState &result, Value *tensor, return unranked(); } +//===----------------------------------------------------------------------===// +// SelectV2Op +//===----------------------------------------------------------------------===// + +static Type InferSelectV2OpType(Value condition, Value e, Value t) { + Type element_ty = e.getType().cast().getElementType(); + auto unranked_ty = UnrankedTensorType::get(element_ty); + + Type broadcasted_ty = + OpTrait::util::getBroadcastedType(e.getType(), t.getType()); + if (!broadcasted_ty) return unranked_ty; + + auto cond_ranked_ty = condition.getType().dyn_cast(); + auto broadcasted_ranked_ty = broadcasted_ty.dyn_cast(); + if (!cond_ranked_ty || !broadcasted_ranked_ty) return unranked_ty; + + // Explicitly get broadcasted output type as element types of condition may + // not be same as the broadcated type's element type. + SmallVector result_shape; + if (!OpTrait::util::getBroadcastedShape(cond_ranked_ty.getShape(), + broadcasted_ranked_ty.getShape(), + result_shape)) + return unranked_ty; + return RankedTensorType::get(result_shape, element_ty); +} + +void SelectV2Op::build(Builder *builder, OperationState &result, + Value condition, Value e, Value t) { + build(builder, result, InferSelectV2OpType(condition, e, t), condition, e, t); +} + //===----------------------------------------------------------------------===// // ShapeOp //===----------------------------------------------------------------------===// @@ -1436,7 +1790,7 @@ LogicalResult VerifyShapeOperandAndResult(Operation *op, Type operand_type, } // anonymous namespace static LogicalResult Verify(ShapeOp op) { - return VerifyShapeOperandAndResult(op, op.input()->getType(), op.getType()); + return VerifyShapeOperandAndResult(op, op.input().getType(), op.getType()); } // Converts shape of the given type to attribute if it is of ranked tensor type. @@ -1461,12 +1815,12 @@ static Attribute ConvertShapeToAttr(Type input_ty, int out_width) { OpFoldResult ShapeOp::fold(ArrayRef operands) { int width = getType().cast().getElementType().getIntOrFloatBitWidth(); - return ConvertShapeToAttr(getOperand()->getType(), width); + return ConvertShapeToAttr(getOperand().getType(), width); } -void ShapeOp::build(Builder *builder, OperationState &result, Value *input, +void ShapeOp::build(Builder *builder, OperationState &result, Value input, BoolAttr use32Bit) { - auto rankedTensorType = input->getType().dyn_cast(); + auto rankedTensorType = input.getType().dyn_cast(); int64_t rank = rankedTensorType ? rankedTensorType.getRank() : -1; auto out_type = use32Bit.getValue() ? builder->getIntegerType(32) : builder->getIntegerType(64); @@ -1491,7 +1845,7 @@ static LogicalResult Verify(ShapeNOp op) { for (auto i : llvm::seq(0, num_tensors)) { auto verification = VerifyShapeOperandAndResult( - op, op.getOperand(i)->getType(), op.getResult(i)->getType(), i); + op, op.getOperand(i).getType(), op.getResult(i).getType(), i); if (failed(verification)) return verification; } @@ -1564,7 +1918,7 @@ static LogicalResult Verify(SliceOp op) { " same number of elements"; } - auto input_ty = op.input()->getType().dyn_cast(); + auto input_ty = op.input().getType().dyn_cast(); if (input_ty && begin_ty.getNumElements() != input_ty.getRank()) { return op.emitOpError() << "requires number of elements in begin and size" "are equal to input rank"; @@ -1618,7 +1972,7 @@ static LogicalResult Verify(SoftmaxOp op) { // static LogicalResult Verify(SoftmaxCrossEntropyWithLogitsOp op) { auto broadcasted_ty = OpTrait::util::getBroadcastedType( - op.features()->getType(), op.labels()->getType()) + op.features().getType(), op.labels().getType()) .dyn_cast_or_null(); if (!broadcasted_ty || (broadcasted_ty.hasRank() && broadcasted_ty.getRank() != 2)) @@ -1628,6 +1982,31 @@ static LogicalResult Verify(SoftmaxCrossEntropyWithLogitsOp op) { return success(); } +//===----------------------------------------------------------------------===// +// SparseSoftmaxCrossEntropyWithLogitsOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(SparseSoftmaxCrossEntropyWithLogitsOp op) { + if (!IsOfRankOrUnranked(op.features(), 2)) { + return op.emitOpError("requires features operand of rank two"); + } + if (!IsOfRankOrUnranked(op.labels(), 1)) { + return op.emitOpError("requires labels operand of rank one"); + } + auto features_ty = op.features().getType().dyn_cast(); + auto labels_ty = op.labels().getType().dyn_cast(); + if (features_ty && labels_ty) { + int64_t features_batches = features_ty.getDimSize(0); + int64_t labels_batches = labels_ty.getDimSize(0); + if (!ShapedType::isDynamic(features_batches) && + !ShapedType::isDynamic(labels_batches) && + features_batches != labels_batches) + return op.emitOpError( + "requires features and labels with matching first dimension"); + } + return success(); +} + //===----------------------------------------------------------------------===// // SplitOp //===----------------------------------------------------------------------===// @@ -1639,8 +2018,8 @@ template LogicalResult VerifySplitInputAndSplitDim(Op op, Optional *dim_index) { *dim_index = llvm::None; - Value *split_dim = op.split_dim(); - if (auto split_dim_type = split_dim->getType().dyn_cast()) + Value split_dim = op.split_dim(); + if (auto split_dim_type = split_dim.getType().dyn_cast()) if (split_dim_type.getRank() != 0) return op.emitOpError( "split dimension should be an integer scalar tensor"); @@ -1648,7 +2027,7 @@ LogicalResult VerifySplitInputAndSplitDim(Op op, Optional *dim_index) { // We can perform further verification if the input tensor to be split has // known rank and the split dimension tensor is a constant. - auto input_type = op.value()->getType().template dyn_cast(); + auto input_type = op.value().getType().template dyn_cast(); if (!input_type) return success(); int64_t input_rank = input_type.getRank(); @@ -1677,7 +2056,7 @@ static LogicalResult Verify(SplitOp op) { if (!dim_index) return success(); int64_t input_dim_size = - op.value()->getType().cast().getDimSize(*dim_index); + op.value().getType().cast().getDimSize(*dim_index); if (input_dim_size == ShapedType::kDynamicSize) return success(); if (input_dim_size % op.getNumResults() != 0) @@ -1693,7 +2072,7 @@ static LogicalResult Verify(SplitOp op) { static LogicalResult Verify(SplitVOp op) { auto split_sizes_type = - op.size_splits()->getType().dyn_cast(); + op.size_splits().getType().dyn_cast(); if (!split_sizes_type) return success(); if (split_sizes_type.getRank() != 1 || @@ -1706,7 +2085,7 @@ static LogicalResult Verify(SplitVOp op) { if (!dim_index) return success(); int64_t input_dim_size = - op.value()->getType().cast().getDimSize(*dim_index); + op.value().getType().cast().getDimSize(*dim_index); if (input_dim_size == ShapedType::kDynamicSize) return success(); // If split sizes come from a constant, they must sum to the dimension size @@ -1773,8 +2152,8 @@ void SubOp::getCanonicalizationPatterns(OwningRewritePatternList &results, // SumOp //===----------------------------------------------------------------------===// -void SumOp::build(Builder *builder, OperationState &result, Value *input, - Value *reduction_indices, BoolAttr keep_dims) { +void SumOp::build(Builder *builder, OperationState &result, Value input, + Value reduction_indices, BoolAttr keep_dims) { Type out_ty = InferReductionOpType(input, reduction_indices, keep_dims, builder); build(builder, result, out_ty, input, reduction_indices, keep_dims); @@ -1797,8 +2176,8 @@ static LogicalResult VerifyStridedSliceBase(OpTy op) { // Expected size for operands begin, end and strides vector operands. int64_t expected_size = -1; - for (Value *val : {op.begin(), op.end(), op.strides()}) { - auto operand_ty = val->getType().dyn_cast(); + for (Value val : {op.begin(), op.end(), op.strides()}) { + auto operand_ty = val.getType().dyn_cast(); if (!operand_ty || !operand_ty.hasStaticShape()) { // TensorFlow constant ops may have non-static shape because the shape is // not propagated during constant folding. If the defining op for this @@ -1912,12 +2291,51 @@ static void CalculateSlicedShapeAndBoundRanges( } } +bool StridedSliceOp::GetSlicedBoundRanges( + ArrayRef shape, SmallVectorImpl *begin_indices, + SmallVectorImpl *end_indices, SmallVectorImpl *strides) { + if (this->ellipsis_mask().getZExtValue() || + this->new_axis_mask().getZExtValue() || + this->shrink_axis_mask().getZExtValue()) + return false; // TODO(antiagainst): support these masks + + // TODO(hinsu): Support lowering for ops with dynamic begin and end values + // when it is possible to derive indices based on mask attributes. + DenseIntElementsAttr begin_indices_attr, end_indices_attr, strides_attr; + if (!matchPattern(this->begin(), m_Constant(&begin_indices_attr)) || + !matchPattern(this->end(), m_Constant(&end_indices_attr)) || + !matchPattern(this->strides(), m_Constant(&strides_attr))) + return false; + + auto input_shape = llvm::to_vector<4>(shape); + int rank = input_shape.size(); + + begin_indices->clear(); + begin_indices->reserve(rank); + end_indices->clear(); + end_indices->reserve(rank); + strides->clear(); + strides->reserve(rank); + + for (const APInt &index : begin_indices_attr) + begin_indices->push_back(index.getSExtValue()); + for (const APInt &index : end_indices_attr) + end_indices->push_back(index.getSExtValue()); + for (const APInt &stride : strides_attr) + strides->push_back(stride.getSExtValue()); + + CalculateSlicedShapeAndBoundRanges( + input_shape, this->begin_mask().getZExtValue(), + this->end_mask().getZExtValue(), *begin_indices, *end_indices, *strides); + return true; +} + //===----------------------------------------------------------------------===// // StridedSliceGradOp //===----------------------------------------------------------------------===// static LogicalResult Verify(StridedSliceGradOp op) { - auto shape_type = op.shape()->getType().dyn_cast(); + auto shape_type = op.shape().getType().dyn_cast(); if (shape_type && shape_type.getRank() != 1) return op.emitOpError("'shape' operand must be 1D tensor, but got ") << shape_type.getRank() << "D tensor"; @@ -1999,6 +2417,35 @@ static LogicalResult Verify(TensorListStackOp op) { return success(); } +//===----------------------------------------------------------------------===// +// TensorScatterUpdateOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(TensorScatterUpdateOp op) { + if (!HasRankAtLeast(op.tensor(), 1)) + return op.emitOpError( + "requires tensor operand to have at least 1 dimension"); + if (!HasRankAtLeast(op.indices(), 1)) + return op.emitOpError( + "requires indices operand to have at least 1 dimension"); + if (!HasRankAtLeast(op.updates(), 1)) + return op.emitOpError( + "requires updates operand to have at least 1 dimension"); + + auto tensor_ty = op.tensor().getType().dyn_cast(); + auto indices_ty = op.indices().getType().dyn_cast(); + if (!tensor_ty || !indices_ty) return success(); + + int64_t num_index_dims = indices_ty.getShape().back(); + if (ShapedType::isDynamic(num_index_dims)) return success(); + + if (num_index_dims > tensor_ty.getRank()) + return op.emitOpError( + "requires tensor operand with rank greater than or equal to the " + "indices operand's last dimensions"); + return success(); +} + //===----------------------------------------------------------------------===// // TopKV2Op //===----------------------------------------------------------------------===// @@ -2028,9 +2475,9 @@ static LogicalResult Verify(TransposeOp op) { } // TODO(jpienaar): perm could be optional too. -void TransposeOp::build(Builder *builder, OperationState &result, Value *x, - Value *perm) { - auto x_type = x->getType().cast(); +void TransposeOp::build(Builder *builder, OperationState &result, Value x, + Value perm) { + auto x_type = x.getType().cast(); // If value is unranked, then so is results. if (!x_type.hasRank()) return TransposeOp::build(builder, result, @@ -2061,7 +2508,7 @@ void TransposeOp::build(Builder *builder, OperationState &result, Value *x, } OpFoldResult TransposeOp::fold(ArrayRef operands) { - auto const_perm = dyn_cast_or_null(perm()->getDefiningOp()); + auto const_perm = dyn_cast_or_null(perm().getDefiningOp()); if (!const_perm) { return {}; @@ -2093,7 +2540,7 @@ void TruncateDivOp::getCanonicalizationPatterns( //===----------------------------------------------------------------------===// static LogicalResult Verify(UnpackOp op) { - auto value_type = op.value()->getType().dyn_cast(); + auto value_type = op.value().getType().dyn_cast(); if (!value_type) return success(); int64_t value_rank = value_type.getRank(); @@ -2121,9 +2568,9 @@ static LogicalResult VerifyUnsortedSegmentReduction(Op op) { if (!HasRankAtMost(op.num_segments(), 0)) return op.emitOpError("number of segments should be a 0-D tensor"); - auto data_type = op.data()->getType().template dyn_cast(); + auto data_type = op.data().getType().template dyn_cast(); auto segment_ids_type = - op.segment_ids()->getType().template dyn_cast(); + op.segment_ids().getType().template dyn_cast(); if (data_type && segment_ids_type) { if (data_type.getRank() < segment_ids_type.getRank()) return op.emitOpError( @@ -2161,7 +2608,7 @@ static LogicalResult VerifyUnsortedSegmentReduction(Op op) { static LogicalResult Verify(VariableShapeOp op) { auto resource_operand_type = op.input() - ->getType() + .getType() .cast() .getElementType() .cast(); @@ -2312,10 +2759,10 @@ struct TFInlinerInterface : public DialectInlinerInterface { // operation that takes 'input' as the only operand, and produces a single // result of 'resultType'. If a conversion can not be generated, nullptr // should be returned. - Operation *materializeCallConversion(OpBuilder &builder, Value *input, + Operation *materializeCallConversion(OpBuilder &builder, Value input, Type result_type, Location conversion_loc) const final { - if (!result_type.isa() || !input->getType().isa()) + if (!result_type.isa() || !input.getType().isa()) return nullptr; return builder.create(conversion_loc, result_type, input, /*truncate=*/builder.getBoolAttr(false)); diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h index e9aaed56afc..b6f1f76782f 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h @@ -19,16 +19,16 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_H_ -#include "mlir/Analysis/CallInterfaces.h" // TF:local_config_mlir -#include "mlir/Dialect/Traits.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Dialect.h" // TF:local_config_mlir -#include "mlir/IR/Matchers.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir +#include "mlir/Analysis/CallInterfaces.h" // TF:llvm-project +#include "mlir/Dialect/Traits.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Dialect.h" // TF:llvm-project +#include "mlir/IR/Matchers.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/OpImplementation.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/TypeUtilities.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index 9b6196cda5b..8444ec783f0 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -29,6 +29,7 @@ limitations under the License. include "tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td" include "mlir/Analysis/CallInterfaces.td" +include "mlir/IR/OpBase.td" class TF_TensorListInitOp : TF_Op { let results = (outs @@ -56,7 +57,7 @@ class TF_TensorListInitOp : TF_Op { // Returns data type of the result handle. Returned type contains type of // the TensorList element as a subtype. VariantType handle_dtype() { - return getElementTypeOrSelf(handle()->getType()).cast(); + return getElementTypeOrSelf(handle().getType()).cast(); } }]; } @@ -232,6 +233,50 @@ def TF_LegacyCallOp : TF_Op<"LegacyCall", }]; } +def TF_ParseExampleV2Op : TF_Op<"ParseExampleV2", + [NoSideEffect, + AttrSizedResultSegments]> { + + let summary = + "Transforms a vector of tf.Example protos (as strings) into typed tensors."; + + let arguments = (ins + TF_StrTensor:$serialized, + TF_StrTensor:$names, + TF_StrTensor:$sparse_keys, + TF_StrTensor:$dense_keys, + TF_StrTensor:$ragged_keys, + Variadic>:$dense_defaults, + + Confined]>:$num_sparse, + I32ElementsAttr:$result_segment_sizes + ); + + let results = (outs + Variadic:$sparse_indices, // len(sparse_types) + Variadic>:$sparse_values, // len(sparse_types) + Variadic:$sparse_shapes, // len(sparse_types) + Variadic>:$dense_values, // len(Tdense) + Variadic>:$ragged_values, // len(ragged_value_types) + // = len(ragged_split_types) + Variadic>:$ragged_row_splits // len(ragged_split_types) + // = len(ragged_value_types) + ); + + // The Verify(ParseExampleV2Op) function validates that the lengths and types + // of these attrs are compatible. + TF_DerivedOperandTypeListAttr Tdense = TF_DerivedOperandTypeListAttr<5>; + TF_DerivedResultTypeListAttr sparse_types = TF_DerivedResultTypeListAttr<1>; + TF_DerivedResultTypeListAttr ragged_value_types = + TF_DerivedResultTypeListAttr<4>; + TF_DerivedResultTypeListAttr ragged_split_types = + TF_DerivedResultTypeListAttr<5>; + + let verifier = [{ + return Verify(*this); + }]; +} + def TF_PartitionedCallOp : TF_Op<"PartitionedCall", [CallOpInterface, NoSideEffect]> { let summary = diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc index c672d624944..17cc4cdfbe5 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc @@ -22,16 +22,16 @@ limitations under the License. #include "llvm/ADT/Twine.h" #include "llvm/Support/Casting.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/Identifier.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/SymbolTable.h" // TF:local_config_mlir -#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir -#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/Identifier.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/OpImplementation.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/SymbolTable.h" // TF:llvm-project +#include "mlir/IR/TypeUtilities.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project namespace mlir { namespace tf_saved_model { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h index 9998858356d..6f4b2061628 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_SAVED_MODEL_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_SAVED_MODEL_H_ -#include "mlir/IR/Dialect.h" // TF:local_config_mlir -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/IR/OpDefinition.h" // TF:local_config_mlir +#include "mlir/IR/Dialect.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/OpDefinition.h" // TF:llvm-project namespace mlir { namespace tf_saved_model { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h index c600f1445c5..51315c4f90c 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h @@ -18,10 +18,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TRAITS_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TRAITS_H_ -#include "mlir/IR/OpDefinition.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir -#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "mlir/IR/OpDefinition.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/TypeUtilities.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" namespace mlir { @@ -47,7 +47,7 @@ class OperandsSameAsResultsTypeOrRef LogicalResult shapeMatch = impl::verifySameOperandsAndResultShape(op); if (failed(shapeMatch)) return shapeMatch; - auto type = getElementTypeOrSelf(op->getResult(0)->getType()); + auto type = getElementTypeOrSelf(op->getResult(0).getType()); // Verify that the first result type is same as the rest of the results. // We skip the comparison against itself. diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc index ff43728928a..539605d6ccc 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "llvm/Support/ErrorHandling.h" -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/TypeUtilities.h" // TF:llvm-project namespace mlir { namespace TF { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h index 6c97253ef33..7ff54e0c7f4 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h @@ -18,10 +18,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TYPES_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TYPES_H_ -#include "mlir/IR/Diagnostics.h" // TF:local_config_mlir -#include "mlir/IR/Location.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/Types.h" // TF:local_config_mlir +#include "mlir/IR/Diagnostics.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project namespace mlir { namespace TF { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/BUILD index ef93af93b40..a4ebc997991 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/tests/BUILD @@ -4,7 +4,7 @@ package(licenses = ["notice"]) glob_lit_tests( data = [":test_utilities"], - driver = "@local_config_mlir//:run_lit.sh", + driver = "@llvm-project//mlir:run_lit.sh", test_file_exts = ["mlir"], ) @@ -14,7 +14,7 @@ filegroup( testonly = True, data = [ "//tensorflow/compiler/mlir:tf-opt", - "@llvm//:FileCheck", - "@llvm//:not", + "@llvm-project//llvm:FileCheck", + "@llvm-project//llvm:not", ], ) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir index b7d438b38ed..2a17ec16898 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir @@ -67,6 +67,116 @@ func @testAdd() -> tensor<2x2xi32> { return %2: tensor<2x2xi32> } +// CHECK-LABEL: testSimpleConcatOffset +func @testSimpleConcatOffset() -> (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) { + %concat_dim = constant dense<1> : tensor + %shape0 = constant dense<[2, 2, 7]> : tensor<3xi32> + %shape1 = constant dense<[2, 3, 7]> : tensor<3xi32> + %shape2 = constant dense<[2, 5, 7]> : tensor<3xi32> + + // CHECK: [[OFFSET_0:%.*]] = "tf.Const{{.*}} dense<0> : tensor<3xi32> + // CHECK: [[OFFSET_1:%.*]] = "tf.Const{{.*}} dense<[0, 2, 0]> : tensor<3xi32> + // CHECK: [[OFFSET_2:%.*]] = "tf.Const{{.*}} dense<[0, 5, 0]> : tensor<3xi32> + + %offset:3 = "tf.ConcatOffset"(%concat_dim, %shape0, %shape1, %shape2) : (tensor, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) + + // CHECK: return [[OFFSET_0]], [[OFFSET_1]], [[OFFSET_2]] + return %offset#0, %offset#1, %offset#2: tensor<3xi32>, tensor<3xi32>, tensor<3xi32> +} + +// CHECK-LABEL: testConcatOffsetWithZeros +func @testConcatOffsetWithZeros() -> (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) { + %concat_dim = constant dense<1> : tensor + %shape0 = constant dense<0> : tensor<3xi32> + %shape1 = constant dense<[0, 3, 0]> : tensor<3xi32> + %shape2 = constant dense<[0, 5, 0]> : tensor<3xi32> + %shape3 = constant dense<0> : tensor<3xi32> + + // CHECK: [[OFFSET_0:%.*]] = "tf.Const{{.*}} dense<0> : tensor<3xi32> + // CHECK: [[OFFSET_2:%.*]] = "tf.Const{{.*}} dense<[0, 3, 0]> : tensor<3xi32> + // CHECK: [[OFFSET_3:%.*]] = "tf.Const{{.*}} dense<[0, 8, 0]> : tensor<3xi32> + + %offset:4 = "tf.ConcatOffset"(%concat_dim, %shape0, %shape1, %shape2, %shape3) : (tensor, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) + + // CHECK: return [[OFFSET_0]], [[OFFSET_0]], [[OFFSET_2]], [[OFFSET_3]] + return %offset#0, %offset#1, %offset#2, %offset#3: tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32> +} + +// CHECK-LABEL: testConcatOffsetNegativeConcatDim +func @testConcatOffsetNegativeConcatDim() -> (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) { + %concat_dim = constant dense<-1> : tensor + %shape0 = constant dense<[2, 8, 3]> : tensor<3xi32> + %shape1 = constant dense<[2, 8, 5]> : tensor<3xi32> + %shape2 = constant dense<[2, 8, 7]> : tensor<3xi32> + + // CHECK: [[OFFSET_0:%.*]] = "tf.Const{{.*}} dense<0> : tensor<3xi32> + // CHECK: [[OFFSET_1:%.*]] = "tf.Const{{.*}} dense<[0, 0, 3]> : tensor<3xi32> + // CHECK: [[OFFSET_2:%.*]] = "tf.Const{{.*}} dense<[0, 0, 8]> : tensor<3xi32> + + %offset:3 = "tf.ConcatOffset"(%concat_dim, %shape0, %shape1, %shape2) : (tensor, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) + + // CHECK: return [[OFFSET_0]], [[OFFSET_1]], [[OFFSET_2]] + return %offset#0, %offset#1, %offset#2: tensor<3xi32>, tensor<3xi32>, tensor<3xi32> +} + +// CHECK-LABEL: testConcatOffsetNonConstConcatDim +func @testConcatOffsetNonConstConcatDim(%concat_dim: tensor) -> (tensor<3xi32>, tensor<3xi32>) { + %shape0 = constant dense<[2, 2, 7]> : tensor<3xi32> + %shape1 = constant dense<[2, 3, 7]> : tensor<3xi32> + + // CHECK: tf.ConcatOffset + %offset:2 = "tf.ConcatOffset"(%concat_dim, %shape0, %shape1) : (tensor, tensor<3xi32>, tensor<3xi32>) -> (tensor<3xi32>, tensor<3xi32>) + + return %offset#0, %offset#1: tensor<3xi32>, tensor<3xi32> +} + +// CHECK-LABEL: testConcatOffsetNonConstShape +func @testConcatOffsetNonConstShape(%shape1: tensor<3xi32>) -> (tensor<3xi32>, tensor<3xi32>) { + %concat_dim = constant dense<1> : tensor + %shape0 = constant dense<[2, 2, 7]> : tensor<3xi32> + + // CHECK: tf.ConcatOffset + %offset:2 = "tf.ConcatOffset"(%concat_dim, %shape0, %shape1) : (tensor, tensor<3xi32>, tensor<3xi32>) -> (tensor<3xi32>, tensor<3xi32>) + + return %offset#0, %offset#1: tensor<3xi32>, tensor<3xi32> +} + +// CHECK-LABEL: testConcatOffsetBadNegativeConcatDim +func @testConcatOffsetBadNegativeConcatDim() -> (tensor<3xi32>, tensor<3xi32>) { + %concat_dim = constant dense<-4> : tensor + %shape0 = constant dense<[2, 2, 7]> : tensor<3xi32> + %shape1 = constant dense<[2, 3, 7]> : tensor<3xi32> + + // CHECK: tf.ConcatOffset + %offset:2 = "tf.ConcatOffset"(%concat_dim, %shape0, %shape1) : (tensor, tensor<3xi32>, tensor<3xi32>) -> (tensor<3xi32>, tensor<3xi32>) + + return %offset#0, %offset#1: tensor<3xi32>, tensor<3xi32> +} + +// CHECK-LABEL: testConcatOffsetBadPositiveConcatDim +func @testConcatOffsetBadPositiveConcatDim() -> (tensor<3xi32>, tensor<3xi32>) { + %concat_dim = constant dense<3> : tensor + %shape0 = constant dense<[2, 2, 7]> : tensor<3xi32> + %shape1 = constant dense<[2, 3, 7]> : tensor<3xi32> + + // CHECK: tf.ConcatOffset + %offset:2 = "tf.ConcatOffset"(%concat_dim, %shape0, %shape1) : (tensor, tensor<3xi32>, tensor<3xi32>) -> (tensor<3xi32>, tensor<3xi32>) + + return %offset#0, %offset#1: tensor<3xi32>, tensor<3xi32> +} + +// CHECK-LABEL: testConcatOffsetDifferentNonConcatDimElements +func @testConcatOffsetDifferentNonConcatDimElements() -> (tensor<3xi32>, tensor<3xi32>) { + %concat_dim = constant dense<1> : tensor + %shape0 = constant dense<[2, 2, 7]> : tensor<3xi32> + %shape1 = constant dense<[2, 3, 8]> : tensor<3xi32> + + // CHECK: tf.ConcatOffset + %offset:2 = "tf.ConcatOffset"(%concat_dim, %shape0, %shape1) : (tensor, tensor<3xi32>, tensor<3xi32>) -> (tensor<3xi32>, tensor<3xi32>) + + return %offset#0, %offset#1: tensor<3xi32>, tensor<3xi32> +} + // Ops with side effects should not get constant folded. // CHECK-LABEL: func @testSideEffectOp() -> tensor<3xf32> func @testSideEffectOp() -> tensor<3xf32> { @@ -77,7 +187,7 @@ func @testSideEffectOp() -> tensor<3xf32> { return %1: tensor<3xf32> } -// Ops with unimplemnted attributes which couldn't be added to the TFE_Op. +// Ops with unimplemented attributes which couldn't be added to the TFE_Op. // CHECK-LABEL: func @testUnimplementedOp() -> (tensor, tensor) func @testUnimplementedOp() -> (tensor, tensor) { %0 = constant dense<1> : tensor diff --git a/tensorflow/compiler/mlir/tensorflow/tests/executor_to_control_dialect.mlir b/tensorflow/compiler/mlir/tensorflow/tests/executor_to_control_dialect.mlir index 60117552c8e..5ecef050055 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/executor_to_control_dialect.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/executor_to_control_dialect.mlir @@ -121,7 +121,7 @@ func @ref_tf_executor_ops(%arg0: tensor<4x!tf.f32ref>, %arg1: tensor<4x!tf.f32re // ----- -// Tests if empty island with just control dependency inputs and output is +// Tests if empty island with just one control dependency input and output is // handled correctly. // CHECK-LABEL: func @empty_island_control_dep_only func @empty_island_control_dep_only() -> tensor { @@ -138,10 +138,10 @@ func @empty_island_control_dep_only() -> tensor { } // CHECK-NEXT: %[[CONST2:[0-9]*]]:2 = "_tf.Const"() // CHECK-SAME: () -> (tensor, !_tf.control) - %2 = tf_executor.island(%0#1, %1#1) { + %2 = tf_executor.island(%0#1) { tf_executor.yield } - %3:2 = tf_executor.island(%2) { + %3:2 = tf_executor.island(%2, %1#1) { %6 = "tf.Add"(%0#0, %1#0) : (tensor, tensor) -> tensor tf_executor.yield %6 : tensor } @@ -151,3 +151,38 @@ func @empty_island_control_dep_only() -> tensor { } return %fetch : tensor } + +// ----- + +// Tests if empty island with multiple control inputs will be replaced with a +// no-op. +// CHECK-LABEL: func @empty_island_multi_control_inputs +func @empty_island_multi_control_inputs() -> tensor { + %fetch = tf_executor.graph { + %0:2 = tf_executor.island { + %4 = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor} : () -> tensor + tf_executor.yield %4 : tensor + } + // CHECK-NEXT: %[[CONST1:[0-9]*]]:2 = "_tf.Const"() + // CHECK-SAME: () -> (tensor, !_tf.control) + %1:2 = tf_executor.island { + %5 = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor} : () -> tensor + tf_executor.yield %5 : tensor + } + // CHECK-NEXT: %[[CONST2:[0-9]*]]:2 = "_tf.Const"() + // CHECK-SAME: () -> (tensor, !_tf.control) + %2 = tf_executor.island(%0#1, %1#1) { + tf_executor.yield + } + // CHECK-NEXT: %[[NOOP:[0-9]*]] = "_tf.NoOp"(%[[CONST1]]#1, %[[CONST2]]#1) + // CHECK-SAME: (!_tf.control, !_tf.control) -> !_tf.control + %3:2 = tf_executor.island(%2) { + %6 = "tf.Add"(%0#0, %1#0) : (tensor, tensor) -> tensor + tf_executor.yield %6 : tensor + } + // CHECK-NEXT: %[[ADD:[0-9]*]]:2 = "_tf.Add"(%[[CONST1]]#0, %[[CONST2]]#0, %[[NOOP]]) + // CHECK-SAME: (tensor, tensor, !_tf.control) -> (tensor, !_tf.control) + tf_executor.fetch %3#0 : tensor + } + return %fetch : tensor +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning.mlir b/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning.mlir index 771ad5e30d8..8585790564b 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning.mlir @@ -167,16 +167,3 @@ func @control_fetch(%arg0 : i32) { } return } - -// Check that @main function is pruned. -// CHECK-LABEL: func @main -func @main() { - tf_executor.graph { - // CHECK-NOT: tf_executor.island - %0 = tf_executor.island { - tf_executor.yield - } - tf_executor.fetch - } - return -} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning_skip_main.mlir b/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning_skip_main.mlir deleted file mode 100644 index 86568cccd0f..00000000000 --- a/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning_skip_main.mlir +++ /dev/null @@ -1,14 +0,0 @@ -// RUN: tf-opt %s -tf-executor-graph-pruning=skip-main-func | FileCheck %s --dump-input=fail - -// Check that @main function is skipped by default. -// CHECK-LABEL: func @main -func @main() { - tf_executor.graph { - // CHECKT: tf_executor.island - %0 = tf_executor.island { - tf_executor.yield - } - tf_executor.fetch - } - return -} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/BUILD index 6c4d6d2b2ab..5880245cc2d 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/BUILD @@ -4,7 +4,7 @@ licenses(["notice"]) glob_lit_tests( data = [":test_utilities"], - driver = "@local_config_mlir//:run_lit.sh", + driver = "@llvm-project//mlir:run_lit.sh", test_file_exts = ["pbtxt"], ) @@ -14,7 +14,7 @@ filegroup( testonly = True, data = [ "//tensorflow/compiler/mlir:tf-mlir-translate", - "@llvm//:FileCheck", - "@llvm//:not", + "@llvm-project//llvm:FileCheck", + "@llvm-project//llvm:not", ], ) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-if-ops.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-if-ops.pbtxt index cbfa973fd64..8eca30802ef 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-if-ops.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-if-ops.pbtxt @@ -1,11 +1,11 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=a,b -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-input-shapes=':' -tf-output-arrays=StatefulIf,StatelessIf -o - | FileCheck %s +# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=a,b -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-input-shapes=':' -tf-output-arrays=StatefulIf,StatelessIf -o - -mlir-print-debuginfo | FileCheck %s # Verify that TensorFlow If and StatelessIf ops are mapped to the # composite If op in MLIR with is_stateless attribute set accordingly to # distinguish between them. -# CHECK-DAG: "tf.If"{{.*}} is_stateless = false, name = "StatefulIf" -# CHECK-DAG: "tf.If"{{.*}} is_stateless = true, name = "StatelessIf" +# CHECK-DAG: "tf.If"{{.*}} is_stateless = false{{.*}} loc("StatefulIf") +# CHECK-DAG: "tf.If"{{.*}} is_stateless = true{{.*}} loc("StatelessIf") node { name: "tf.Less" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-while-ops.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-while-ops.pbtxt index 953f83a9f68..ede01ebf62b 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-while-ops.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-while-ops.pbtxt @@ -1,11 +1,11 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=iter,val -tf-input-data-types=DT_INT32,DT_FLOAT -tf-input-shapes=':' -tf-output-arrays=StatefulWhile:1,StatelessWhile:1 -o - | FileCheck %s +# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=iter,val -tf-input-data-types=DT_INT32,DT_FLOAT -tf-input-shapes=':' -tf-output-arrays=StatefulWhile:1,StatelessWhile:1 -o - -mlir-print-debuginfo | FileCheck %s # Verify that TensorFlow While and StatelessWhile ops are mapped to the # composite While op in MLIR with is_stateless attribute set accordingly to # distinguish between them. -# CHECK-DAG: "tf.While"{{.*}} is_stateless = false, name = "StatefulWhile" -# CHECK-DAG: "tf.While"{{.*}} is_stateless = true, name = "StatelessWhile" +# CHECK-DAG: "tf.While"{{.*}} is_stateless = false{{.*}} loc("StatefulWhile") +# CHECK-DAG: "tf.While"{{.*}} is_stateless = true{{.*}} loc("StatelessWhile") node { name: "StatefulWhile" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/mlir_passthrough_op.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/mlir_passthrough_op.pbtxt index 1df903d46ce..da79023093c 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/mlir_passthrough_op.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/mlir_passthrough_op.pbtxt @@ -1,7 +1,7 @@ # RUN: tf-mlir-translate -graphdef-to-mlir %s | FileCheck %s # CHECK:"tf.MlirPassthroughOp" -# CHECK: mlir_module = "\0Afunc @main(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10x10xf32> {\0A %add = \22tf.Add\22(%arg0, %arg1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>\0A %ret = \22magic.op\22(%add, %add) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xf32>\0A return %ret : tensor<10x10xf32>\0A}\0A", name = "MlirPassthroughOp"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<*xf32> +# CHECK: mlir_module = "\0Afunc @main(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10x10xf32> {\0A %add = \22tf.Add\22(%arg0, %arg1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>\0A %ret = \22magic.op\22(%add, %add) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xf32>\0A return %ret : tensor<10x10xf32>\0A}\0A"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<*xf32> node { name: "x" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/node-locations.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/node-locations.pbtxt index a8f58c427fd..fdf279f3887 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/node-locations.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/node-locations.pbtxt @@ -90,6 +90,6 @@ library { } # TODO(b/142400497): What is the semantic contract for locations? -# CHECK: "tf.Const"{{.*}}value = dense<2>{{.*}}loc(fused["n1@f1", "n2@f2"]) +# CHECK: "tf.Const"{{.*}}value = dense<2>{{.*}}loc(fused["n1@f1", "n2@f2", "fused_node_outside_function"]) # CHECK: "tf.Const"{{.*}}value = dense<0>{{.*}}loc("node_outside_function") # CHECK: "tf.Const"{{.*}}value = dense<1>{{.*}}loc("node_inside_function@foo") diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/parse_example.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/parse_example.pbtxt new file mode 100644 index 00000000000..7411a5ea4d7 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/parse_example.pbtxt @@ -0,0 +1,225 @@ +# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=input0 -tf-input-data-types=DT_STRING -tf-input-shapes=32 -tf-output-arrays=ParseExample/ParseExampleV2:0,ParseExample/ParseExampleV2:7 -o - | FileCheck %s + +# CHECK: %[[parse_example:.*]]:8, %[[parse_example_control:.*]] = tf_executor.island wraps "tf.ParseExampleV2"(%arg0, +# CHECK: result_segment_sizes = dense<[2, 2, 2, 2, 0, 0]> : vector<6xi32> +# CHECK: tf_executor.fetch %[[parse_example]]#0, %[[parse_example]]#7 : tensor, tensor<32xf32> + +node { + name: "input0" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "shape" + value { + shape { + unknown_rank: true + } + } + } +} +node { + name: "ParseExample/Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + } + } + } + } + } +} +node { + name: "ParseExample/Const_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + } + } + } + } + } +} +node { + name: "ParseExample/ParseExampleV2/names" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + } + } + } + } + } +} +node { + name: "ParseExample/ParseExampleV2/sparse_keys" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 2 + } + } + string_val: "feature_key3" + string_val: "feature_key4" + } + } + } +} +node { + name: "ParseExample/ParseExampleV2/dense_keys" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 2 + } + } + string_val: "feature_key1" + string_val: "feature_key2" + } + } + } +} +node { + name: "ParseExample/ParseExampleV2/ragged_keys" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + } + } + } + } + } +} +node { + name: "ParseExample/ParseExampleV2" + op: "ParseExampleV2" + input: "input0" + input: "ParseExample/ParseExampleV2/names" + input: "ParseExample/ParseExampleV2/sparse_keys" + input: "ParseExample/ParseExampleV2/dense_keys" + input: "ParseExample/ParseExampleV2/ragged_keys" + input: "ParseExample/Const" + input: "ParseExample/Const_1" + attr { + key: "Tdense" + value { + list { + type: DT_FLOAT + type: DT_FLOAT + } + } + } + attr { + key: "dense_shapes" + value { + list { + shape { + } + shape { + } + } + } + } + attr { + key: "num_sparse" + value { + i: 2 + } + } + attr { + key: "ragged_split_types" + value { + list { + } + } + } + attr { + key: "ragged_value_types" + value { + list { + } + } + } + attr { + key: "sparse_types" + value { + list { + type: DT_STRING + type: DT_INT64 + } + } + } +} +versions { + producer: 175 +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/quint8-const.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/quint8-const.pbtxt index 748bc996f36..cf8051f7aaa 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/quint8-const.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/quint8-const.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s +# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - -mlir-print-debuginfo | FileCheck %s node { name: "Quantized_Constant" @@ -28,5 +28,5 @@ versions { } # CHECK: tf.Const -# CHECK-SAME: name = "Quantized_Constant" # CHECK-SAME: value = opaque<"tf", "{{0[xX][0-9a-fA-F]*}}"> : tensor +# CHECK-SAME: loc("Quantized_Constant") diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/switch_n.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/switch_n.pbtxt index 3dd5ce58ed2..e819efcddd1 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/switch_n.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/switch_n.pbtxt @@ -1,13 +1,13 @@ -# RUN: tf-mlir-translate -graphdef-to-splatted-mlir %s -o - | FileCheck %s --dump-input-on-failure +# RUN: tf-mlir-translate -graphdef-to-splatted-mlir %s -o - -mlir-print-debuginfo | FileCheck %s --dump-input-on-failure # CHECK: tf_executor.SwitchN # CHECK-SAME: of 3 : tensor # CHECK-SAME: T = i32 -# CHECK-SAME: name = "Case/branch_index/_3" +# CHECK-SAME: loc("Case/branch_index/_3") # CHECK: tf_executor.SwitchN # CHECK-SAME: of 2 : tensor # CHECK-SAME: T = f32 -# CHECK-SAME: name = "Case/Case/input_0/_7" +# CHECK-SAME: loc("Case/Case/input_0/_7") node { name: "Case/branch_index" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir index 3448c8c2005..c1c5f419ca9 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir @@ -182,7 +182,6 @@ func @rsqrt_grad_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor< return %0 : tensor<*xf32> } - // CHECK-LABEL: SoftmaxCrossEntropyWithLogits // CHECK-SAME: %[[FEATURES:.*]]: tensor<2x3xf32>, %[[LABELS:.*]]: tensor<2x3xf32> func @SoftmaxCrossEntropyWithLogits(%features: tensor<2x3xf32>, %labels: tensor<2x3xf32>) -> (tensor<2xf32>, tensor<2x3xf32>) { @@ -222,6 +221,66 @@ func @scalar_SoftmaxCrossEntropyWithLogits(%features: tensor, %labels: tens return %0#0, %0#1 : tensor, tensor } +// CHECK-LABEL: SparseSoftmaxCrossEntropyWithLogits +// CHECK-SAME: %[[FEATURES:.*]]: tensor<2x3xf32>, %[[SPARSE_LABELS:.*]]: tensor<2xi32> +func @SparseSoftmaxCrossEntropyWithLogits(%features: tensor<2x3xf32>, %labels: tensor<2xi32>) -> (tensor<2xf32>, tensor<2x3xf32>) { + // Convert SPARSE_LABELS to dense LABELS. + // CHECK-DAG: %[[DEPTH:.*]] = "tf.Const"() {value = dense<3> : tensor} : () -> tensor + // CHECK-DAG: %[[ONE:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor} : () -> tensor + // CHECK-DAG: %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor} : () -> tensor + // CHECK-DAG: %[[LABELS:.*]] = "tf.OneHot"(%[[SPARSE_LABELS]], %[[DEPTH]], %[[ONE]], %[[ZERO]]) {axis = 1 : i64} : (tensor<2xi32>, tensor, tensor, tensor) -> tensor<2x3xf32> + + // Adjust labels to have Nan for out of range labels. + // CHECK-DAG: %[[ZERO_I32:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + // CHECK-DAG: %[[IS_NEGATIVE:.*]] = "tf.LessEqual"(%[[ZERO_I32]], %arg1) : (tensor, tensor<2xi32>) -> tensor<2xi1> + // CHECK-DAG: %[[IS_LESS:.*]] = "tf.Less"(%arg1, %[[DEPTH]]) : (tensor<2xi32>, tensor) -> tensor<2xi1> + // CHECK-DAG: %[[IS_WITHIN_RANGE:.*]] = "tf.LogicalAnd"(%[[IS_NEGATIVE]], %[[IS_LESS]]) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> + // CHECK-DAG: %[[NAN:.*]] = "tf.Const"() {value = dense<0x7FC00000> : tensor} : () -> tensor + // CHECK-DAG: %[[ZERO_OR_NAN:.*]] = "tf.SelectV2"(%[[IS_WITHIN_RANGE]], %[[ZERO]], %[[NAN]]) : (tensor<2xi1>, tensor, tensor) -> tensor<2xf32> + // CHECK-DAG: %[[NEG_ONE:.*]] = "tf.Const"() {value = dense<-1> : tensor<1xi64>} : () -> tensor<1xi64> + // CHECK-DAG: %[[RESHAPE:.*]] = "tf.ExpandDims"(%[[ZERO_OR_NAN]], %[[NEG_ONE]]) : (tensor<2xf32>, tensor<1xi64>) -> tensor<2x1xf32> + // CHECK-DAG: %[[ADJUSTED_LABELS:.*]] = "tf.AddV2"(%[[LABELS]], %[[RESHAPE]]) : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32> + + // SoftmaxCrossEntropyWithLogits expansion + // CHECK-DAG: = "tf.Neg"({{.*}}) : (tensor<2x3xf32>) -> tensor<2x3xf32> + // CHECK-DAG: = "tf.LogSoftmax"({{.*}}) : (tensor<2x3xf32>) -> tensor<2x3xf32> + // CHECK-DAG: = "tf.Mul"({{.*}}) : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> + // CHECK-DAG: = "tf.Sum"({{.*}}) {keep_dims = false} : (tensor<2x3xf32>, tensor<1xi64>) -> tensor<2xf32> + // CHECK-DAG: = "tf.Softmax"({{.*}}) : (tensor<2x3xf32>) -> tensor<2x3xf32> + // CHECK-DAG: = "tf.Sub"({{.*}}) : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> + + %0:2 = "tf.SparseSoftmaxCrossEntropyWithLogits"(%features, %labels) : (tensor<2x3xf32>, tensor<2xi32>) -> (tensor<2xf32>, tensor<2x3xf32>) + return %0#0, %0#1 : tensor<2xf32>, tensor<2x3xf32> +} + +// CHECK-LABEL: SparseSoftmaxCrossEntropyWithLogits_with_bf16_i64 +func @SparseSoftmaxCrossEntropyWithLogits_with_bf16_i64(%features: tensor<2x3xbf16>, %labels: tensor<2xi64>) -> (tensor<2xbf16>, tensor<2x3xbf16>) { + // CHECK-NOT: tf.SparseSoftmaxCrossEntropyWithLogits + %0:2 = "tf.SparseSoftmaxCrossEntropyWithLogits"(%features, %labels) : (tensor<2x3xbf16>, tensor<2xi64>) -> (tensor<2xbf16>, tensor<2x3xbf16>) + return %0#0, %0#1 : tensor<2xbf16>, tensor<2x3xbf16> +} + +// CHECK-LABEL: SparseSoftmaxCrossEntropyWithLogits_with_unranked_labels +func @SparseSoftmaxCrossEntropyWithLogits_with_unranked_labels(%features: tensor<2x3xf32>, %labels: tensor) -> (tensor<2xf32>, tensor<2x3xf32>) { + // CHECK-NOT: tf.SparseSoftmaxCrossEntropyWithLogits + %0:2 = "tf.SparseSoftmaxCrossEntropyWithLogits"(%features, %labels) : (tensor<2x3xf32>, tensor) -> (tensor<2xf32>, tensor<2x3xf32>) + return %0#0, %0#1 : tensor<2xf32>, tensor<2x3xf32> +} + +// CHECK-LABEL: SparseSoftmaxCrossEntropyWithLogits_with_dynamic_labels +func @SparseSoftmaxCrossEntropyWithLogits_with_dynamic_labels(%features: tensor<2x3xf32>, %labels: tensor<*xi64>) -> (tensor<2xf32>, tensor<2x3xf32>) { + // CHECK-NOT: tf.SparseSoftmaxCrossEntropyWithLogits + %0:2 = "tf.SparseSoftmaxCrossEntropyWithLogits"(%features, %labels) : (tensor<2x3xf32>, tensor<*xi64>) -> (tensor<2xf32>, tensor<2x3xf32>) + return %0#0, %0#1 : tensor<2xf32>, tensor<2x3xf32> +} + +// CHECK-LABEL: SparseSoftmaxCrossEntropyWithLogits_with_dynamic +func @SparseSoftmaxCrossEntropyWithLogits_with_dynamic(%features: tensor<*xbf16>, %labels: tensor<*xi64>) -> (tensor<2xbf16>, tensor<*xbf16>) { + // CHECK: tf.SparseSoftmaxCrossEntropyWithLogits + %0:2 = "tf.SparseSoftmaxCrossEntropyWithLogits"(%features, %labels) : (tensor<*xbf16>, tensor<*xi64>) -> (tensor<2xbf16>, tensor<*xbf16>) + return %0#0, %0#1 : tensor<2xbf16>, tensor<*xbf16> +} + // CHECK-LABEL: func @tanhgrad_float // CHECK-SAME: (%[[Y:.*]]: tensor<*xf32>, %[[DY:.*]]: tensor<*xf32>) func @tanhgrad_float(%y : tensor<*xf32>, %dy: tensor<*xf32>) -> tensor<*xf32> { @@ -276,3 +335,99 @@ func @addN_variant(%arg0: tensor>>, %arg1: tensor>>, tensor>>, tensor>>) -> tensor>> return %0 : tensor>> } + +// CHECK-LABEL: func @DynamicStitch_simple +func @DynamicStitch_simple(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + // CHECK-DAG: %[[SHAPE:.*]] = "tf.Const"() {value = dense<[-1, 2]> : tensor<2xi64>} : () -> tensor<2xi64> + // CHECK-DAG: %[[INP:.*]] = "tf.Reshape"(%arg0, %[[SHAPE]]) : (tensor<2x2xf32>, tensor<2xi64>) -> tensor<2x2xf32> + // CHECK-DAG: %[[ITEMS:.*]]:2 = "tf.Unpack"(%[[INP]]) {axis = 0 : i64} : (tensor<2x2xf32>) -> (tensor<2xf32>, tensor<2xf32>) + // CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + // CHECK-DAG: %[[RESULT:.*]] = "tf.ConcatV2"(%[[ITEMS]]#1, %[[ITEMS]]#0, %[[AXIS]]) : (tensor<2xf32>, tensor<2xf32>, tensor) -> tensor<2x2xf32> + // CHECK: return %[[RESULT]] + + %indices = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> + %0 = "tf.DynamicStitch"(%indices, %arg0) : (tensor<2xi32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// CHECK-LABEL: DynamicStitch_scalar_matrix_indices +func @DynamicStitch_scalar_matrix_indices(%arg0: tensor<2xf32>, %arg1: tensor<2x2x2xf32>) -> (tensor<5x2xf32>) { + // CHECK-DAG: %[[SHAPE:.*]] = "tf.Const"() {value = dense<[-1, 2]> : tensor<2xi64>} : () -> tensor<2xi64> + // CHECK-DAG: %[[INP0:.*]] = "tf.Reshape"(%arg0, %[[SHAPE]]) : (tensor<2xf32>, tensor<2xi64>) -> tensor<1x2xf32> + // CHECK-DAG: %[[ITEMS0:.*]] = "tf.Unpack"(%[[INP0]]) {axis = 0 : i64} : (tensor<1x2xf32>) -> tensor<2xf32> + // CHECK-DAG: %[[INP1:.*]] = "tf.Reshape"(%arg1, %[[SHAPE]]) : (tensor<2x2x2xf32>, tensor<2xi64>) -> tensor<4x2xf32> + // CHECK-DAG: %[[ITEMS1:.*]]:4 = "tf.Unpack"(%[[INP1]]) {axis = 0 : i64} : (tensor<4x2xf32>) -> (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) + // CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + // CHECK-DAG: %6 = "tf.ConcatV2"(%[[ITEMS1]]#3, %[[ITEMS1]]#2, %[[ITEMS1]]#1, %[[ITEMS1]]#0, %[[ITEMS0]], %[[AXIS]]) : (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor) -> tensor<5x2xf32> + + %indices0 = "tf.Const"() {value = dense<4> : tensor} : () -> tensor + %indices1 = "tf.Const"() {value = dense<[[3, 2], [1, 0]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32> + %0 = "tf.DynamicStitch"(%indices0, %indices1, %arg0, %arg1) : (tensor, tensor<2x2xi32>, tensor<2xf32>, tensor<2x2x2xf32>) -> tensor<5x2xf32> + return %0 : tensor<5x2xf32> +} + +// Verify that custom types are lowered and have legal output. +// CHECK-LABEL: func @DynamicStitch_uint8 +func @DynamicStitch_uint8(%arg0: tensor<2x2x!tf.uint8>) -> tensor<2x2x!tf.uint8> { + // CHECK-NOT: tf.DynamicStitch + + %indices = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> + %0 = "tf.DynamicStitch"(%indices, %arg0) : (tensor<2xi32>, tensor<2x2x!tf.uint8>) -> tensor<2x2x!tf.uint8> + return %0 : tensor<2x2x!tf.uint8> +} + +// CHECK-LABEL: func @DynamicStitch_scalar_item +func @DynamicStitch_scalar_item(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK-DAG: %[[SHAPE:.*]] = "tf.Const"() {value = dense<-1> : tensor<1xi64>} : () -> tensor<1xi64> + // CHECK-DAG: %[[INP:.*]] = "tf.Reshape"(%arg0, %[[SHAPE]]) : (tensor<2xf32>, tensor<1xi64>) -> tensor<2xf32> + // CHECK-DAG: %[[ITEMS]]:2 = "tf.Unpack"(%[[INP]]) {axis = 0 : i64} : (tensor<2xf32>) -> (tensor, tensor) + // CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + // CHECK-DAG: %[[RESULT]] = "tf.ConcatV2"(%[[ITEMS]]#1, %[[ITEMS]]#0, %[[AXIS]]) : (tensor, tensor, tensor) -> tensor<2xf32> + // CHECK: return %[[RESULT]] + + %indices = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> + %0 = "tf.DynamicStitch"(%indices, %arg0) : (tensor<2xi32>, tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + +// CHECK-LABEL: func @DynamicStitch_matrix_item +func @DynamicStitch_matrix_item(%arg0: tensor<2x2x2xf32>) -> tensor<2x2x2xf32> { + // CHECK-DAG: %[[SHAPE:.*]] = "tf.Const"() {value = dense<[-1, 2, 2]> : tensor<3xi64>} : () -> tensor<3xi64> + // CHECK-DAG: %[[INP:.*]] = "tf.Reshape"(%arg0, %[[SHAPE]]) : (tensor<2x2x2xf32>, tensor<3xi64>) -> tensor<2x2x2xf32> + // CHECK-DAG: %[[ITEMS:.*]]:2 = "tf.Unpack"(%[[INP]]) {axis = 0 : i64} : (tensor<2x2x2xf32>) -> (tensor<2x2xf32>, tensor<2x2xf32>) + // CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + // CHECK-DAG: %[[RESULT:.*]] = "tf.ConcatV2"(%[[ITEMS]]#1, %[[ITEMS]]#0, %[[AXIS]]) : (tensor<2x2xf32>, tensor<2x2xf32>, tensor) -> tensor<2x2x2xf32> + // CHECK: return %[[RESULT]] + + %indices = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> + %0 = "tf.DynamicStitch"(%indices, %arg0) : (tensor<2xi32>, tensor<2x2x2xf32>) -> tensor<2x2x2xf32> + return %0 : tensor<2x2x2xf32> +} + +// CHECK-LABEL: func @DynamicStitch_dynamic +func @DynamicStitch_dynamic(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { + // CHECK: tf.DynamicStitch + %0 = "tf.DynamicStitch"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// CHECK-LABEL: func @DynamicStitch_duplicates +func @DynamicStitch_duplicates(%arg0: tensor<2x2xf32>) -> tensor<1x2xf32> { + // CHECK-DAG: %[[SHAPE:.*]] = "tf.Const"() {value = dense<[-1, 2]> : tensor<2xi64>} : () -> tensor<2xi64> + // CHECK-DAG: %[[INP:.*]] = "tf.Reshape"(%arg0, %[[SHAPE]]) : (tensor<2x2xf32>, tensor<2xi64>) -> tensor<2x2xf32> + // CHECK-DAG: %[[ITEMS:.*]]:2 = "tf.Unpack"(%[[INP]]) {axis = 0 : i64} : (tensor<2x2xf32>) -> (tensor<2xf32>, tensor<2xf32>) + // CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + // CHECK-DAG: %[[RESULT:.*]] = "tf.ConcatV2"(%[[ITEMS]]#1, %[[AXIS]]) : (tensor<2xf32>, tensor) -> tensor<1x2xf32> + // CHECK: return %[[RESULT]] + + %indices = "tf.Const"() {value = dense<[0, 0]> : tensor<2xi32>} : () -> tensor<2xi32> + %0 = "tf.DynamicStitch"(%indices, %arg0) : (tensor<2xi32>, tensor<2x2xf32>) -> tensor<1x2xf32> + return %0 : tensor<1x2xf32> +} + +func @Reciprocal(%arg0: tensor<*xf32>) -> tensor<*xf32> { + // CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor} : () -> tensor + // CHECK: "tf.Div"(%[[ONE]], %arg0) : (tensor, tensor<*xf32>) -> tensor<*xf32> + %0 = "tf.Reciprocal"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/BUILD index 976ad56a895..cbdf5d96d0e 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/BUILD @@ -4,7 +4,7 @@ licenses(["notice"]) glob_lit_tests( data = [":test_utilities"], - driver = "@local_config_mlir//:run_lit.sh", + driver = "@llvm-project//mlir:run_lit.sh", test_file_exts = ["mlir"], ) @@ -14,7 +14,7 @@ filegroup( testonly = True, data = [ "//tensorflow/compiler/mlir:tf-mlir-translate", - "@llvm//:FileCheck", - "@llvm//:not", + "@llvm-project//llvm:FileCheck", + "@llvm-project//llvm:not", ], ) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/convert_tensor.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/convert_tensor.mlir index 52e4c529815..e6e22722aec 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/convert_tensor.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/convert_tensor.mlir @@ -1,8 +1,8 @@ // RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s func @main() -> (tensor<1x2xf16>, tensor<2xf16>) { - %0:2 = "_tf.Const"() {device = "", name = "foo", dtype = "tfdtype$DT_HALF", value = dense<1.0> : tensor<1x2xf16>} : () -> (tensor<1x2xf16>, !_tf.control) - %1:2 = "_tf.Const"() {device = "", name = "bar", dtype = "tfdtype$DT_HALF", value = dense<[1.0, 2.0]> : tensor<2xf16>} : () -> (tensor<2xf16>, !_tf.control) + %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_HALF", value = dense<1.0> : tensor<1x2xf16>} : () -> (tensor<1x2xf16>, !_tf.control) loc("foo") + %1:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_HALF", value = dense<[1.0, 2.0]> : tensor<2xf16>} : () -> (tensor<2xf16>, !_tf.control) loc("bar") return %0#0, %1#0 : tensor<1x2xf16>, tensor<2xf16> // CHECK: node { @@ -13,4 +13,4 @@ func @main() -> (tensor<1x2xf16>, tensor<2xf16>) { // CHECK-NEXT: op: "Const" // CHECK: half_val: 15360 // CHECK: half_val: 16384 -} \ No newline at end of file +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/function-resource-args.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/function-resource-args.mlir index 24cb7b703c6..515e03ac2d2 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/function-resource-args.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/function-resource-args.mlir @@ -2,7 +2,7 @@ func @main() -> tensor<*x!tf.resource> attributes {tf.entry_function = {inputs = "", outputs = "func_call"}} { %0 = tf_executor.graph { - %outputs, %control = tf_executor.island wraps "tf.VarHandleOp"() {container = "a", device = "/CPU:0", dtype = i64, name = "x", shape = "tfshape$", shared_name = "x"} : () -> tensor>> + %outputs, %control = tf_executor.island wraps "tf.VarHandleOp"() {container = "a", device = "/CPU:0", dtype = i64, shape = "tfshape$", shared_name = "x"} : () -> tensor>> loc("x") %outputs_0, %control_1 = tf_executor.island wraps "tf.LegacyCall"(%outputs, %outputs) {_disable_call_shape_inference = true, f = @test_func_name0} : (tensor>>, tensor>>) -> tensor<*x!tf.resource> tf_executor.fetch %outputs_0 : tensor<*x!tf.resource> } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/graph-as-function.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/graph-as-function.mlir index 40ddad90aec..cb9c5c380ba 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/graph-as-function.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/graph-as-function.mlir @@ -2,15 +2,15 @@ func @main(%arg0: tensor<*x!tf.resource>, %arg1: tensor<*x!tf.resource>>, %arg2: tensor<*xf32>, %arg3: tensor<2x4x6x8xi32>) -> (tensor, tensor) attributes {tf.entry_function = {inputs = "args_0,args_1,args_2,args_3", outputs = "rets_0_RetVal,rets_1_RetVal"}} { - %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "const", value = dense<0.000000e+00> : tensor} : () -> (tensor, !_tf.control) - %1:2 = "_tf.Identity"(%0#0) {T = "tfdtype$DT_FLOAT", device = "", name = "identity"} : (tensor) -> (tensor, !_tf.control) - %2:2 = "_tf.StatefulPartitionedCall"(%0#0, %arg1) {Tin = ["tfdtype$DT_FLOAT", "tfdtype$DT_RESOURCE"], Tout = ["tfdtype$DT_FLOAT"], _gradient_op_type = "PartitionedCall-1205", config = "", config_proto = "\0A\07\0A\03GPU\10\00\0A\07\0A\03CPU\10\012\02J\008\01", device = "", executor_type = "", f = @function0, name = "statefulpartitionedcall"} : (tensor, tensor<*x!tf.resource>>) -> (tensor, !_tf.control) - return %1#0, %2#0 : tensor, tensor + %0 = "tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", value = dense<0.000000e+00> : tensor} : () -> tensor loc("const") + %1 = "tf.Identity"(%0) {T = "tfdtype$DT_FLOAT", device = ""} : (tensor) -> tensor loc("identity") + %2 = "tf.StatefulPartitionedCall"(%0, %arg1) {Tin = ["tfdtype$DT_FLOAT", "tfdtype$DT_RESOURCE"], Tout = ["tfdtype$DT_FLOAT"], _gradient_op_type = "PartitionedCall-1205", config = "", config_proto = "\0A\07\0A\03GPU\10\00\0A\07\0A\03CPU\10\012\02J\008\01", device = "", executor_type = "", f = @function0} : (tensor, tensor<*x!tf.resource>>) -> tensor loc("statefulpartitionedcall") + return %1, %2 : tensor, tensor } func @function0(%arg0: tensor<*xf32>, %arg1: tensor<*x!tf.resource>) -> tensor<*xf32> attributes {tf.signature.is_stateful} { - %0:2 = "_tf.Identity"(%arg0) {T = "tfdtype$DT_FLOAT", device = "", name = "Identity"} : (tensor<*xf32>) -> (tensor<*xf32>, !_tf.control) + %0 = "tf.Identity"(%arg0) {T = "tfdtype$DT_FLOAT", device = ""} : (tensor<*xf32>) -> tensor<*xf32> loc("Identity@function0") return %0#0 : tensor<*xf32> } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/legalized_name.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/legalized_name.mlir index 67ccf52b62f..60b239aee14 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/legalized_name.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/legalized_name.mlir @@ -6,7 +6,7 @@ func @main() { %0 = "tf.Const"() {dtype = "tfdtype$DT_INT32", value = dense<0> : tensor} : () -> (tensor) loc("^foo") // CHECK: name: "fo.o" %1 = "tf.Const"() {dtype = "tfdtype$DT_INT32", value = dense<1> : tensor} : () -> (tensor) loc("fo{o") - // CHECK: name: "foo.1" + // CHECK: name: "foo" %2 = "tf.Const"() {dtype = "tfdtype$DT_INT32", value = dense<2> : tensor} : () -> (tensor) loc("foo@1") // CHECK: name: "ba.r" %3 = "tf.Const"() {dtype = "tfdtype$DT_INT32", value = dense<2> : tensor} : () -> (tensor) loc("ba r") diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/parse_example.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/parse_example.mlir new file mode 100644 index 00000000000..ec51fdc8e11 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/parse_example.mlir @@ -0,0 +1,86 @@ +// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 175 : i32}} { + func @main(%arg0: tensor<32x!tf.string>) -> (tensor) attributes {tf.entry_function = {inputs = "input0", outputs = "ParseExample/ParseExampleV2"}} { + + %0 = tf_executor.graph { + // NOTE(mrry): This dummy input was manually added because the exporter expects it and fails otherwise. + %dummy_input, %control_dummy = tf_executor.island wraps "tf.Placeholder.input"(%arg0) {device = "", dtype = "tfdtype$DT_STRING", shape = "tfshape$dim { size: 32 }"} : (tensor<32x!tf.string>) -> tensor<32x!tf.string> + + %outputs, %control = tf_executor.island wraps "tf.Const"() {device = "", dtype = f32, value = dense<[]> : tensor<0xf32>} : () -> tensor<0xf32> + %outputs_0, %control_1 = tf_executor.island wraps "tf.Const"() {device = "", dtype = f32, value = dense<[]> : tensor<0xf32>} : () -> tensor<0xf32> + %outputs_2, %control_3 = tf_executor.island wraps "tf.Const"() {device = "", dtype = !tf.string, value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B2073697A653A2032207D207D2074656E736F725F636F6E74656E743A20225C3031345C303134666561747572655F6B657931666561747572655F6B65793222"> : tensor<2x!tf.string>} : () -> tensor<2x!tf.string> + %outputs_4, %control_5 = tf_executor.island wraps "tf.Const"() {device = "", dtype = !tf.string, value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B207D207D"> : tensor<0x!tf.string>} : () -> tensor<0x!tf.string> + %outputs_6, %control_7 = tf_executor.island wraps "tf.Const"() {device = "", dtype = !tf.string, value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B207D207D"> : tensor<0x!tf.string>} : () -> tensor<0x!tf.string> + %outputs_8, %control_9 = tf_executor.island wraps "tf.Const"() {device = "", dtype = !tf.string, value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B2073697A653A2032207D207D2074656E736F725F636F6E74656E743A20225C3031345C303134666561747572655F6B657933666561747572655F6B65793422"> : tensor<2x!tf.string>} : () -> tensor<2x!tf.string> + + %outputs_10:8, %control_11 = tf_executor.island wraps "tf.ParseExampleV2"(%dummy_input, %outputs_4, %outputs_8, %outputs_2, %outputs_6, %outputs, %outputs_0) {Tdense = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], dense_shapes = ["tfshape$", "tfshape$"], device = "", name = "ParseExample/ParseExampleV2", num_sparse = 2 : i64, ragged_split_types = [], ragged_value_types = [], result_segment_sizes = dense<[2, 2, 2, 2, 0, 0]> : vector<6xi32>, sparse_types = ["tfdtype$DT_STRING", "tfdtype$DT_INT64"]} : (tensor<32x!tf.string>, tensor<0x!tf.string>, tensor<2x!tf.string>, tensor<2x!tf.string>, tensor<0x!tf.string>, tensor<0xf32>, tensor<0xf32>) -> (tensor, tensor, tensor, tensor, tensor<2xi64>, tensor<2xi64>, tensor<32xf32>, tensor<32xf32>) + // CHECK: name: "ParseExample/ParseExampleV2" + // CHECK-NEXT: op: "ParseExampleV2" + // CHECK-NEXT: input: "input0" + // CHECK-NEXT: input: "_tf.Const3" + // CHECK-NEXT: input: "_tf.Const5" + // CHECK-NEXT: input: "_tf.Const2" + // CHECK-NEXT: input: "_tf.Const4" + // CHECK-NEXT: input: "_tf.Const" + // CHECK-NEXT: input: "_tf.Const1" + // CHECK-NEXT: attr { + // CHECK-NEXT: key: "Tdense" + // CHECK-NEXT: value { + // CHECK-NEXT: list { + // CHECK-NEXT: type: DT_FLOAT + // CHECK-NEXT: type: DT_FLOAT + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: attr { + // CHECK-NEXT: key: "dense_shapes" + // CHECK-NEXT: value { + // CHECK-NEXT: list { + // CHECK-NEXT: shape { + // CHECK-NEXT: } + // CHECK-NEXT: shape { + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: attr { + // CHECK-NEXT: key: "num_sparse" + // CHECK-NEXT: value { + // CHECK-NEXT: i: 2 + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: attr { + // CHECK-NEXT: key: "ragged_split_types" + // CHECK-NEXT: value { + // CHECK-NEXT: list { + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: attr { + // CHECK-NEXT: key: "ragged_value_types" + // CHECK-NEXT: value { + // CHECK-NEXT: list { + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: attr { + // CHECK-NEXT: key: "sparse_types" + // CHECK-NEXT: value { + // CHECK-NEXT: list { + // CHECK-NEXT: type: DT_STRING + // CHECK-NEXT: type: DT_INT64 + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: } + + tf_executor.fetch %outputs_10#0 : tensor + } + return %0#0 : tensor + // CHECK: name: "main" + // CHECK-NEXT: op: "_Retval" + // CHECK-NEXT: input: "ParseExample/ParseExampleV2" + + } +} + diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir index 2c3c72869b0..582f2237d01 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir @@ -17,8 +17,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr func @simple_chain(%arg0: tensor<1xf32>) -> tensor<*xf32> { // CHECK: %[[MUL:.*]] = "tf.Mul"{{.*}} (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> // CHECK: %[[ADD:.*]] = "tf.Add"(%[[MUL]], %[[MUL]]) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> -// CHECK: %[[CAST:.*]] = "tf.Cast"(%[[ADD]]) {{.*}} : (tensor<1xf32>) -> tensor<*xf32> -// CHECK: return %[[CAST]] : tensor<*xf32> +// CHECK: return %[[ADD]] : tensor<1xf32> %0 = "tf.Mul"(%arg0, %arg0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<*xf32> %1 = "tf.Add"(%0, %0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> return %1 : tensor<*xf32> @@ -29,10 +28,12 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr // CHECK: %[[MUL:.*]] = "tf.Mul"{{.*}} (tensor<1xf32>, tensor<10xf32>) -> tensor<10xf32> // CHECK: %[[ADD:.*]] = "tf.Add"(%[[MUL]], %[[MUL]]) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> // CHECK: %[[CAST:.*]] = "tf.Cast"(%[[ADD]]) {{.*}} : (tensor<10xf32>) -> tensor<*xf32> -// CHECK: return %[[CAST]] : tensor<*xf32> +// CHECK: %[[UNKNOWN:.*]] = "unknown.A"(%[[CAST]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return %[[UNKNOWN]] : tensor<*xf32> %0 = "tf.Mul"(%arg0, %arg1) : (tensor<1xf32>, tensor<10xf32>) -> tensor<*xf32> %1 = "tf.Add"(%0, %0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - return %1 : tensor<*xf32> + %2 = "unknown.A"(%1) : (tensor<*xf32>) -> tensor<*xf32> + return %2 : tensor<*xf32> } // CHECK-LABEL: func @unknown_op @@ -52,8 +53,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr // CHECK: %[[CST:.*]] = "tf.Const"{{.*}} {value = dense<1> : tensor<4xi32>} : () -> tensor<4xi32> // CHECK: %[[CONV:.*]] = "tf.Conv2DBackpropInput"(%[[CST]] // CHECK-SAME: (tensor<4xi32>, tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> -// CHECK: %[[CAST:.*]] = "tf.Cast"(%[[CONV]]) {{.*}} : (tensor<1x1x1x1xf32>) -> tensor -// CHECK: return %[[CAST]] : tensor +// CHECK: return %[[CONV]] : tensor<1x1x1x1xf32> %0 = "tf.Shape"(%arg0) : (tensor<1x1x1x1xi32>) -> tensor<4xi32> %1 = "tf.Conv2DBackpropInput"(%0, %arg1, %arg1) { padding = "VALID", strides = [1, 1, 1, 1] @@ -105,14 +105,16 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr } // CHECK-LABEL: func @shape_from_while_to_cond_body_functions - func @shape_from_while_to_cond_body_functions(%arg0: tensor<4xf32>) -> tensor<4xf32> { - %0 = "tf.While"(%arg0) {cond = @while_cond_func, body = @while_body_func, is_stateless = true} : (tensor<4xf32>) -> tensor<4xf32> - return %0 : tensor<4xf32> + func @shape_from_while_to_cond_body_functions(%arg0: tensor<4xf32>, %arg1: tensor>>, %arg2: tensor>>) -> tensor<4xf32> { + // CHECK "tf.While" + // CHECK-SAME (tensor<4xf32>, tensor>>, tensor>>) -> (tensor<4xf32>, tensor>>, tensor>>) + %0:3 = "tf.While"(%arg0, %arg1, %arg2) {cond = @while_cond_func, body = @while_body_func, is_stateless = true} : (tensor<4xf32>, tensor>>, tensor>>) -> (tensor<4xf32>, tensor<*x!tf.resource>, tensor>>) + return %0#0 : tensor<4xf32> } // CHECK-LABEL: func @while_cond_func - // CHECK-SAME: %arg0: tensor<4xf32>) -> tensor - func @while_cond_func(%arg0: tensor<*xf32>) -> tensor { + // CHECK-SAME: (%arg0: tensor<4xf32>, %arg1: tensor>>, %arg2: tensor>>) -> tensor + func @while_cond_func(%arg0: tensor<*xf32>, %arg1: tensor<*x!tf.resource>, %arg2: tensor>>) -> tensor { %0 = "tf.Const"() {value = dense<[1.000000e-04,2.000000e-04,3.000000e-04,4.000000e-04]> : tensor<4xf32>} : () -> tensor<4xf32> %1 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor // CHECK: tf.Equal @@ -124,14 +126,27 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr } // CHECK-LABEL: func @while_body_func - func @while_body_func(%arg0: tensor<*xf32>) -> tensor<*xf32> { + func @while_body_func(%arg0: tensor<*xf32>, %arg1: tensor<*x!tf.resource>, %arg2: tensor>>) -> (tensor<*xf32>, tensor<*x!tf.resource>, tensor>>) { %0 = "tf.Const"() {value = dense<1.000000e-04> : tensor} : () -> tensor // CHECK: tf.AddV2 // CHECK-SAME: (tensor<4xf32>, tensor) -> tensor<4xf32> %1 = "tf.AddV2"(%arg0, %0) : (tensor<*xf32>, tensor) -> tensor<*xf32> + // CHECK: "tf.Identity" + // CHECK-SAME: (tensor>>) -> tensor>> + %2 = "tf.Identity"(%arg1) : (tensor<*x!tf.resource>) -> tensor<*x!tf.resource> + // CHECK: "tf.TPUReplicatedInput" + // CHECK-SAME: (tensor>>) -> tensor>> + %ri = "tf.TPUReplicatedInput"(%2) : (tensor<*x!tf.resource>) -> tensor<*x!tf.resource> + // CHECK: "tf.ReadVariableOp" + // CHECK-SAME: (tensor>>) -> tensor<4xf32> + %read = "tf.ReadVariableOp"(%ri) : (tensor<*x!tf.resource>) -> tensor<*xf32> + // CHECK: "tf.ReadVariableOp" + // CHECK-SAME: (tensor>>) -> tensor<*xf32> + %read1 = "tf.ReadVariableOp"(%arg2) : (tensor>>) -> tensor<*xf32> // CHECK: return // CHECK-SAME: tensor<4xf32> - return %1 : tensor<*xf32> + // CHECK-SAME: tensor>> + return %1, %arg1, %arg2 : tensor<*xf32>, tensor<*x!tf.resource>, tensor>> } // CHECK-LABEL: func @invalid_function_reused_by_control_flows @@ -162,4 +177,28 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr // CHECK-SAME: tensor<*xf32> return %0 : tensor<*xf32> } + + // CHECK-LABEL: func @with_graph_and_islands + // CHECK-SAME: %[[ARG_0:.*]]: tensor>> + // CHECK-SAME: -> tensor<4xf32> + func @with_graph_and_islands(%arg0: tensor>>) -> tensor<*xf32> { + %graph = tf_executor.graph { + %island:2 = tf_executor.island { + // CHECK: %[[ID_0:.*]] = "tf.IdentityN"(%[[ARG_0]]) + %id0 = "tf.IdentityN"(%arg0) + : (tensor>>) -> tensor>> + // CHECK-NEXT: %[[READ_0:.*]] = "tf.ReadVariableOp"(%[[ID_0]]) + // CHECK-SAME: (tensor>>) -> tensor<4xf32> + %read = "tf.ReadVariableOp"(%id0) : (tensor>>) -> tensor<*xf32> + // CHECK-NEXT: tf_executor.yield %[[READ_0]] : tensor<4xf32> + tf_executor.yield %read : tensor<*xf32> + } + // CHECK: tf_executor.fetch + // CHECK-SAME: tensor<4xf32> + tf_executor.fetch %island#0 : tensor<*xf32> + } + // CHECK: return + // CHECK-SAME: tensor<4xf32> + return %graph : tensor<*xf32> + } } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir b/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir index 9b17956f399..5ff3212db65 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir @@ -737,3 +737,43 @@ func @while_cond( // expected-remark@above {{ID: 6}} // expected-remark@above {{Predecessors: {5}}} } + +// ----- + +// Tests that the pass tracks control dependencies based on TF op registry +// statefulness flag, for ops not yet defined in ODS. + +// CHECK-LABEL: func @tf_registry_ops +func @tf_registry_ops( + // expected-remark@above {{ID: 8}} + %arg0: tensor, %arg1: tensor) { + tf_executor.graph { + // expected-remark@above {{ID: 6}} + // expected-remark@above {{Successors: {7}}} + %island = tf_executor.island { + // expected-remark@above {{ID: 4}} + // expected-remark@above {{Successors: {5}}} + "tf.PrintV2"(%arg0) { output_stream = "stderr", end = "\n" } + // expected-remark@above {{ID: 0}} + // expected-remark@above {{Successors: {2}}} + : (tensor) -> () + %merge_summary = "tf.MergeSummary"(%arg0, %arg1) { N = 2 } + // expected-remark@above {{ID: 1}} + : (tensor, tensor) -> (tensor) + "tf.PrintV2"(%merge_summary) { output_stream = "stderr", end = "\n" } + // expected-remark@above {{ID: 2}} + // expected-remark@above {{Predecessors: {0}}} + // expected-remark@above {{Successors: {3}}} + : (tensor) -> () + tf_executor.yield + // expected-remark@above {{ID: 3}} + // expected-remark@above {{Predecessors: {2}}} + } + tf_executor.fetch %island : !tf_executor.control + // expected-remark@above {{ID: 5}} + // expected-remark@above {{Predecessors: {4}}} + } + return + // expected-remark@above {{ID: 7}} + // expected-remark@above {{Predecessors: {6}}} +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir index 9db1ae27837..d58a0b86df5 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir @@ -898,6 +898,39 @@ func @testSoftmaxCrossEntropyWithLogits(%arg0: tensor<3xf32>, %arg1: tensor<3xf3 // ----- +// Test valid tf.SparseSoftmaxCrossEntropyWithLogits +// CHECK-LABEL: func @testSparseSoftmaxCrossEntropyWithLogits +func @testSparseSoftmaxCrossEntropyWithLogits(%arg0: tensor<2x3xf32>, %arg1: tensor<2xi32>) -> (tensor<3xf32>, tensor<2x3xf32>) { + %0:2 = "tf.SparseSoftmaxCrossEntropyWithLogits"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<2xi32>) -> (tensor<3xf32>, tensor<2x3xf32>) + return %0#0, %0#1 : tensor<3xf32>, tensor<2x3xf32> +} + +// ----- + +func @testSparseSoftmaxCrossEntropyWithLogits(%arg0: tensor<3xf32>, %arg1: tensor<3xi32>) -> (tensor<3xf32>, tensor<2x3xf32>) { + // expected-error @+1 {{requires features operand of rank two}} + %0:2 = "tf.SparseSoftmaxCrossEntropyWithLogits"(%arg0, %arg1) : (tensor<3xf32>, tensor<3xi32>) -> (tensor<3xf32>, tensor<2x3xf32>) + return %0#0, %0#1 : tensor<3xf32>, tensor<2x3xf32> +} + +// ----- + +func @testSparseSoftmaxCrossEntropyWithLogits(%arg0: tensor<2x3xf32>, %arg1: tensor<2x3xi32>) -> (tensor<2xf32>, tensor<2x3xf32>) { + // expected-error @+1 {{requires labels operand of rank one}} + %0:2 = "tf.SparseSoftmaxCrossEntropyWithLogits"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<2x3xi32>) -> (tensor<2xf32>, tensor<2x3xf32>) + return %0#0, %0#1 : tensor<2xf32>, tensor<2x3xf32> +} + +// ----- + +func @testSparseSoftmaxCrossEntropyWithLogits(%arg0: tensor<2x3xf32>, %arg1: tensor<3xi32>) -> (tensor<2xf32>, tensor<2x3xf32>) { + // expected-error @+1 {{requires features and labels with matching first dimension}} + %0:2 = "tf.SparseSoftmaxCrossEntropyWithLogits"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<3xi32>) -> (tensor<2xf32>, tensor<2x3xf32>) + return %0#0, %0#1 : tensor<2xf32>, tensor<2x3xf32> +} + +// ----- + func @testWhileCond(tensor<*xf32>) -> (tensor) func @testWhileBody(tensor<*xf32>) -> (tensor<*xf32>) @@ -2009,3 +2042,246 @@ func @stridedSliceGrad(%dy: tensor<4x8xf32>, %begin: tensor<2xi64>, %end: tensor %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %dy) : (tensor<1x2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<4x8xf32>) -> tensor return %0 : tensor } + +// ----- + +func @testDynamicStitch(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + %indices = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> + %0 = "tf.DynamicStitch"(%indices, %arg0) : (tensor<2xi32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + +func @testDynamicStitch() -> tensor<2x2xf32> { + // expected-error @+1 {{requires attribute N with value >= 1}} + %0 = "tf.DynamicStitch"() : () -> (tensor<2x2xf32>) + return %0 : tensor<2x2xf32> +} + +// ----- + +func @testDynamicStitch(%arg0: tensor<2x2xf32>) -> tensor { + %indices = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> + // expected-error @+1 {{requires non scalar output}} + %0 = "tf.DynamicStitch"(%indices, %arg0) : (tensor<2xi32>, tensor<2x2xf32>) -> tensor + return %0 : tensor +} + +// ----- + +func @testDynamicStitch(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + %indices = "tf.Const"() {value = dense<[-1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> + // expected-error @+1 {{requires non-negative index values; found -1}} + %0 = "tf.DynamicStitch"(%indices, %arg0) : (tensor<2xi32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + +func @testDynamicStitch(%arg0: tensor<3x2xf32>) -> tensor<2x2xf32> { + %indices = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> + // expected-error @+1 {{requires shape of data with type 'tensor<3x2xf32>' to have prefix matching with shape of the corresponding index type 'tensor<2xi32>'}} + %0 = "tf.DynamicStitch"(%indices, %arg0) : (tensor<2xi32>, tensor<3x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + +func @testDynamicStitch(%arg0: tensor<2xf32>, %arg1: tensor<2x2x3xf32>) -> (tensor<5x2xf32>) { + %indices0 = "tf.Const"() {value = dense<4> : tensor} : () -> tensor + %indices1 = "tf.Const"() {value = dense<[[3, 2], [1, 0]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32> + + // expected-error @+1 {{inconsistent shaped data and index pairs; inferred item shapes [2] and [3] don't match}} + %0 = "tf.DynamicStitch"(%indices0, %indices1, %arg0, %arg1) : (tensor, tensor<2x2xi32>, tensor<2xf32>, tensor<2x2x3xf32>) -> tensor<5x2xf32> + return %0 : tensor<5x2xf32> +} + +// ----- + +func @testDynamicStitch(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + %indices = "tf.Const"() {value = dense<[2, 0]> : tensor<2xi32>} : () -> tensor<2xi32> + // expected-error @+1 {{missing index 1}} + %0 = "tf.DynamicStitch"(%indices, %arg0) : (tensor<2xi32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + +func @testDynamicStitch(%arg0: tensor<2x2xf32>) -> tensor<3x2xf32> { + %indices = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> + // expected-error @+1 {{has invalid output type; should be compatible with inferred type 'tensor<2x2xf32>'}} + %0 = "tf.DynamicStitch"(%indices, %arg0) : (tensor<2xi32>, tensor<2x2xf32>) -> tensor<3x2xf32> + return %0 : tensor<3x2xf32> +} + +// ----- + +func @testDynamicStitch(%arg0: tensor, %arg1: tensor) -> (tensor<*xf32>) { + // expected-error @+1 {{requires shape of data with type 'tensor' to have prefix matching with shape of the corresponding index type 'tensor'}} + %0 = "tf.DynamicStitch"(%arg0, %arg1) : (tensor, tensor) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +func @testDynamicStitch(%arg0: tensor, %arg1: tensor<2x?xf32>) -> (tensor<2x3x2xf32>) { + %indices0 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %indices1 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + + // expected-error @+1 {{has invalid output type; should be compatible with inferred type 'tensor<2x2x3xf32>'}} + %0 = "tf.DynamicStitch"(%indices0, %indices1, %arg0, %arg1) : (tensor, tensor, tensor, tensor<2x?xf32>) -> tensor<2x3x2xf32> + return %0 : tensor<2x3x2xf32> +} + +// ----- + +func @testConcatOffest(%concat_dim: tensor, %shape0: tensor<3xi32>) { + // expected-error @+1 {{'tf.ConcatOffset' op requires N to be at least 2, got 1}} + %0 = "tf.ConcatOffset"(%concat_dim, %shape0) : (tensor, tensor<3xi32>) -> tensor<3xi32> + return +} + +// ----- + +func @testConcatOffest(%concat_dim: tensor, %shape0: tensor<3xi32>, %shape1: tensor<3xi32>) { + // expected-error @+1 {{'tf.ConcatOffset' op requires sizes of shapes and offsets to be the same, got sizes 2 and 3}} + %0:3 = "tf.ConcatOffset"(%concat_dim, %shape0, %shape1) : (tensor, tensor<3xi32>, tensor<3xi32>) -> (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) + return +} + +// ----- + +func @testConcatOffest(%concat_dim: tensor<1xi32>, %shape0: tensor<3xi32>, %shape1: tensor<3xi32>) { + // expected-error @+1 {{'tf.ConcatOffset' op requires concat_dim to be a scalar, got tensor of rank 1}} + %0:2 = "tf.ConcatOffset"(%concat_dim, %shape0, %shape1) : (tensor<1xi32>, tensor<3xi32>, tensor<3xi32>) -> (tensor<3xi32>, tensor<3xi32>) + return +} + +// ----- + +func @testConcatOffest(%concat_dim: tensor, %shape0: tensor<3xi32>, %shape1: tensor<3xi32>) { + // expected-error @+1 {{'tf.ConcatOffset' op requires operand and result 1 to have compatible shapes}} + %0:2 = "tf.ConcatOffset"(%concat_dim, %shape0, %shape1) : (tensor, tensor<3xi32>, tensor<3xi32>) -> (tensor<3xi32>, tensor<8xi32>) + return +} + +// ----- + +func @testConcatOffest(%concat_dim: tensor, %shape0: tensor<3xi32>, %shape1: tensor<3x3xi32>) { + // expected-error @+1 {{'tf.ConcatOffset' op requires shape tensor operand 1 to be of rank 1, got tensor of rank 2}} + %0:2 = "tf.ConcatOffset"(%concat_dim, %shape0, %shape1) : (tensor, tensor<3xi32>, tensor<3x3xi32>) -> (tensor<3xi32>, tensor<3x3xi32>) + return +} + +// ----- + +func @testConcatOffest(%concat_dim: tensor, %shape0: tensor<3xi32>, %shape1: tensor<8xi32>) { + // expected-error @+1 {{'tf.ConcatOffset' op requires shape tensor (rank 1) operand 1 to be of length 3, got tensor (rank 1) of length 8}} + %0:2 = "tf.ConcatOffset"(%concat_dim, %shape0, %shape1) : (tensor, tensor<3xi32>, tensor<8xi32>) -> (tensor<3xi32>, tensor<8xi32>) + return +} + +// ----- + +func @tensor_scatter_update(%tensor: tensor, %indices: tensor<4x2xi32>, %updates: tensor<4x4xf32>) -> tensor { + // expected-error @+1 {{op requires tensor operand to have at least 1 dimension}} + %0 = "tf.TensorScatterUpdate"(%tensor, %indices, %updates) : (tensor, tensor<4x2xi32>, tensor<4x4xf32>) -> tensor + return %0 : tensor +} + +// ----- + +func @tensor_scatter_update(%tensor: tensor<4x4x4xf32>, %indices: tensor, %updates: tensor<4x4xf32>) -> tensor<4x4x4xf32> { + // expected-error @+1 {{op requires indices operand to have at least 1 dimension}} + %0 = "tf.TensorScatterUpdate"(%tensor, %indices, %updates) : (tensor<4x4x4xf32>, tensor, tensor<4x4xf32>) -> tensor<4x4x4xf32> + return %0 : tensor<4x4x4xf32> +} + +// ----- + +func @tensor_scatter_update(%tensor: tensor<4x4x4xf32>, %indices: tensor<4x2xi32>, %updates: tensor) -> tensor<4x4x4xf32> { + // expected-error @+1 {{op requires updates operand to have at least 1 dimension}} + %0 = "tf.TensorScatterUpdate"(%tensor, %indices, %updates) : (tensor<4x4x4xf32>, tensor<4x2xi32>, tensor) -> tensor<4x4x4xf32> + return %0 : tensor<4x4x4xf32> +} + +// ----- + +func @tensor_scatter_update(%tensor: tensor<4xf32>, %indices: tensor<4x2xi32>, %updates: tensor<4x4xf32>) -> tensor<4x4x4xf32> { + // expected-error @+1 {{op requires tensor operand with rank greater than or equal to the indices operand's last dimensions}} + %0 = "tf.TensorScatterUpdate"(%tensor, %indices, %updates) : (tensor<4xf32>, tensor<4x2xi32>, tensor<4x4xf32>) -> tensor<4x4x4xf32> + return %0 : tensor<4x4x4xf32> +} + +// ----- + +// CHECK-LABEL: func @testParseExampleV2DenseOnlyValid +func @testParseExampleV2DenseOnlyValid(%serialized: tensor<32x!tf.string>, %names : tensor<32x!tf.string>, %dense_keys : tensor<2x!tf.string>, %dense_default_0 : tensor, %dense_default_1 : tensor) -> (tensor<32xf32>) { + %empty_str_vector = "tf.Const"() {dtype = !tf.string, value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B207D207D"> : tensor<0x!tf.string>} : () -> tensor<0x!tf.string> + %result:2 = "tf.ParseExampleV2"(%serialized, %names, %empty_str_vector, %dense_keys, %empty_str_vector, %dense_default_0, %dense_default_1) {dense_shapes = ["tfshape$", "tfshape$"], num_sparse = 0 : i64, result_segment_sizes = dense<[0, 0, 0, 2, 0, 0]> : vector<6xi32>} : (tensor<32x!tf.string>, tensor<32x!tf.string>, tensor<0x!tf.string>, tensor<2x!tf.string>, tensor<0x!tf.string>, tensor, tensor) -> (tensor<32xf32>, tensor<32xf32>) + return %result#0 : tensor<32xf32> +} + +// ----- + +func @testParseExampleV2DenseMismatchedInputOutput(%serialized: tensor<32x!tf.string>, %names : tensor<32x!tf.string>, %dense_keys : tensor<2x!tf.string>, %dense_default_0 : tensor, %dense_default_1 : tensor) -> (tensor<32xf32>) { + %empty_str_vector = "tf.Const"() {dtype = !tf.string, value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B207D207D"> : tensor<0x!tf.string>} : () -> tensor<0x!tf.string> + // expected-error @+1 {{output 'dense_values' should have same length as attribute 'Tdense'}} + %result:3 = "tf.ParseExampleV2"(%serialized, %names, %empty_str_vector, %dense_keys, %empty_str_vector, %dense_default_0, %dense_default_1) {dense_shapes = ["tfshape$", "tfshape$"], num_sparse = 0 : i64, result_segment_sizes = dense<[0, 0, 0, 3, 0, 0]> : vector<6xi32>} : (tensor<32x!tf.string>, tensor<32x!tf.string>, tensor<0x!tf.string>, tensor<2x!tf.string>, tensor<0x!tf.string>, tensor, tensor) -> (tensor<32xf32>, tensor<32xf32>, tensor<32xi64>) + return %result#0 : tensor<32xf32> +} + +// ----- + +// CHECK-LABEL: func @testParseExampleV2SparseOnlyValid +func @testParseExampleV2SparseOnlyValid(%serialized: tensor<32x!tf.string>, %names : tensor<32x!tf.string>, %sparse_keys : tensor<2x!tf.string>) -> (tensor) { + %empty_str_vector = "tf.Const"() {dtype = !tf.string, value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B207D207D"> : tensor<0x!tf.string>} : () -> tensor<0x!tf.string> + %result:6 = "tf.ParseExampleV2"(%serialized, %names, %sparse_keys, %empty_str_vector, %empty_str_vector) {dense_shapes = [], num_sparse = 2 : i64, result_segment_sizes = dense<[2, 2, 2, 0, 0, 0]> : vector<6xi32>} : (tensor<32x!tf.string>, tensor<32x!tf.string>, tensor<2x!tf.string>, tensor<0x!tf.string>, tensor<0x!tf.string>) -> (tensor, tensor, tensor, tensor, tensor<2xi64>, tensor<2xi64>) + return %result#0 : tensor +} + +// ----- + +func @testParseExampleV2SparseInvalidNumSparse(%serialized: tensor<32x!tf.string>, %names : tensor<32x!tf.string>, %sparse_keys : tensor<2x!tf.string>) -> (tensor) { + %empty_str_vector = "tf.Const"() {dtype = !tf.string, value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B207D207D"> : tensor<0x!tf.string>} : () -> tensor<0x!tf.string> + // expected-error @+1 {{attribute 'num_sparse' should be the same as the length of attribute 'sparse_types'}} + %result:6 = "tf.ParseExampleV2"(%serialized, %names, %sparse_keys, %empty_str_vector, %empty_str_vector) {dense_shapes = [], num_sparse = 3 : i64, result_segment_sizes = dense<[2, 2, 2, 0, 0, 0]> : vector<6xi32>} : (tensor<32x!tf.string>, tensor<32x!tf.string>, tensor<2x!tf.string>, tensor<0x!tf.string>, tensor<0x!tf.string>) -> (tensor, tensor, tensor, tensor, tensor<2xi64>, tensor<2xi64>) + return %result#0 : tensor +} + +// ----- + +func @testParseExampleV2SparseInvalidSparseIndicesOutput(%serialized: tensor<32x!tf.string>, %names : tensor<32x!tf.string>, %sparse_keys : tensor<2x!tf.string>) -> (tensor) { + %empty_str_vector = "tf.Const"() {dtype = !tf.string, value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B207D207D"> : tensor<0x!tf.string>} : () -> tensor<0x!tf.string> + // expected-error @+1 {{output 'sparse_indices' should have same length as attribute 'sparse_types'}} + %result:5 = "tf.ParseExampleV2"(%serialized, %names, %sparse_keys, %empty_str_vector, %empty_str_vector) {dense_shapes = [], num_sparse = 2 : i64, result_segment_sizes = dense<[1, 2, 2, 0, 0, 0]> : vector<6xi32>} : (tensor<32x!tf.string>, tensor<32x!tf.string>, tensor<2x!tf.string>, tensor<0x!tf.string>, tensor<0x!tf.string>) -> (tensor, tensor, tensor, tensor<2xi64>, tensor<2xi64>) + return %result#0 : tensor +} + +// ----- + +func @testParseExampleV2SparseOnlyValid(%serialized: tensor<32x!tf.string>, %names : tensor<32x!tf.string>, %sparse_keys : tensor<2x!tf.string>) -> (tensor) { + %empty_str_vector = "tf.Const"() {dtype = !tf.string, value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B207D207D"> : tensor<0x!tf.string>} : () -> tensor<0x!tf.string> + // expected-error @+1 {{output 'sparse_shapes' should have same length as attribute 'sparse_types'}} + %result:5 = "tf.ParseExampleV2"(%serialized, %names, %sparse_keys, %empty_str_vector, %empty_str_vector) {dense_shapes = [], num_sparse = 2 : i64, result_segment_sizes = dense<[2, 2, 1, 0, 0, 0]> : vector<6xi32>} : (tensor<32x!tf.string>, tensor<32x!tf.string>, tensor<2x!tf.string>, tensor<0x!tf.string>, tensor<0x!tf.string>) -> (tensor, tensor, tensor, tensor, tensor<2xi64>) + return %result#0 : tensor +} + +// ----- + +// CHECK-LABEL: func @testParseExampleV2RaggedOnlyValid +func @testParseExampleV2RaggedOnlyValid(%serialized: tensor<32x!tf.string>, %names : tensor<32x!tf.string>, %ragged_keys : tensor<2x!tf.string>) -> (tensor) { + %empty_str_vector = "tf.Const"() {dtype = !tf.string, value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B207D207D"> : tensor<0x!tf.string>} : () -> tensor<0x!tf.string> + %result:4 = "tf.ParseExampleV2"(%serialized, %names, %empty_str_vector, %empty_str_vector, %ragged_keys) {dense_shapes = [], num_sparse = 0 : i64, result_segment_sizes = dense<[0, 0, 0, 0, 2, 2]> : vector<6xi32>} : (tensor<32x!tf.string>, tensor<32x!tf.string>, tensor<0x!tf.string>, tensor<0x!tf.string>, tensor<2x!tf.string>) -> (tensor, tensor, tensor, tensor) + return %result#0 : tensor +} + +// ----- + +func @testParseExampleV2RaggedMismatchedOutputLengths(%serialized: tensor<32x!tf.string>, %names : tensor<32x!tf.string>, %ragged_keys : tensor<2x!tf.string>) -> (tensor) { + %empty_str_vector = "tf.Const"() {dtype = !tf.string, value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B207D207D"> : tensor<0x!tf.string>} : () -> tensor<0x!tf.string> + // expected-error @+1 {{attribute 'ragged_value_types' should have same length as attribute 'ragged_split_types'}} + %result:3 = "tf.ParseExampleV2"(%serialized, %names, %empty_str_vector, %empty_str_vector, %ragged_keys) {dense_shapes = [], num_sparse = 0 : i64, result_segment_sizes = dense<[0, 0, 0, 0, 2, 1]> : vector<6xi32>} : (tensor<32x!tf.string>, tensor<32x!tf.string>, tensor<0x!tf.string>, tensor<0x!tf.string>, tensor<2x!tf.string>) -> (tensor, tensor, tensor) + return %result#0 : tensor +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_device_ops_invalid.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_device_ops_invalid.mlir index 80fb5b98b67..8a546285f76 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_device_ops_invalid.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_device_ops_invalid.mlir @@ -94,7 +94,7 @@ func @verifier_replicate_terminator() { // Check that a replicate with 'n' attribute that is less than 2 is invalid. func @verifier_replicate_n() { "tf_device.replicate" () ({ -// expected-error@-1 {{'tf_device.replicate' op attribute 'n' failed to satisfy constraint: 32-bit integer attribute whose minimal value is 2}} +// expected-error@-1 {{'tf_device.replicate' op attribute 'n' failed to satisfy constraint: 32-bit integer attribute whose minimum value is 2}} ^entry: tf_device.return }) {n = 1 : i32} : () -> () diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/BUILD index 5ad0d96f79e..93ee05d478e 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/BUILD @@ -13,18 +13,30 @@ py_library( ], ) +py_library( + name = "common_v1", + srcs = ["common_v1.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + ], +) + filegroup( name = "test_utilities", testonly = True, data = [ - "@llvm//:FileCheck", + "@llvm-project//llvm:FileCheck", ], ) # Drop trailing ".py" from all test file names. all_test_basenames = [py[:-3] for py in glob( ["*.py"], - exclude = ["common.py"], + exclude = [ + "common.py", + "common_v1.py", + ], )] # Instantiate all the tests. diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/basic_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/basic_v1.py new file mode 100644 index 00000000000..8fb8b4e6e2d --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/basic_v1.py @@ -0,0 +1,64 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# RUN: %p/basic_v1 | FileCheck %s + +# pylint: disable=missing-docstring,line-too-long +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow.compat.v1 as tf +from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1 + +# CHECK: "tf_saved_model.global_tensor"() {is_mutable, sym_name = "y", type = tensor<1x3xf32>, value = {{.*}} : tensor<1x3xf32>} : () -> () +# CHECK: func @basic([[ARG0:%.*]]: tensor<3x1xf32>, +# CHECK-SAME: [[ARG1:%.*]]: tensor>> {tf_saved_model.bound_input = @y}) -> tensor<3x3xf32> +# CHECK-NEXT: [[R0:%.*]] = "tf.ReadVariableOp"([[ARG1]]) {{{.*}}} : (tensor>>) -> tensor<1x3xf32> +# CHECK-NEXT: [[R1:%.*]] = "tf.MatMul"([[ARG0]], [[R0]]) {{{.*}}} : (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32> +# CHECK-NEXT: return [[R1]] : tensor<3x3xf32> + + +def Test(): + + # Default TF1.x uses reference variables that are not supported by SavedModel + # v1 Importer. To use SavedModel V1 Importer, resource variables should be + # enabled. + tf.compat.v1.enable_resource_variables() + + tf.compat.v1.disable_eager_execution() + + x = tf.constant([[1.0], [1.0], [1.0]]) + y = tf.compat.v1.get_variable( + name='y', + shape=(1, 3), + initializer=tf.random_normal_initializer(), + trainable=True) + r = tf.matmul(x, y) + + tensor_info_x = tf.compat.v1.saved_model.utils.build_tensor_info(x) + tensor_info_r = tf.compat.v1.saved_model.utils.build_tensor_info(r) + + return { + 'basic': + (tf.compat.v1.saved_model.signature_def_utils.build_signature_def( + inputs={'x': tensor_info_x}, + outputs={'r': tensor_info_r}, + method_name=tf.saved_model.PREDICT_METHOD_NAME)) + } + + +if __name__ == '__main__': + common_v1.do_test(Test()) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/build_defs.bzl b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/build_defs.bzl index e60d393bae8..0e83900d98c 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/build_defs.bzl +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/build_defs.bzl @@ -11,6 +11,7 @@ def tf_saved_model_test(name, data): srcs = [name + ".py"], deps = [ "//tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model:common", + "//tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model:common_v1", ], ) @@ -22,5 +23,5 @@ def tf_saved_model_test(name, data): lit_test( name = name + ".py", data = [name] + data, - driver = "@local_config_mlir//:run_lit.sh", + driver = "@llvm-project//mlir:run_lit.sh", ) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common.py index 67725236f07..fd8221cd190 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common.py +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common.py @@ -23,6 +23,7 @@ from __future__ import division from __future__ import print_function import tempfile + from absl import app from absl import flags from absl import logging diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py new file mode 100644 index 00000000000..35858d2b38a --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py @@ -0,0 +1,93 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Serves as a common "main" function for all the SavedModel tests. + +There is a fair amount of setup needed to initialize tensorflow and get it +into a proper TF2 execution mode. This hides that boilerplate. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tempfile +from absl import app +from absl import flags +from absl import logging +import tensorflow.compat.v1 as tf + +from tensorflow.python import pywrap_tensorflow + +# Use /tmp to make debugging the tests easier (see README.md) +flags.DEFINE_string('save_model_path', '', 'Path to save the model to.') +FLAGS = flags.FLAGS + + +# This function needs to take a "create_module_fn", as opposed to just the +# module itself, because the creation of the module has to be delayed until +# after absl and tensorflow have run various initialization steps. +def do_test(signature_def_map, show_debug_info=False): + """Runs test. + + 1. Performs absl and tf "main"-like initialization that must run before almost + anything else. + 2. Converts signature_def_map to SavedModel V1 + 3. Converts SavedModel V1 to MLIR + 4. Prints the textual MLIR to stdout (it is expected that the caller will have + FileCheck checks in its file to check this output). + + This is only for use by the MLIR SavedModel importer tests. + + Args: + signature_def_map: A map from string key to signature_def. The key will be + used as function name in the resulting MLIR. + show_debug_info: If true, shows debug locations in the resulting MLIR. + """ + + # Make LOG(ERROR) in C++ code show up on the console. + # All `Status` passed around in the C++ API seem to eventually go into + # `LOG(ERROR)`, so this makes them print out by default. + logging.set_stderrthreshold('error') + + def app_main(argv): + """Function passed to absl.app.run.""" + if len(argv) > 1: + raise app.UsageError('Too many command-line arguments.') + if FLAGS.save_model_path: + save_model_path = FLAGS.save_model_path + else: + save_model_path = tempfile.mktemp(suffix='.saved_model') + + sess = tf.Session() + sess.run(tf.initializers.global_variables()) + builder = tf.saved_model.builder.SavedModelBuilder(save_model_path) + builder.add_meta_graph_and_variables( + sess, [tf.saved_model.tag_constants.SERVING], + signature_def_map, + strip_default_attrs=True) + builder.save() + + logging.info('Saved model to: %s', save_model_path) + mlir = pywrap_tensorflow.experimental_convert_saved_model_v1_to_mlir( + save_model_path, ','.join([tf.saved_model.tag_constants.SERVING]), + show_debug_info) + # We don't strictly need this, but it serves as a handy sanity check + # for that API, which is otherwise a bit annoying to test. + # The canonicalization shouldn't affect these tests in any way. + mlir = pywrap_tensorflow.experimental_run_pass_pipeline( + mlir, 'tf-standard-pipeline', show_debug_info) + print(mlir) + + app.run(app_main) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/shared_variable_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/shared_variable_v1.py new file mode 100644 index 00000000000..6ba51c2a325 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/shared_variable_v1.py @@ -0,0 +1,64 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# RUN: %p/shared_variable_v1 | FileCheck %s + +# pylint: disable=missing-docstring,line-too-long +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow.compat.v1 as tf +from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1 + +# CHECK: "tf_saved_model.global_tensor"() {is_mutable, sym_name = "y", type = tensor<1x3xf32>, value = {{.*}} : tensor<1x3xf32>} : () -> () +# CHECK: func {{@.*}}([[ARG0:%.*]]: tensor<3x1xf32>, +# CHECK-SAME: [[ARG1:%.*]]: tensor>> {tf_saved_model.bound_input = @y}) -> tensor<3x3xf32> + +# CHECK: func {{@.*}}([[ARG2:%.*]]: tensor<3x1xf32>, +# CHECK-SAME: [[ARG3:%.*]]: tensor>> {tf_saved_model.bound_input = @y}) -> tensor<3x3xf32> + + +def Test(): + + # Default TF1.x uses reference variables that are not supported by SavedModel + # v1 Importer. To use SavedModel V1 Importer, resource variables should be + # enabled. + tf.enable_resource_variables() + + tf.compat.v1.disable_eager_execution() + + x = tf.constant([[1.0], [1.0], [1.0]]) + y = tf.get_variable( + name='y', + shape=(1, 3), + initializer=tf.random_normal_initializer(), + trainable=True) + r = tf.matmul(x, y) + + tensor_info_x = tf.saved_model.utils.build_tensor_info(x) + tensor_info_r = tf.saved_model.utils.build_tensor_info(r) + + signature_def = tf.saved_model.signature_def_utils.build_signature_def( + inputs={'x': tensor_info_x}, + outputs={'r': tensor_info_r}, + method_name=tf.saved_model.PREDICT_METHOD_NAME) + + # Create two signatures that share the same variable. + return {'basic': signature_def, 'basic_2': signature_def} + + +if __name__ == '__main__': + common_v1.do_test(Test()) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc index 7ef3449e3e9..75e7d2daeeb 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc @@ -17,8 +17,8 @@ limitations under the License. #include -#include "mlir/Pass/PassManager.h" // TF:local_config_mlir -#include "mlir/Transforms/Passes.h" // TF:local_config_mlir +#include "mlir/Pass/PassManager.h" // TF:llvm-project +#include "mlir/Transforms/Passes.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" @@ -27,6 +27,9 @@ namespace mlir { namespace TFTPU { void CreateTPUBridge(OpPassManager &pm) { + // Run shape inference so that tf_executor/tf_device ops created later will + // likely to inherit more concrete types. + pm.addPass(TF::CreateTFShapeInferencePass()); OpPassManager &func_pm = pm.nest(); func_pm.addPass(tf_executor::CreateTFExecutorIslandCoarseningPass()); func_pm.addPass(CreateTPUClusterFormationPass()); @@ -35,8 +38,13 @@ void CreateTPUBridge(OpPassManager &pm) { // because DecomposeResourceOpsPass uses pattern rewriter which hoists // changed constants out of tf_device.Launch. func_pm.addPass(TFDevice::CreateDecomposeResourceOpsPass()); - func_pm.addPass(tf_executor::CreateTFExecutorConstantSinkingPass()); - func_pm.addPass(TFDevice::CreateResourceOpLiftingPass()); + + // Run another shape inference pass because resource ecomposition might have + // created new partial types. + pm.addPass(TF::CreateTFShapeInferencePass()); + OpPassManager &func_pm2 = pm.nest(); + func_pm2.addPass(tf_executor::CreateTFExecutorConstantSinkingPass()); + func_pm2.addPass(TFDevice::CreateResourceOpLiftingPass()); pm.addPass(TF::CreateResourceDeviceInferencePass()); pm.addPass(TFDevice::CreateClusterOutliningPass()); @@ -56,7 +64,7 @@ tensorflow::Status TPUBridge(ModuleOp module, bool enable_logging) { // Add logger to bridge passmanager. if (enable_logging) - bridge.addInstrumentation(std::make_unique()); + bridge.enableIRPrinting(std::make_unique()); // Populate a passmanager with the list of passes that implement the bridge. CreateTPUBridge(bridge); @@ -80,7 +88,7 @@ tensorflow::Status RunBridgeWithStandardPipeline(ModuleOp module, // Add logger to bridge passmanager. if (enable_logging) - bridge.addInstrumentation(std::make_unique()); + bridge.enableIRPrinting(std::make_unique()); StandardPipelineOptions pipeline_options; pipeline_options.enable_inliner.setValue(enable_inliner); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.h b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.h index ff446af24f5..34543069f5b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_BRIDGE_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_BRIDGE_H_ -#include "mlir/IR/Module.h" // TF:local_config_mlir +#include "mlir/IR/Module.h" // TF:llvm-project #include "tensorflow/core/lib/core/status.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge_pass.cc index 0208dc2f579..3af20758207 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge_pass.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Pass/PassManager.h" // TF:local_config_mlir -#include "mlir/Transforms/Passes.h" // TF:local_config_mlir +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/PassManager.h" // TF:llvm-project +#include "mlir/Transforms/Passes.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td index 7c38b78f239..bfe58397f22 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td @@ -23,7 +23,7 @@ def SingleResultAndOperandHaveSameElementType : Constraint< CPred<"getElementTypeOrSelf($0) == getElementTypeOrSelf($1)">>; def SingleResultAndOperandHaveSameType : Constraint< - CPred<"$0->getType() == $1->getType()">>; + CPred<"$0.getType() == $1.getType()">>; def IsRank2Tensor : Type, "Rank 2 tensor">; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc index 165d1b2388b..feeddf4696e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc @@ -20,13 +20,13 @@ limitations under the License. #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Block.h" // TF:local_config_mlir -#include "mlir/IR/BlockAndValueMapping.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Block.h" // TF:llvm-project +#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/PassRegistry.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" @@ -68,11 +68,11 @@ StringRef GetDevice(Operation* op) { // re-ordered but forming clusters of non-continuous ops is effectively // re-ordering them.. bool CanMergeIntoCluster(const Cluster& c, Operation* to_merge) { - return llvm::all_of(to_merge->getOperands(), [&](Value* operand) { + return llvm::all_of(to_merge->getOperands(), [&](Value operand) { // Block arguments. - if (isa(operand)) return true; + if (operand.isa()) return true; - Operation* defining_op = operand->getDefiningOp(); + Operation* defining_op = operand.getDefiningOp(); // Operand produced by other islands. if (defining_op->getBlock() != c.ops.front()->getBlock()) return true; @@ -95,12 +95,12 @@ bool CanMergeIntoCluster(const Cluster& c, Operation* to_merge) { }); } -void ReplaceLiveOutExternalUses(llvm::ArrayRef live_outs, +void ReplaceLiveOutExternalUses(llvm::ArrayRef live_outs, tf_device::LaunchOp launch_op) { Region* launch_op_region = &launch_op.body(); for (const auto& p : llvm::zip(live_outs, launch_op.getResults())) { - Value* from = std::get<0>(p); - for (auto& use : from->getUses()) { + Value from = std::get<0>(p); + for (auto& use : from.getUses()) { if (launch_op_region->isAncestor(use.getOwner()->getParentRegion())) continue; use.set(std::get<1>(p)); @@ -109,14 +109,14 @@ void ReplaceLiveOutExternalUses(llvm::ArrayRef live_outs, } // Get all escaped live-out values of a region. -void GetLiveOuts(Region* region, llvm::SmallVectorImpl* live_outs) { +void GetLiveOuts(Region* region, llvm::SmallVectorImpl* live_outs) { live_outs->clear(); for (Operation& op : region->front()) { - for (Value* v : op.getResults()) { + for (Value v : op.getResults()) { // A value is live-out if any of its users are not inside value producer's // region. - bool is_live_out = llvm::any_of(v->getUsers(), [&](Operation* user) { + bool is_live_out = llvm::any_of(v.getUsers(), [&](Operation* user) { return !region->isAncestor(user->getParentRegion()); }); @@ -145,7 +145,7 @@ void BuildLaunchForCluster(const Cluster& c, OpBuilder* builder) { // Get all escaped live-out values of region, they are used later to determine // return values and types of launch op. - llvm::SmallVector live_outs; + llvm::SmallVector live_outs; GetLiveOuts(®ion, &live_outs); // Build a `tf_device.return` op at end of region, with all live-out values @@ -157,8 +157,8 @@ void BuildLaunchForCluster(const Cluster& c, OpBuilder* builder) { llvm::SmallVector live_out_types; live_out_types.reserve(live_outs.size()); - for (Value* v : live_outs) { - live_out_types.emplace_back(v->getType()); + for (Value v : live_outs) { + live_out_types.emplace_back(v.getType()); } tf_device::LaunchOp launch_op = builder->create( diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc index 10337df1a66..1f082bd1137 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc @@ -17,15 +17,15 @@ limitations under the License. // `tf_device.launch` with equivalent `tf_device.launch_func` operations. #include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Block.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir -#include "mlir/Transforms/RegionUtils.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Block.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/PassRegistry.h" // TF:llvm-project +#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" @@ -51,12 +51,12 @@ void ReplaceLaunchReturnWithReturn(tf_device::ReturnOp launch_return_op, // Builds a function that outlines region attached to launch_op and inserts // built function into given module. -FuncOp BuildFunction(StringRef device, llvm::ArrayRef live_ins, +FuncOp BuildFunction(StringRef device, llvm::ArrayRef live_ins, tf_device::LaunchOp launch_op, SymbolTable* symbol_table, OpBuilder* builder) { llvm::SmallVector operand_types; operand_types.reserve(live_ins.size()); - for (Value* v : live_ins) operand_types.emplace_back(v->getType()); + for (Value v : live_ins) operand_types.emplace_back(v.getType()); llvm::SmallVector result_types(launch_op.getResultTypes()); @@ -101,7 +101,7 @@ FuncOp BuildFunction(StringRef device, llvm::ArrayRef live_ins, // removed afterwards.` void OutlineLaunch(tf_device::LaunchOp launch_op, SymbolTable* symbol_table, OpBuilder* builder) { - llvm::SetVector live_ins; + llvm::SetVector live_ins; getUsedValuesDefinedAbove(launch_op.body(), launch_op.body(), live_ins); StringRef device = diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc index 0ef0072390d..11eafdede08 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/compiler/mlir/tensorflow/utils/eval_util.h" +#include "tensorflow/core/platform/mutex.h" namespace mlir { namespace TF { @@ -59,6 +60,10 @@ LogicalResult ConstantFoldFallbackHook( inputs.push_back(input.cast()); } + // Avoid overlapping folds with the same context. + // TODO(jpienaar): Avoid using global context & mutex here. + static auto* mu = new tensorflow::mutex(); + tensorflow::mutex_lock l(*mu); return tensorflow::EvaluateOperation(inst, inputs, ctx, &results); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.h b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.h index ad52ac66538..3718d4bd765 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.h @@ -18,9 +18,9 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project namespace mlir { namespace TF { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.cc b/tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.cc index f17a5cd8808..51c37b038d3 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.cc @@ -15,10 +15,10 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h" -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h b/tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h index 566d956ac85..ae8b4eace4d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_DECODE_CONSTANT_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_DECODE_CONSTANT_H_ -#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/Pass/Pass.h" // TF:llvm-project namespace mlir { namespace TF { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.h b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.h index 3a816233fdf..6697a2181ad 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_DECOMPOSE_RESOURCE_OPS_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_DECOMPOSE_RESOURCE_OPS_H_ -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project namespace mlir { namespace TF { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td index 3c98f30de7b..db82a71bf80 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td @@ -25,7 +25,7 @@ def CreateTFReadVariableOp: NativeCodeCall< "$_builder.create(" " $0.getLoc()," " UnrankedTensorType::get(" - " $1->getType().cast().getElementType())," + " $1.getType().cast().getElementType())," " $2)" >; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops_pass.cc index 61fc12d6ab9..8d83b5c2fa2 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops_pass.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/delete_unused_funcs.cc b/tensorflow/compiler/mlir/tensorflow/transforms/delete_unused_funcs.cc index 50215b7163a..3b13633ed80 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/delete_unused_funcs.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/delete_unused_funcs.cc @@ -18,8 +18,8 @@ limitations under the License. #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/dialect_hooks.cc b/tensorflow/compiler/mlir/tensorflow/transforms/dialect_hooks.cc index af6476615bb..05b0fb20b62 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/dialect_hooks.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/dialect_hooks.cc @@ -16,13 +16,13 @@ limitations under the License. #include #include "llvm/ADT/ArrayRef.h" -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Dialect.h" // TF:local_config_mlir -#include "mlir/IR/DialectHooks.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/Types.h" // TF:local_config_mlir -#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Dialect.h" // TF:llvm-project +#include "mlir/IR/DialectHooks.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc index 918e6ac3078..837944ce0e7 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc @@ -27,12 +27,12 @@ limitations under the License. #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" -#include "mlir/IR/Block.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Location.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir +#include "mlir/IR/Block.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/PassRegistry.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/core/platform/logging.h" @@ -49,11 +49,11 @@ enum IslandType { kParentIsland, kChildIsland }; // IslandResult is a helper struct holding an islands result and associated // inner op result. struct IslandResult { - IslandResult(Value* inner_op_result, Value* island_result) + IslandResult(Value inner_op_result, Value island_result) : inner_op_result(inner_op_result), island_result(island_result) {} - Value* inner_op_result; - Value* island_result; + Value inner_op_result; + Value island_result; }; struct ExecutorIslandCoarsening @@ -70,16 +70,16 @@ llvm::Optional GetOperandCandidateToMergeWith(IslandOp island) { Operation* candidate = nullptr; // Check island control operands. - for (Value* input : island.controlInputs()) { - Operation* def = input->getDefiningOp(); + for (Value input : island.controlInputs()) { + Operation* def = input.getDefiningOp(); DCHECK_EQ(def->getParentOp(), graph_op); if (!candidate || candidate->isBeforeInBlock(def)) candidate = def; } // Check island data operands. island.walk([graph_op, &candidate](Operation* op) { - for (Value* input : op->getOperands()) { - Operation* def = input->getDefiningOp(); + for (Value input : op->getOperands()) { + Operation* def = input.getDefiningOp(); if (!def || def->getParentOp() != graph_op) continue; if (!candidate || candidate->isBeforeInBlock(def)) candidate = def; } @@ -99,15 +99,15 @@ llvm::Optional GetResultCandidateToMergeWith(IslandOp island) { Operation* candidate = nullptr; // Check island control results. - for (Operation* user : island.control()->getUsers()) { + for (Operation* user : island.control().getUsers()) { DCHECK_EQ(user->getParentOp(), graph_op); if (!candidate || user->isBeforeInBlock(candidate)) candidate = user; } // Check island data results. Block& graph_body = llvm::cast(graph_op).GetBody(); - for (Value* result : island.outputs()) { - for (Operation* user : result->getUsers()) { + for (Value result : island.outputs()) { + for (Operation* user : result.getUsers()) { Operation* def = graph_body.findAncestorOpInBlock(*user); DCHECK_NE(def, nullptr); if (!candidate || def->isBeforeInBlock(candidate)) candidate = def; @@ -121,9 +121,9 @@ llvm::Optional GetResultCandidateToMergeWith(IslandOp island) { // Collects the operands for the new island by collecting all control inputs of // the islands being merged. -llvm::SmallSetVector GetNewIslandOperands(IslandOp parent, - IslandOp child) { - llvm::SmallSetVector operands; +llvm::SmallSetVector GetNewIslandOperands(IslandOp parent, + IslandOp child) { + llvm::SmallSetVector operands; operands.insert(parent.getOperands().begin(), parent.getOperands().end()); operands.insert(child.getOperands().begin(), child.getOperands().end()); operands.remove(parent.control()); @@ -145,9 +145,9 @@ llvm::SmallVector GetNewIslandResultsAndForwardResults( for (auto ret_vals : llvm::zip(parent.GetYield().getOperands(), parent.outputs())) { bool result_captured = false; - Value* inner_op_result = std::get<0>(ret_vals); - Value* island_result = std::get<1>(ret_vals); - for (auto& use : llvm::make_early_inc_range(island_result->getUses())) { + Value inner_op_result = std::get<0>(ret_vals); + Value island_result = std::get<1>(ret_vals); + for (auto& use : llvm::make_early_inc_range(island_result.getUses())) { if (child_body.findAncestorOpInBlock(*use.getOwner())) { // Forward result from inner op. use.set(inner_op_result); @@ -160,9 +160,9 @@ llvm::SmallVector GetNewIslandResultsAndForwardResults( for (auto ret_vals : llvm::zip(child.GetYield().getOperands(), child.outputs())) { - Value* inner_op_result = std::get<0>(ret_vals); - Value* island_result = std::get<1>(ret_vals); - if (!island_result->use_empty()) { + Value inner_op_result = std::get<0>(ret_vals); + Value island_result = std::get<1>(ret_vals); + if (!island_result.use_empty()) { results.emplace_back(inner_op_result, island_result); } } @@ -173,12 +173,12 @@ llvm::SmallVector GetNewIslandResultsAndForwardResults( // Creates the new merged island. IslandOp CreateNewIsland(IslandOp parent, IslandOp child, IslandType insert_position, - llvm::ArrayRef operands, + llvm::ArrayRef operands, llvm::ArrayRef results) { // Collect types from results. llvm::SmallVector result_types; for (const auto& result : results) - result_types.push_back(result.inner_op_result->getType()); + result_types.push_back(result.inner_op_result.getType()); // IslandOps always have a control result. result_types.push_back(ControlType::get(parent.getContext())); @@ -194,14 +194,14 @@ IslandOp CreateNewIsland(IslandOp parent, IslandOp child, // Creates respective YieldOp for the new merged island. YieldOp CreateNewIslandYieldOp(IslandOp new_island, llvm::ArrayRef results) { - llvm::SmallVector yield_operands; + llvm::SmallVector yield_operands; yield_operands.reserve(results.size()); for (auto ret_vals : llvm::zip(results, new_island.outputs())) { const auto& old_result = std::get<0>(ret_vals); // Replace original island result with new island result. - old_result.island_result->replaceAllUsesWith(std::get<1>(ret_vals)); + old_result.island_result.replaceAllUsesWith(std::get<1>(ret_vals)); // Add associated inner op result to operands of the YieldOp. yield_operands.push_back(old_result.inner_op_result); @@ -232,8 +232,7 @@ void MoveInnerOpsToNewIsland(IslandOp parent, IslandOp child, // Merges two islands and places new merged island before parent or child. void MergeIslands(IslandOp parent, IslandOp child, IslandType insert_position) { // Collect operands for the new merged island. - llvm::SmallSetVector operands = - GetNewIslandOperands(parent, child); + llvm::SmallSetVector operands = GetNewIslandOperands(parent, child); // Collect results for the new merged island. llvm::SmallVector results = @@ -250,8 +249,8 @@ void MergeIslands(IslandOp parent, IslandOp child, IslandType insert_position) { MoveInnerOpsToNewIsland(parent, child, new_yield_op.getOperation()); // Update control inputs to point to the new merged island. - child.control()->replaceAllUsesWith(new_island.control()); - parent.control()->replaceAllUsesWith(new_island.control()); + child.control().replaceAllUsesWith(new_island.control()); + parent.control().replaceAllUsesWith(new_island.control()); // Remove merged islands. child.erase(); @@ -288,15 +287,15 @@ bool MergeIslandWithResult(IslandOp parent) { // This allows our def-use based island coarsening algorithm to merge // islands that independently feed into a fetch. void InsertDummyIslandForFetch(FetchOp fetch) { - llvm::SmallVector data_fetches; + llvm::SmallVector data_fetches; llvm::SmallVector data_types; - llvm::SmallVector control_fetches; + llvm::SmallVector control_fetches; for (auto value : fetch.fetches()) { - if (value->getType().isa()) { + if (value.getType().isa()) { control_fetches.push_back(value); } else { data_fetches.push_back(value); - data_types.push_back(value->getType()); + data_types.push_back(value.getType()); } } auto island = OpBuilder(fetch).create( diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc b/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc index 52b425c4ee6..44309a5e019 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc @@ -30,24 +30,24 @@ limitations under the License. #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" -#include "mlir/Analysis/LoopAnalysis.h" // TF:local_config_mlir -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Block.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir -#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir -#include "mlir/IR/Types.h" // TF:local_config_mlir -#include "mlir/IR/Value.h" // TF:local_config_mlir -#include "mlir/IR/Visitors.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir -#include "mlir/Support/Functional.h" // TF:local_config_mlir -#include "mlir/Support/LLVM.h" // TF:local_config_mlir -#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Block.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/OperationSupport.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/TypeUtilities.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/IR/Value.h" // TF:llvm-project +#include "mlir/IR/Visitors.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/PassRegistry.h" // TF:llvm-project +#include "mlir/Support/Functional.h" // TF:llvm-project +#include "mlir/Support/LLVM.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -65,13 +65,13 @@ class SwitchFoldPass : public mlir::FunctionPass { } // namespace // Returns the defining op for a value looking through islands. -static Operation* GetDefiningOp(Value* val) { - Operation* op = val->getDefiningOp(); +static Operation* GetDefiningOp(Value val) { + Operation* op = val.getDefiningOp(); auto island_op = dyn_cast(op); if (!island_op) return op; auto yield_op = island_op.GetYield(); - auto index = cast(val)->getResultNumber(); - return yield_op.getOperand(index)->getDefiningOp(); + auto index = val.cast().getResultNumber(); + return yield_op.getOperand(index).getDefiningOp(); } // Returns either the value or input to an IdentityOp. @@ -81,7 +81,7 @@ static Operation* GetDefiningOp(Value* val) { // identity nodes are common so handle them specially when considering // predicate in a minimally invasive way until identity's are handled more // generally. -static Value* LookThroughIdentityOp(Value* pred_val) { +static Value LookThroughIdentityOp(Value pred_val) { if (!pred_val) return pred_val; auto op = GetDefiningOp(pred_val); if (auto id_op = dyn_cast(op)) pred_val = id_op.input(); @@ -114,7 +114,7 @@ class DeadQueue { // feeding into the Merge then we could have a null value here. count = 0; for (auto operand : op->getOperands()) { - if (operand && !operand->getType().isa()) + if (operand && !operand.getType().isa()) ++count; } } @@ -124,9 +124,9 @@ class DeadQueue { } // Enqueue users of a value. - void EnqueueUsers(Value* val) { - for (auto user : val->getUsers()) { - Enqueue(user, val->getType().isa()); + void EnqueueUsers(Value val) { + for (auto user : val.getUsers()) { + Enqueue(user, val.getType().isa()); } } @@ -175,7 +175,7 @@ class DeadQueue { // Enqueues values of foldable switch ops. static void MatchSwitchFoldOps(tf_executor::SwitchOp switch_op, DeadQueue* queue) { - Value* pred_val = LookThroughIdentityOp(switch_op.predicate()); + Value pred_val = LookThroughIdentityOp(switch_op.predicate()); // If predicate or input is null then enqueue entire op for deletion. if (pred_val == nullptr || switch_op.data() == nullptr) { @@ -187,9 +187,9 @@ static void MatchSwitchFoldOps(tf_executor::SwitchOp switch_op, if (!matchPattern(pred_val, m_Constant(&pred))) return; bool taken = pred.getSplatValue(); - Value* dead = taken ? switch_op.falseOutput() : switch_op.trueOutput(); - Value* live = !taken ? switch_op.falseOutput() : switch_op.trueOutput(); - live->replaceAllUsesWith(switch_op.data()); + Value dead = taken ? switch_op.falseOutput() : switch_op.trueOutput(); + Value live = !taken ? switch_op.falseOutput() : switch_op.trueOutput(); + live.replaceAllUsesWith(switch_op.data()); queue->EnqueueUsers(dead); // Delete switch op. @@ -210,15 +210,15 @@ static LogicalResult FoldMergeNodes(FuncOp function, const DeadQueue& queue) { for (auto it : queue.merge_nodes()) { // Find the valid input to merge node. - Value* val = nullptr; + Value val = nullptr; int index = -1; auto* merge = it.first; auto merge_op = cast(merge); for (auto e : llvm::enumerate(merge->getOperands())) { - Value* operand = e.value(); + Value operand = e.value(); if (!operand) continue; // Skip control operands. - if (operand->getType().isa()) break; + if (operand.getType().isa()) break; if (val != nullptr) { return merge->emitOpError("multiple valid inputs post switch folding"); } @@ -226,26 +226,26 @@ static LogicalResult FoldMergeNodes(FuncOp function, const DeadQueue& queue) { index = e.index(); } assert(val != nullptr && "merge node should have been deleted"); - merge_op.output()->replaceAllUsesWith(val); + merge_op.output().replaceAllUsesWith(val); // Build and insert value_index only if needed. - if (!merge_op.value_index()->use_empty()) { - merge_op.value_index()->replaceAllUsesWith( + if (!merge_op.value_index().use_empty()) { + merge_op.value_index().replaceAllUsesWith( build_index(merge->getLoc(), index)); } // Propagate control dependencies if used. - if (!merge_op.control()->use_empty()) { + if (!merge_op.control().use_empty()) { // Change control dependencies from the merge to being on the parent of // the value being propagated. - auto def_op = val->getDefiningOp(); + auto def_op = val.getDefiningOp(); #ifndef NDEBUG auto exec_dialect = function.getContext()->getRegisteredDialect("tf_executor"); assert(def_op->getDialect() == exec_dialect && "unable to forward control dependencies"); #endif - merge_op.control()->replaceAllUsesWith( + merge_op.control().replaceAllUsesWith( def_op->getResult(def_op->getNumResults() - 1)); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc index e9b3879c025..6e713570f75 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc @@ -16,14 +16,14 @@ limitations under the License. // This transformation pass transforms functional control flow operations in the // standard TensorFlow dialect to MLIR Control Flow Graph (CFG) form. -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir -#include "mlir/IR/Value.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/TypeUtilities.h" // TF:llvm-project +#include "mlir/IR/Value.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/PassRegistry.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" @@ -48,12 +48,12 @@ struct FunctionalControlFlowToCFG // non-empty means True and empty means False. If the tensor is not a scalar, // being empty means False and being non-empty means True. // -static Value* LowerCondition(Location loc, Value* value, OpBuilder* builder) { +static Value LowerCondition(Location loc, Value value, OpBuilder* builder) { // TODO: Right now we just handle zero-D tensors of boolean values. // FIXME: This is almost all wrong, but is a placeholder to unblock the one // testcases, later patches will build on this once I build the right infra to // support it. - TensorType type = value->getType().cast(); + TensorType type = value.getType().cast(); if (!type.hasRank() || type.getRank() != 0 || !type.getElementType().isInteger(1)) { return emitError(loc, "only supports zero-D bool tensors now"), nullptr; @@ -70,17 +70,16 @@ static Value* LowerCondition(Location loc, Value* value, OpBuilder* builder) { // Requires the function to provide arguments for each of the `fn` operands // that is compatible for tensor cast. // -static Operation* CallFn(Location loc, - const std::function& get_arg, FuncOp fn, - OpBuilder* builder) { +static Operation* CallFn(Location loc, const std::function& get_arg, + FuncOp fn, OpBuilder* builder) { FunctionType fn_type = fn.getType(); - llvm::SmallVector operands; + llvm::SmallVector operands; int num_operands = fn_type.getNumInputs(); operands.reserve(num_operands); for (int i = 0; i < num_operands; ++i) { - Value* val = get_arg(i); + Value val = get_arg(i); Type expected = fn_type.getInput(i); - if (val->getType() != expected) { + if (val.getType() != expected) { val = builder->create(loc, expected, val, /*Truncate=*/builder->getBoolAttr(false)); @@ -95,16 +94,16 @@ static Operation* CallFn(Location loc, // // Requires the function to provide values for each of the block arguments and // they should be pair-wise compatible for tensor cast. -static llvm::SmallVector PrepareValsForJump( - Location loc, const std::function& get_val, Block* block, +static llvm::SmallVector PrepareValsForJump( + Location loc, const std::function& get_val, Block* block, OpBuilder* builder) { - llvm::SmallVector result; + llvm::SmallVector result; int num_vals = block->getNumArguments(); result.reserve(num_vals); for (int i = 0; i < num_vals; ++i) { - Value* val = get_val(i); - Type expected = block->getArgument(i)->getType(); - if (val->getType() != expected) { + Value val = get_val(i); + Type expected = block->getArgument(i).getType(); + if (val.getType() != expected) { val = builder->create(loc, expected, val, /*Truncate=*/builder->getBoolAttr(false)); @@ -119,7 +118,7 @@ static llvm::SmallVector PrepareValsForJump( // // Requires the function to provide values for each of the block arguments and // they should be pair-wise compatible for tensor cast. -static void JumpToBlock(Location loc, const std::function& get_arg, +static void JumpToBlock(Location loc, const std::function& get_arg, Block* block, OpBuilder* builder) { auto operands = PrepareValsForJump(loc, get_arg, block, builder); builder->create(loc, block, operands); @@ -136,14 +135,14 @@ static void ReplaceOpResultWithBlockArgs(Location loc, Operation* op, Block* block, OpBuilder* builder) { assert(op->getNumResults() == block->getNumArguments()); for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) { - Value* arg = block->getArgument(i); - Value* result = op->getResult(i); - if (arg->getType() != result->getType()) { + Value arg = block->getArgument(i); + Value result = op->getResult(i); + if (arg.getType() != result.getType()) { arg = - builder->create(loc, result->getType(), arg, + builder->create(loc, result.getType(), arg, /*Truncate=*/builder->getBoolAttr(false)); } - result->replaceAllUsesWith(arg); + result.replaceAllUsesWith(arg); } } @@ -160,7 +159,7 @@ static LogicalResult LowerIfOp(IfOp op) { OpBuilder builder(op_inst); // Lower the condition to a boolean value (i1). - Value* cond_i1 = LowerCondition(loc, op.cond(), &builder); + Value cond_i1 = LowerCondition(loc, op.cond(), &builder); if (!cond_i1) return failure(); auto module = op_inst->getParentOfType(); @@ -174,8 +173,8 @@ static LogicalResult LowerIfOp(IfOp op) { // Add the block arguments to the merge point, and replace all uses of the // original operation results with them. - for (Value* value : op_inst->getResults()) - merge_block->addArgument(value->getType()); + for (Value value : op_inst->getResults()) + merge_block->addArgument(value.getType()); ReplaceOpResultWithBlockArgs(loc, op_inst, merge_block, &builder); // Get arguments to the branches after dropping the condition which is the @@ -200,8 +199,8 @@ static LogicalResult LowerIfOp(IfOp op) { // orig_block with a conditional branch. builder.setInsertionPointToEnd(orig_block); builder.create(loc, cond_i1, then_block, - llvm::ArrayRef(), else_block, - llvm::ArrayRef()); + llvm::ArrayRef(), else_block, + llvm::ArrayRef()); // Finally, delete the op in question. op_inst->erase(); @@ -277,7 +276,7 @@ static LogicalResult LowerWhileOp(WhileOp op) { Operation* cond_call_op = CallFn(loc, get_cond_arg, cond_fn, &builder); assert(cond_call_op->getNumResults() == 1); - Value* condition = LowerCondition(loc, cond_call_op->getResult(0), &builder); + Value condition = LowerCondition(loc, cond_call_op->getResult(0), &builder); auto br_operands = PrepareValsForJump(loc, get_cond_arg, body_block, &builder); builder.create(loc, condition, body_block, br_operands, diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc b/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc index 23cdebc4323..c7dac93101b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc @@ -18,10 +18,10 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Casting.h" -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/Value.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/Value.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/PassRegistry.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" @@ -38,8 +38,8 @@ void PruneGraph(GraphOp graph) { // Visit an op's operands if it is output of an Operation in same graph. auto visit_op = [&](Operation* op) { - for (Value* operand : op->getOperands()) { - Operation* def = operand->getDefiningOp(); + for (Value operand : op->getOperands()) { + Operation* def = operand.getDefiningOp(); if (def && def->getParentOp() == graph && reachable_ops.insert(def).second) { // Op has not been visited, add to queue to visit later. @@ -86,36 +86,17 @@ namespace { // This transformation pass prunes a TF graph eliminating dead-nodes. struct GraphPruning : public FunctionPass { void runOnFunction() override { - FuncOp func = getFunction(); - if (func.getName() == "main" && skip_main_func) return; - func.walk([](tf_executor::GraphOp graph) { PruneGraph(graph); }); + getFunction().walk([](tf_executor::GraphOp graph) { PruneGraph(graph); }); } - - struct Options : public PassOptions { - Option skip_main_func{ - *this, "skip-main-func", - llvm::cl::desc("skip graph pruning for main function"), - llvm::cl::init(false)}; - }; - - explicit GraphPruning(bool skip_main_func) - : FunctionPass(), skip_main_func(skip_main_func) {} - - explicit GraphPruning(const Options& option) - : GraphPruning(option.skip_main_func) {} - - private: - bool skip_main_func; }; } // namespace -std::unique_ptr> CreateTFExecutorGraphPruningPass( - bool skip_main_func) { - return std::make_unique(skip_main_func); +std::unique_ptr> CreateTFExecutorGraphPruningPass() { + return std::make_unique(); } -static PassRegistration pass( +static PassRegistration pass( "tf-executor-graph-pruning", "Prune unreachable nodes in a TensorFlow Graph."); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/inline_global_tensors.cc b/tensorflow/compiler/mlir/tensorflow/transforms/inline_global_tensors.cc index c994ccf498b..6d780d08d6b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/inline_global_tensors.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/inline_global_tensors.cc @@ -24,10 +24,10 @@ limitations under the License. #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" @@ -55,7 +55,7 @@ void InlineGlobalTensorsPass::runOnModule() { // Replace the arg with a tf.Const op in the function body. auto const_op = builder.create(global_tensor.getLoc(), global_tensor.value()); - func.getArgument(i)->replaceAllUsesWith(const_op.getResult()); + func.getArgument(i).replaceAllUsesWith(const_op.getResult()); args_to_erase.push_back(i); } func.eraseArguments(args_to_erase); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc index e06831ceb21..e9434ab4d5d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc @@ -17,12 +17,14 @@ limitations under the License. #include -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Diagnostics.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Diagnostics.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/TypeUtilities.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/core/util/tensor_format.h" @@ -67,6 +69,14 @@ static DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) { return DenseElementsAttr::get(scalar_ty, attr); } +// Returns float DenseElementsAttr with scalar shape with the specified value. +static DenseElementsAttr GetScalarOfFloatType(Type ty, double raw_value) { + auto float_ty = ty.cast(); + FloatAttr attr = FloatAttr::get(float_ty, raw_value); + RankedTensorType scalar_ty = RankedTensorType::get({}, ty); + return DenseElementsAttr::get(scalar_ty, attr); +} + // Returns reduction indices to use while lowering tf.BiasAddGrad op to tf.Sum // op. DenseIntElementsAttr GetBiasAddGradReductionIndices(int64_t rank, @@ -124,8 +134,8 @@ class LowerAddNOp : public OpRewritePattern { // TODO(hinsu): Improve parallelism by splitting operands in two halves and // accumulating them first. - Value *result = *op.inputs().begin(); - for (Value *operand : llvm::drop_begin(op.inputs(), 1)) { + Value result = *op.inputs().begin(); + for (Value operand : llvm::drop_begin(op.inputs(), 1)) { result = rewriter.create(op.getLoc(), result, operand); } @@ -134,6 +144,101 @@ class LowerAddNOp : public OpRewritePattern { } }; +// Lowers DynamicStitch op with constant indices and with static input and +// output shapes using Reshape, UnPack and ConcatV2 op. +// +// %indices0 = "tf.Const"() {value = dense<4> : tensor} +// %indices1 = "tf.Const"() {value = dense<[[3, 2], [1, 0]]> : +// tensor<2x2xi32>} %0 = "tf.DynamicStitch"(%indices0, %indices1, %arg0, +// %arg1) +// : (tensor, tensor<2x2xi32>, tensor<2xf32>, tensor<2x2x2xf32>) +// -> tensor<5x2xf32> +// +// is lowered to +// +// %shape = "tf.Const"() {value = dense<[-1, 2]> : tensor<2xi64>} +// %inp0 = "tf.Reshape"(%arg0, %shape) +// : (tensor<2xf32>, tensor<2xi64>) -> tensor<1x2xf32> +// %inp1 = "tf.Reshape"(%arg1, %shape) +// : (tensor<2x2x2xf32>, tensor<2xi64>) -> tensor<4x2xf32> +// %items0 = "tf.Unpack"(%[[INP0]]) {axis = 0 : i64} +// : (tensor<1x2xf32>) -> tensor<2xf32> +// %items1:4 = "tf.Unpack"(%[[INP1]]) {axis = 0 : i64} +// : (tensor<4x2xf32>) -> (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, +// tensor<2xf32>) +// %axis = "tf.Const"() {value = dense<0> : tensor} +// %0 = "tf.ConcatV2"(items1#3, items1#2, items1#1, items1#0, %items0, %axis) +// : (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, +// tensor<2xf32>, tensor) -> tensor<5x2xf32> +// +class LowerDynamicStitchOp : public OpRewritePattern { + public: + explicit LowerDynamicStitchOp(MLIRContext *context) + : OpRewritePattern(context) {} + + PatternMatchResult matchAndRewrite(DynamicStitchOp op, + PatternRewriter &rewriter) const override { + // Static output type is used to compute intermediate values. Note that the + // output type doesn't have to be static but if input types and indices are + // constant, then the output type can be statically determined. + RankedTensorType out_ty = op.getType().dyn_cast(); + if (!out_ty || !out_ty.hasStaticShape()) return matchFailure(); + + // Extract out all the constant indices' attributes and verify that data + // types are static. + SmallVector indices; + indices.reserve(op.N()); + for (auto it : llvm::zip(op.indices(), op.data())) { + Value index = std::get<0>(it); + Value data = std::get<1>(it); + + DenseIntElementsAttr index_attr; + if (!matchPattern(index, m_Constant(&index_attr))) return matchFailure(); + indices.push_back(index_attr); + + RankedTensorType data_ty = data.getType().dyn_cast(); + if (!data_ty || !data_ty.hasStaticShape()) return matchFailure(); + } + + // Compute type of each of the items and shape to use while reshaping inputs + // so that they can be unpacked to extract out individual items. + ArrayRef item_shape = out_ty.getShape().drop_front(1); + auto item_ty = RankedTensorType::get(item_shape, out_ty.getElementType()); + + SmallVector packed_shape; + packed_shape.push_back(-1); + packed_shape.append(item_shape.begin(), item_shape.end()); + Location loc = op.getLoc(); + auto packed_shape_val = rewriter.create( + loc, GetI64ElementsAttr(packed_shape, &rewriter)); + + // Prepare each of the output item by unpacking data and then putting it to + // the specified index. + SmallVector values(out_ty.getDimSize(0)); + for (auto it : llvm::zip(indices, op.data())) { + DenseIntElementsAttr index_attr = std::get<0>(it); + Value data = std::get<1>(it); + + auto reshaped_data = + rewriter.create(loc, data, packed_shape_val); + auto num_items = + reshaped_data.getType().cast().getShape()[0]; + auto items = rewriter.create( + loc, SmallVector(num_items, item_ty), reshaped_data, + /*axis=*/APInt(64, 0)); + for (auto index_item : llvm::zip(index_attr, items.getResults())) { + int64_t output_index = std::get<0>(index_item).getSExtValue(); + Value item = std::get<1>(index_item); + values[output_index] = item; + } + } + + auto axis = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + rewriter.replaceOpWithNewOp(op, op.getType(), values, axis); + return matchSuccess(); + } +}; + // Lowers Pack op to ConcatV2 op after changing shape of the inputs with // ExpandDims op. // @@ -159,13 +264,13 @@ class LowerPackOp : public OpRewritePattern { int64_t axis = op.axis().getSExtValue(); Type prev_input_ty, inferred_ty; - SmallVector expanded_inputs; + SmallVector expanded_inputs; expanded_inputs.reserve(op.N()); - for (Value *input : op.values()) { + for (Value input : op.values()) { // If input type is different than the previous input type, infer the // output type. Otherwise, use the already inferred output type from the // previous iteration. - Type input_ty = input->getType(); + Type input_ty = input.getType(); if (input_ty != prev_input_ty) { inferred_ty = InferExpandDimsType(input_ty, axis, &rewriter); prev_input_ty = input_ty; @@ -184,8 +289,7 @@ class LowerPackOp : public OpRewritePattern { void PopulateLoweringTFPatterns(MLIRContext *context, OwningRewritePatternList *patterns) { - patterns->insert(context); - patterns->insert(context); + patterns->insert(context); populateWithGenerated(context, patterns); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h index 4b85ac3b46a..b72b0f25938 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LOWER_TF_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LOWER_TF_H_ -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project namespace mlir { namespace TF { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td index 069bc07f4a1..ec0ac5e3c1e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td @@ -21,13 +21,23 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" class GetScalarOfType : NativeCodeCall< "GetScalarOfType(getElementTypeOrSelf($0)," # value # ")">; +class GetScalarOfFloatType : NativeCodeCall< + "GetScalarOfFloatType(getElementTypeOrSelf($0)," # value # ")">; + +def GetScalarNanOfType : NativeCodeCall< + "GetScalarOfFloatType(getElementTypeOrSelf($0), " + "std::numeric_limits::quiet_NaN())">; + +class GetI64ScalarElementsAttr : + NativeCodeCall<"GetI64ElementsAttr({" # value # "}, &$_builder)">; + //===----------------------------------------------------------------------===// // BiasAddGrad op patterns. //===----------------------------------------------------------------------===// def GetBiasAddGradReductionIndices : NativeCodeCall< "GetBiasAddGradReductionIndices(" - "$0->getType().cast().getRank(), $1, &$_builder)">; + "$0.getType().cast().getRank(), $1, &$_builder)">; def LowerBiasAddGradOp : Pat<(TF_BiasAddGradOp AnyRankedTensor:$out_backprop, $data_format), @@ -56,19 +66,57 @@ def LowerBiasAddGradOp : // TODO(hinsu): Support scalar inputs by introducing reshape to 1D. def NonScalarType : Type>, "Non scalar type">; -def GetLastDimReductionAxis : - NativeCodeCall<"GetI64ElementsAttr({-1}, &$_builder)">; - def LowerSoftmaxCrossEntropyWithLogitsOp : Pattern< (TF_SoftmaxCrossEntropyWithLogitsOp AnyRankedTensor:$features, AnyRankedTensor:$labels), [(TF_SumOp (TF_MulOp:$sum_input (TF_NegOp $labels), (TF_LogSoftmaxOp $features)), - (TF_ConstOp (GetLastDimReductionAxis)), + (TF_ConstOp (GetI64ScalarElementsAttr<-1>)), /*keep_dims=*/ConstBoolAttrFalse), (TF_SubOp (TF_SoftmaxOp $features), $labels)], [(NonScalarType $features), (NonScalarType $labels)]>; +// Returns size of the specified dimension as scalar elements attribute of type +// $1. +// Requires $0 to be of RankedTensorType with rank greater than `dim` and the +// dimension should be known. +class GetDimSizeOfType : NativeCodeCall< + "GetScalarOfType(getElementTypeOrSelf($1), " + "$0.getType().cast().getDimSize(" # dim # "))">; + +// Same as the above with i32 element type. +class GetDimSizeAsI32 : NativeCodeCall< + "GetScalarOfType($_builder.getIntegerType(32), " + "$0.getType().cast().getDimSize(" # dim # "))">; + +// Sparse version of SoftmaxCrossEntropyWithLogits is lowered to dense by +// expanding the sparse labels using: +// +// labels = OneHotOp(sparse_labels, depth, 1.0, 0.0) +// +// If any of the indices are out of range, we must populate the labels with +// NaNs to follow the semantics of the op. +def LowerSparseSoftmaxCrossEntropyWithLogitsOp : Pattern< + (TF_SparseSoftmaxCrossEntropyWithLogitsOp:$src_op + AnyStaticShapeTensor:$features, $sparse_labels), + [(TF_OneHotOp:$labels $sparse_labels, + (TF_ConstOp (GetDimSizeAsI32<1> $features, $src_op__0)), + (TF_ConstOp (GetScalarOfType<1> $features)), + (TF_ConstOp (GetScalarOfType<0> $features)), + ConstantAttr), + (TF_SelectV2Op:$zero_or_nan + (TF_LogicalAndOp + (TF_LessEqualOp + (TF_ConstOp (GetScalarOfType<0> $sparse_labels)), $sparse_labels), + (TF_LessOp $sparse_labels, + (TF_ConstOp (GetDimSizeOfType<1> $features, $sparse_labels)))), + (TF_ConstOp (GetScalarOfType<0> $features)), + (TF_ConstOp (GetScalarNanOfType $labels))), + (TF_AddV2Op:$adjusted_labels $labels, + (TF_ExpandDimsOp $zero_or_nan, + (TF_ConstOp (GetI64ScalarElementsAttr<-1>)))), + (TF_SoftmaxCrossEntropyWithLogitsOp $features, $adjusted_labels)]>; + //===----------------------------------------------------------------------===// // Difference op patterns. //===----------------------------------------------------------------------===// @@ -112,7 +160,7 @@ def LowerFillOp : Pat<(TF_FillOp $dims, $value), def GetAllAxes : NativeCodeCall< "GetI64ElementsAttrForSeq(" - "0, $0->getType().cast().getRank(), &$_builder)">; + "0, $0.getType().cast().getRank(), &$_builder)">; // L2Loss is lowered using the formula, // L2Loss(input) = Sum(input * input) / 2 @@ -135,6 +183,14 @@ def : Pat<(TF_PadOp TensorOf<[AnyInteger, AnyFloat]>:$input, $paddings), (TF_PadV2Op $input, $paddings, (TF_ConstOp (GetScalarOfType<0> $input)))>; +//===----------------------------------------------------------------------===// +// Reciprocal op patterns. +//===----------------------------------------------------------------------===// + +// TODO(hinsu): Support complex and unsigned input types. +def LowerReciprocal : Pat<(TF_ReciprocalOp TF_SintOrFpTensor:$x), + (TF_DivOp (TF_ConstOp (GetScalarOfType<1> $x)), $x)>; + //===----------------------------------------------------------------------===// // Rsqrt op patterns. //===----------------------------------------------------------------------===// @@ -164,7 +220,7 @@ def LowerTanhGradOp : //===----------------------------------------------------------------------===// def CreateTFShapeOp : NativeCodeCall< - "$_builder.create($0->getLoc(), $1, $2)">; + "$_builder.create($0.getLoc(), $1, $2)">; // TODO(hinsu): Support inputs of TensorList types. def LowerZerosLikeOp : diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf_pass.cc index 309d0147bc0..be9e0f4aef4 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf_pass.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/materialize_mlir_passthrough_op.cc b/tensorflow/compiler/mlir/tensorflow/transforms/materialize_mlir_passthrough_op.cc index 58dfab15d34..f9a459647c8 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/materialize_mlir_passthrough_op.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/materialize_mlir_passthrough_op.cc @@ -17,16 +17,16 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" -#include "mlir/IR/Block.h" // TF:local_config_mlir -#include "mlir/IR/Diagnostics.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/IR/OpDefinition.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/Types.h" // TF:local_config_mlir -#include "mlir/IR/Value.h" // TF:local_config_mlir -#include "mlir/Parser.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir +#include "mlir/IR/Block.h" // TF:llvm-project +#include "mlir/IR/Diagnostics.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/OpDefinition.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/IR/Value.h" // TF:llvm-project +#include "mlir/Parser.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/PassRegistry.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #define DEBUG_TYPE "tf-materialize-passthrough-op" @@ -79,7 +79,7 @@ void MaterializePassthroughOpPass::runOnFunction() { Block &block = body.front(); for (const auto &arg_mapping : llvm::zip(block.getArguments(), op->getOperands())) { - std::get<0>(arg_mapping)->replaceAllUsesWith(std::get<1>(arg_mapping)); + std::get<0>(arg_mapping).replaceAllUsesWith(std::get<1>(arg_mapping)); } op->getBlock()->getOperations().splice(op->getIterator(), block.getOperations(), block.begin(), @@ -87,7 +87,7 @@ void MaterializePassthroughOpPass::runOnFunction() { Operation &return_op = block.front(); for (auto ret_mapping : llvm::zip(op->getResults(), return_op.getOperands())) { - std::get<0>(ret_mapping)->replaceAllUsesWith(std::get<1>(ret_mapping)); + std::get<0>(ret_mapping).replaceAllUsesWith(std::get<1>(ret_mapping)); } op->erase(); }); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc index 6e28b19ad80..a52b30e2fd2 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc @@ -14,14 +14,14 @@ limitations under the License. ==============================================================================*/ #include -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Pass/PassManager.h" // TF:local_config_mlir -#include "mlir/Transforms/Passes.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/PassManager.h" // TF:llvm-project +#include "mlir/Transforms/Passes.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/utils/validators.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td index 6c11067ce7a..5681b78882a 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td @@ -21,7 +21,7 @@ def BroadcastableElements : Constraint>; def F32ElementsAttr : ElementsAttrBase< CPred<"$_self.cast().getType().getElementType().isF32()">, "float constant tensor">; -def DefinedByConv2D : Constraint($0->getDefiningOp())">>; +def DefinedByConv2D : Constraint($0.getDefiningOp())">>; // If we see a Conv2D op followed by Mul, then multiply the filter // with the value in Mul. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc index e7acbb334ea..40f084af46b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc @@ -20,9 +20,9 @@ limitations under the License. #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" @@ -52,9 +52,9 @@ using GlobalTensorUsesMap = // be keep in sync. bool IsReadOnlyVariableOp(Operation* op) { return isa(op); } -void RewriteReadOnlyVariableOpToTensorOp(Operation* op, Value* tensor_value) { +void RewriteReadOnlyVariableOpToTensorOp(Operation* op, Value tensor_value) { auto read_variable = cast(op); - read_variable.value()->replaceAllUsesWith(tensor_value); + read_variable.value().replaceAllUsesWith(tensor_value); } bool IsFreezable(GlobalTensorOp global_tensor, @@ -73,8 +73,8 @@ bool IsFreezable(GlobalTensorOp global_tensor, // func for tf.ReadVariableOp. If the resource is passed into other functions // or control flow, we fail to prove it is freezable even though we could. for (auto& global_tensor_use : global_tensor_uses) { - auto* arg = global_tensor_use.func.getArgument(global_tensor_use.arg_index); - for (auto user : arg->getUsers()) { + auto arg = global_tensor_use.func.getArgument(global_tensor_use.arg_index); + for (auto user : arg.getUsers()) { if (!IsReadOnlyVariableOp(user)) { return false; } @@ -129,13 +129,13 @@ void FreezeGlobalTensors(ModuleOp module, for (auto global_tensor_use : global_tensor_uses) { auto func = global_tensor_use.func; auto arg_index = global_tensor_use.arg_index; - Value* arg = func.getArgument(arg_index); - for (Operation* user : llvm::make_early_inc_range(arg->getUsers())) { + Value arg = func.getArgument(arg_index); + for (Operation* user : llvm::make_early_inc_range(arg.getUsers())) { RewriteReadOnlyVariableOpToTensorOp(user, arg); user->erase(); } Type new_type = global_tensor.value().Attribute::getType(); - arg->setType(new_type); + arg.setType(new_type); auto old_ftype = func.getType(); auto input_types = old_ftype.getInputs().vec(); input_types[arg_index] = new_type; @@ -168,7 +168,7 @@ void EraseUnusedBoundInputs(ModuleOp module) { SmallVector args_to_erase; for (int i = 0, e = func.getNumArguments(); i < e; i++) { if (func.getArgAttr(i, "tf_saved_model.bound_input") && - func.getArgument(i)->use_empty()) { + func.getArgument(i).use_empty()) { args_to_erase.push_back(i); } } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index c9c97735848..180e87eba46 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/Pass/Pass.h" // TF:llvm-project namespace mlir { @@ -46,7 +46,8 @@ std::unique_ptr> CreateTFShapeInferencePass(); // Optimizes Tensorflow graph. std::unique_ptr> CreateTFOptimizePass(); -struct StandardPipelineOptions : public PassOptions { +struct StandardPipelineOptions + : public PassPipelineOptions { Option enable_inliner{*this, "enable-inliner", llvm::cl::desc("Enable inliner."), llvm::cl::init(false)}; @@ -79,8 +80,7 @@ std::unique_ptr> CreateSwitchFoldPass(); std::unique_ptr> CreateTFExecutorIslandCoarseningPass(); // Create a pass to prune tf_executor.graph from dead nodes. -std::unique_ptr> CreateTFExecutorGraphPruningPass( - bool skip_main_func = false); +std::unique_ptr> CreateTFExecutorGraphPruningPass(); // Prunes unreachable operations of a tf_executor.graph operation. void PruneGraph(GraphOp graph); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/raise_control_flow.cc b/tensorflow/compiler/mlir/tensorflow/transforms/raise_control_flow.cc index d6acb7488e1..55cb1e2c3df 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/raise_control_flow.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/raise_control_flow.cc @@ -22,9 +22,9 @@ limitations under the License. // eliminating control dependencies, and results in the code being in the // canonical TensorFlow dialect. -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" @@ -100,7 +100,7 @@ void RaiseTFControlFlow::rewriteOps() { // aren't necessary any more since the order within a block encodes the // same information. for (auto &operand : op.getOpOperands()) { - if (!operand.get()->getType().isa()) + if (!operand.get().getType().isa()) result.operands.push_back(operand.get()); // Drop all operands from the old operation, eliminating any @@ -110,14 +110,14 @@ void RaiseTFControlFlow::rewriteOps() { // Add a result type for each non-control result we find. bool sawControlResult = false; - for (auto *opResult : op.getResults()) { - if (opResult->getType().isa()) { + for (auto opResult : op.getResults()) { + if (opResult.getType().isa()) { sawControlResult = true; } else { // We assume all control inputs are at the end of the result list. assert(!sawControlResult && "all control results must be last"); (void)sawControlResult; - result.types.push_back(opResult->getType()); + result.types.push_back(opResult.getType()); } } @@ -129,7 +129,7 @@ void RaiseTFControlFlow::rewriteOps() { // We know that all the control results are last, so we can just rewrite // the first results. for (unsigned i = 0, e = result.types.size(); i != e; ++i) - op.getResult(i)->replaceAllUsesWith(replacement->getResult(i)); + op.getResult(i).replaceAllUsesWith(replacement->getResult(i)); } } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc index 36f6f3a933c..7b4ae38726d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc @@ -20,11 +20,11 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Value.h" // TF:local_config_mlir -#include "mlir/IR/Visitors.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Value.h" // TF:llvm-project +#include "mlir/IR/Visitors.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -71,19 +71,19 @@ struct ReplicateInvariantOpHoistingPass // } void MakeShapeOpInvariant(tf_device::ReplicateOp replicate_op, int num_replicas, Block* replicate_block, TF::ShapeOp shape_op) { - Value* input = shape_op.input(); + Value input = shape_op.input(); // If ShapeOp operand is replicate tensor block argument, replace with the // associated first replica operand. - if (auto block_arg = llvm::dyn_cast(input)) { - if (block_arg->getOwner() != replicate_block) return; + if (auto block_arg = input.dyn_cast()) { + if (block_arg.getOwner() != replicate_block) return; shape_op.setOperand( - replicate_op.getOperand(num_replicas * block_arg->getArgNumber())); + replicate_op.getOperand(num_replicas * block_arg.getArgNumber())); return; } - Operation* input_def = input->getDefiningOp(); + Operation* input_def = input.getDefiningOp(); // If ShapeOp operand is a ReadVariableOp result where the ReadVariableOp // operand is a replicate resource block argument, replace ShapeOp with @@ -96,13 +96,13 @@ void MakeShapeOpInvariant(tf_device::ReplicateOp replicate_op, int num_replicas, // shape has not changed in replicate prior to read. Currently after both // ResourceOpLiftingPass and TPURewritePass, there should not be any updates // to resources prior to their respective ReadVariableOp. - if (auto block_arg = llvm::dyn_cast(read_var_op.resource())) { - if (block_arg->getOwner() != replicate_block) return; + if (auto block_arg = read_var_op.resource().dyn_cast()) { + if (block_arg.getOwner() != replicate_block) return; OpBuilder builder(shape_op); auto new_shape_op = builder.create( shape_op.getLoc(), shape_op.getType(), - replicate_op.getOperand(num_replicas * block_arg->getArgNumber())); + replicate_op.getOperand(num_replicas * block_arg.getArgNumber())); shape_op.replaceAllUsesWith(new_shape_op.getOperation()); shape_op.erase(); } @@ -111,8 +111,8 @@ void MakeShapeOpInvariant(tf_device::ReplicateOp replicate_op, int num_replicas, // Checks if op and inner op operands are all replicate invariant. bool IsOpReplicateInvariant(Region* replicate_region, Operation* op) { auto result = op->walk([&](Operation* inner_op) { - for (Value* operand : inner_op->getOperands()) { - Region* parent_region = operand->getParentRegion(); + for (Value operand : inner_op->getOperands()) { + Region* parent_region = operand.getParentRegion(); if (!parent_region || !parent_region->isProperAncestor(replicate_region)) return WalkResult::interrupt(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc index 9787ac0f0f0..ec0125b913d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc @@ -24,13 +24,13 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Block.h" // TF:local_config_mlir -#include "mlir/IR/BlockAndValueMapping.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Diagnostics.h" // TF:local_config_mlir -#include "mlir/IR/Dialect.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Block.h" // TF:llvm-project +#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Diagnostics.h" // TF:llvm-project +#include "mlir/IR/Dialect.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" @@ -60,7 +60,7 @@ llvm::SmallVector ExpandReplicateIntoReplicas( Operation& terminator = replicate_op.GetBody().back(); llvm::SmallVector output_types(terminator.getOperandTypes()); auto control_type = tf_executor::ControlType::get(island_op.getContext()); - llvm::SmallVector replica_inputs(island_op.controlInputs()); + llvm::SmallVector replica_inputs(island_op.controlInputs()); // Replace replicate terminator with YieldOp. builder->setInsertionPoint(&terminator); @@ -83,7 +83,7 @@ llvm::SmallVector ExpandReplicateIntoReplicas( mapping.clear(); for (auto& block_arg : replicate_op.GetBody().getArguments()) mapping.map(block_arg, replicate_op.getOperand( - block_arg->getArgNumber() * num_replicas + i)); + block_arg.getArgNumber() * num_replicas + i)); // Copy over replicate region into replica island. replicate_op.body().cloneInto(&replica.body(), mapping); @@ -149,8 +149,8 @@ void CreateIslandsFromReplicate(const Dialect* tf_dialect, num_replicas); // Collect all replica results. - llvm::SmallVector replicas_outputs(replicate_op.getNumResults(), - nullptr); + llvm::SmallVector replicas_outputs(replicate_op.getNumResults(), + nullptr); for (auto replica_and_idx : llvm::enumerate(replicas)) for (auto replica_result_and_idx : llvm::enumerate(replica_and_idx.value().outputs())) @@ -163,7 +163,7 @@ void CreateIslandsFromReplicate(const Dialect* tf_dialect, // Collect per replica control dependency and add to island operand if replica // island has no uses. - llvm::SmallVector island_operands; + llvm::SmallVector island_operands; for (auto& replica : replicas) if (replica.use_empty()) island_operands.push_back(replica.control()); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc index 6dc3e87f8ec..c92ce1f01ad 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc @@ -26,16 +26,16 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Casting.h" -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/Types.h" // TF:local_config_mlir -#include "mlir/IR/Value.h" // TF:local_config_mlir -#include "mlir/IR/Visitors.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir -#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/IR/Value.h" // TF:llvm-project +#include "mlir/IR/Visitors.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/PassRegistry.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" @@ -64,7 +64,7 @@ class PerFunctionResult { // Returns the recorded device assignment for a resource, if any. llvm::Optional DeviceForResource( - const Value* resource) const { + const Value resource) const { llvm::Optional result; if (alias_analysis_.IsUnknownResource(resource)) return result; for (int64_t id : alias_analysis_.GetResourceUniqueIds(resource)) { @@ -87,7 +87,7 @@ class PerFunctionResult { // conflicts with an existing one, returns an error. // // If `changed` is provided, assign *changed to true if anything is modified. - LogicalResult AddResourceDevice(const Value* resource, llvm::StringRef device, + LogicalResult AddResourceDevice(const Value resource, llvm::StringRef device, bool* changed = nullptr) { if (alias_analysis_.IsUnknownResource(resource)) return success(); for (int64_t id : alias_analysis_.GetResourceUniqueIds(resource)) { @@ -108,7 +108,7 @@ class PerFunctionResult { }; // Tries to record device assignment for a resource. -LogicalResult AddResourceDeviceAndEmitError(const Value* resource, +LogicalResult AddResourceDeviceAndEmitError(const Value resource, llvm::StringRef device, Operation* error_reporting_op, PerFunctionResult* result, @@ -127,16 +127,16 @@ LogicalResult ComputeResourceDevicesInComputation(FuncOp func_op, OpBuilder builder(func_op); // Function arguments. for (auto arg : func_op.getArguments()) { - if (!mlir::getElementTypeOrSelf(arg->getType()).isa()) { + if (!mlir::getElementTypeOrSelf(arg.getType()).isa()) { continue; } auto device_attr = func_op.getArgAttrOfType( - arg->getArgNumber(), kFuncDeviceAttr); + arg.getArgNumber(), kFuncDeviceAttr); if (!device_attr || device_attr.getValue() == "") { // If device_attr does not exist, try to construct it from any recorded // assignment. if (auto device = result->DeviceForResource(arg)) { - func_op.setArgAttr(arg->getArgNumber(), kFuncDeviceAttr, + func_op.setArgAttr(arg.getArgNumber(), kFuncDeviceAttr, builder.getStringAttr(*device)); } continue; @@ -160,7 +160,7 @@ LogicalResult ComputeResourceDevicesInComputation(FuncOp func_op, } if (auto identity = llvm::dyn_cast(op)) { // Try to construct IdentityOp's attribute from recorded assignment. - if (!mlir::getElementTypeOrSelf(identity.output()->getType()) + if (!mlir::getElementTypeOrSelf(identity.output().getType()) .isa()) { return WalkResult::advance(); } @@ -176,7 +176,7 @@ LogicalResult ComputeResourceDevicesInComputation(FuncOp func_op, // Propagate and record output device assignment for other ops based on // existing recording. E.g., IdentityN. for (auto output : op->getResults()) { - if (!mlir::getElementTypeOrSelf(output->getType()) + if (!mlir::getElementTypeOrSelf(output.getType()) .isa()) { continue; } @@ -212,7 +212,7 @@ void ResourceDeviceInference::runOnModule() { for (auto operand_and_argument : llvm::zip(caller_operands, callee.getArguments())) { if (!mlir::getElementTypeOrSelf( - std::get<0>(operand_and_argument)->getType()) + std::get<0>(operand_and_argument).getType()) .isa()) { continue; } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc index 2f32a3a2c28..70a69a36adf 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc @@ -19,13 +19,13 @@ limitations under the License. #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" -#include "mlir/IR/BlockAndValueMapping.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Diagnostics.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Transforms/RegionUtils.h" // TF:local_config_mlir +#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Diagnostics.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" @@ -87,26 +87,26 @@ void ForwardStoreToLoad(tf_device::LaunchOp launch_op) { // resource_handle_to_last_store_op keeps track of the most recent (last) // store to each resource. Non-existent entry indicates that a resource has // not been stored to yet. - llvm::SmallDenseMap + llvm::SmallDenseMap resource_handle_to_last_store_op; // Only iterate through ops directly in launch_op's body as we can't handle // ops nested deeper in regions. for (Operation& op : llvm::make_early_inc_range(launch_op.GetBody())) { if (auto read_variable_op = dyn_cast(&op)) { - Value* resource = read_variable_op.resource(); + Value resource = read_variable_op.resource(); auto last_store = resource_handle_to_last_store_op[resource]; if (!last_store) continue; // Use stored value in last_store to replace all uses of current resource // load's result, then erase this resource load. - read_variable_op.value()->replaceAllUsesWith(last_store.value()); + read_variable_op.value().replaceAllUsesWith(last_store.value()); read_variable_op.erase(); continue; } if (auto assign_variable_op = dyn_cast(&op)) { - Value* resource = assign_variable_op.resource(); + Value resource = assign_variable_op.resource(); auto last_store = resource_handle_to_last_store_op[resource]; // Previous store ops to same resource can be erased. if (last_store) last_store.erase(); @@ -120,17 +120,17 @@ void ForwardStoreToLoad(tf_device::LaunchOp launch_op) { // forwarding has been performed on this launch_op such that all loads of same // resource are on its initial values. void HoistResourceLoads(tf_device::LaunchOp launch_op) { - llvm::SmallDenseMap resource_to_read_ops; + llvm::SmallDenseMap resource_to_read_ops; // Only iterate through ops directly in launch_op's body as we can't handle // ops nested deeper in regions. for (Operation& op : llvm::make_early_inc_range(launch_op.GetBody())) { auto read_variable_op = dyn_cast(&op); if (!read_variable_op) continue; - Value* resource = read_variable_op.resource(); + Value resource = read_variable_op.resource(); // Skip resources created inside of launch_op. - if (resource->getParentRegion() == &launch_op.body()) continue; + if (resource.getParentRegion() == &launch_op.body()) continue; auto p = resource_to_read_ops.insert({resource, read_variable_op}); if (p.second) { @@ -156,18 +156,18 @@ bool AppendResourceStoreValueToReturn(tf_device::LaunchOp launch_op) { Block* body = &launch_op.GetBody(); auto old_return = body->getTerminator(); - llvm::SmallVector new_return_operands(old_return->getOperands()); + llvm::SmallVector new_return_operands(old_return->getOperands()); // Only iterate through ops directly in launch_op's body as we can't handle // ops nested deeper in regions. for (Operation& op : launch_op.GetBody()) { auto assign_variable_op = dyn_cast(&op); if (!assign_variable_op) continue; - Value* resource = assign_variable_op.resource(); + Value resource = assign_variable_op.resource(); if (!resource) continue; // Skip resources created inside of launch_op. - if (resource->getParentRegion() == &launch_op.body()) continue; + if (resource.getParentRegion() == &launch_op.body()) continue; // TODO(ycao): Prevent same value from being returned multiple times. // TODO(ycao): Do not return resource store value if it is defined outside @@ -202,12 +202,12 @@ void SinkResourceStores(tf_device::LaunchOp launch_op, OpBuilder* builder) { builder->setInsertionPoint(launch_op); auto new_launch_op = builder->create( launch_op.getLoc(), new_launch_return_types, - /*operands=*/llvm::SmallVector(), launch_op.getAttrs()); + /*operands=*/llvm::SmallVector(), launch_op.getAttrs()); new_launch_op.body().takeBody(launch_op.body()); // Replace uses of old launch_op results with those of new_launch_op. for (auto p : llvm::zip(launch_op.getResults(), new_launch_op.getResults())) { - std::get<0>(p)->replaceAllUsesWith(std::get<1>(p)); + std::get<0>(p).replaceAllUsesWith(std::get<1>(p)); } // Create a mapping from operands of new_return_op operands to new_launch_op diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index 39b7fbb4d07..4f69d18a96b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -22,28 +22,33 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/iterator_range.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/Block.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Diagnostics.h" // TF:local_config_mlir -#include "mlir/IR/Location.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/SymbolTable.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir -#include "mlir/Support/LLVM.h" // TF:local_config_mlir -#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir -#include "mlir/Transforms/FoldUtils.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Block.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Diagnostics.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/SymbolTable.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/PassRegistry.h" // TF:llvm-project +#include "mlir/Support/LLVM.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/Transforms/FoldUtils.h" // TF:llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/types.pb.h" #define DEBUG_TYPE "tf-shape-inference" @@ -68,29 +73,101 @@ Optional> InferShapeForFunctionReturnType( // Manually fold tf.Cast that precedes the return instruction and only differs // in shape refinement level. for (OpOperand& arg_op : return_op.getOperation()->getOpOperands()) { - Operation* arg_defining_op = arg_op.get()->getDefiningOp(); + Operation* arg_defining_op = arg_op.get().getDefiningOp(); if (auto cast_op = dyn_cast_or_null(arg_defining_op)) { // Shape inference should not change the element type. if (cast_op.SrcT() != cast_op.DstT()) continue; // We only refine the result shape if the result a dynamic shape, the // input has static shape, and the two shapes are compatible. - auto has_static_shape = [](const Value* value) { - auto shaped_type = value->getType().dyn_cast(); + auto has_static_shape = [](const Value value) { + auto shaped_type = value.getType().dyn_cast(); return shaped_type && shaped_type.hasStaticShape(); }; - Value* input = cast_op.x(); - Value* result = cast_op.y(); + Value input = cast_op.x(); + Value result = cast_op.y(); if (!has_static_shape(input) || has_static_shape(result) || - failed(verifyCompatibleShape(input->getType(), result->getType()))) + failed(verifyCompatibleShape(input.getType(), result.getType()))) continue; arg_op.set(cast_op.x()); - if (cast_op.y()->use_empty()) cast_op.erase(); + if (cast_op.y().use_empty()) cast_op.erase(); } } return llvm::to_vector<4>(return_op.getOperandTypes()); } + +// Returns if the shape inference pass supports an op outside the TF dialect. +bool IsSupportedNonTFOp(Operation* op) { + return isa(op) || isa(op) || + isa(op) || isa(op) || + isa(op) || isa(op); +} + +// Inserts tf.Cast operation when changing the type of a result if the user is +// not a TF operation, as we can't guarantee that the new type will be OK. +void AddCastBackForUnsupportedNonTFUses(Operation* op, Value result, + Dialect* tf_dialect, Type old_type) { + OpBuilder builder(op); + builder.setInsertionPointAfter(op); + // A tf.Cast operation is lazily created on the first uses that isn't a TF + // operation. + TF::CastOp cast_op; + auto get_cast_op = [&]() { + if (!cast_op) + cast_op = + builder.create(op->getLoc(), old_type, result, + /*truncate=*/builder.getBoolAttr(false)); + return cast_op; + }; + for (OpOperand& use : llvm::make_early_inc_range(result->getUses())) { + if (use.getOwner()->getDialect() != tf_dialect && + !IsSupportedNonTFOp(use.getOwner())) + use.set(get_cast_op()); + } +} + +// Extracts a PartialTensorShape from the MLIR type. +Optional GetShapeFromMlirType(Type t) { + if (auto ranked_type = t.dyn_cast()) { + // Convert the MLIR shape indices (int64_t) to TensorFlow indices + // (int64). + ArrayRef shape = ranked_type.getShape(); + SmallVector tf_shape(shape.begin(), shape.end()); + return tensorflow::PartialTensorShape({tf_shape.data(), tf_shape.size()}); + } + return None; +} + +// Passes the operand shapes/types to the op's results. +bool InferShapeForPassThroughOps(OperandRange pass_through_operands, + Operation* op, Dialect* tf_dialect) { + bool changed = false; + for (auto entry : llvm::zip(pass_through_operands, op->getResults())) { + Type operand_type = std::get<0>(entry).getType(); + Value result = std::get<1>(entry); + if (result.getType() == operand_type) continue; + AddCastBackForUnsupportedNonTFUses(op, result, tf_dialect, + result.getType()); + result.setType(operand_type); + changed = true; + } + return changed; +} + +// Infers shape for necessary ops that are not in the TF dialect. +bool InferShapeForNonTFDialectOperation(Operation* op, Dialect* tf_dialect) { + if (auto graph_op = dyn_cast(op)) { + return InferShapeForPassThroughOps(graph_op.GetFetch().fetches(), op, + tf_dialect); + } + if (auto island_op = dyn_cast(op)) { + return InferShapeForPassThroughOps(island_op.GetYield().fetches(), op, + tf_dialect); + } + return false; +} + } // namespace bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, @@ -98,9 +175,13 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, assert(tf_dialect == op->getDialect()); // If no result for this op needs shape inference, we have a fast-path return. + // But if the type is a resource, we do not skip it because we might not have + // the handle shapes. if (llvm::all_of(op->getResultTypes(), [](Type type) { auto shape_type = type.dyn_cast(); - return !shape_type || shape_type.hasStaticShape(); + return !shape_type || + (shape_type.hasStaticShape() && + !shape_type.getElementType().isa()); })) { LLVM_DEBUG(llvm::dbgs() << "Skipping inference for statically shaped op '" << op->getName() << "'.\n";); @@ -111,7 +192,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, // This is necessary to avoid reprocessing the tf.Cast that are inserted at // the end of this function. if (isa(op) && - llvm::all_of(op->getResult(0)->getUsers(), [&](Operation* user) { + llvm::all_of(op->getResult(0).getUsers(), [&](Operation* user) { return user->getDialect() != tf_dialect; })) { LLVM_DEBUG(llvm::dbgs() << "Skipping inference for tf.Cast with no TF " @@ -127,10 +208,9 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, // Get information from the registry and check if we have a shape function for // this op. - const tensorflow::OpRegistrationData* op_reg_data; - if (!tensorflow::OpRegistry::Global() - ->LookUp(node_name.data(), &op_reg_data) - .ok()) { + const tensorflow::OpRegistrationData* op_reg_data = + tensorflow::OpRegistry::Global()->LookUp(node_name.data()); + if (!op_reg_data) { LLVM_DEBUG(llvm::dbgs() << "Skipping inference for unregistered op '" << op->getName() << "'.\n";); return false; @@ -161,8 +241,11 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, std::vector input_shapes( op->getNumOperands()); std::vector tensors(op->getNumOperands()); + std::vector>>> + handle_shapes_and_types(op->getNumOperands()); for (auto it : llvm::enumerate(op->getOperands())) { - Value* operand = it.value(); + Value operand = it.value(); size_t index = it.index(); // If the operand is constant, then convert it to Tensor. @@ -179,13 +262,32 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, } } - Type operand_type = operand->getType(); - if (auto ranked_type = operand_type.dyn_cast()) { - // Convert the MLIR shape indices (int64_t) to TensorFlow indices (int64). - ArrayRef shape = ranked_type.getShape(); - SmallVector tf_shape(shape.begin(), shape.end()); - input_shapes[index] = - tensorflow::PartialTensorShape({tf_shape.data(), tf_shape.size()}); + Type operand_type = operand.getType(); + if (auto shape = GetShapeFromMlirType(operand_type)) { + input_shapes[index] = *shape; + } + // Collect the handle shapes and types for a resource. + if (auto resource_type = operand_type.cast() + .getElementType() + .dyn_cast()) { + if (resource_type.getSubtypes().empty()) continue; + auto shapes_and_types = absl::make_unique>>(); + for (auto subtype : resource_type.getSubtypes()) { + auto shape = GetShapeFromMlirType(subtype); + // handle_shapes_and_types requires all shapes to be known. So if any + // subtype is unknown, clear the vector. + if (!shape) { + shapes_and_types = nullptr; + break; + } + tensorflow::DataType dtype; + auto status = + tensorflow::ConvertToDataType(subtype.getElementType(), &dtype); + assert(status.ok() && "Unknown element type"); + shapes_and_types->emplace_back(*shape, dtype); + } + handle_shapes_and_types[index] = std::move(shapes_and_types); } } @@ -194,8 +296,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, // function operates on. tensorflow::shape_inference::InferenceContext c( graph_version, *node_def, op_reg_data->op_def, input_shapes, - input_tensors, /*input_tensors_as_shapes=*/{}, - /*input_handle_shapes_and_types=*/{}); + input_tensors, /*input_tensors_as_shapes=*/{}, handle_shapes_and_types); auto status = c.Run(op_reg_data->shape_inference_fn); if (!status.ok()) { LLVM_DEBUG(llvm::dbgs() << "Shape inference error for '" << *op @@ -207,47 +308,52 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, "inference context matches the MLIR number of results."); // Update the shape for each of the operation result if the InferenceContext - // has more precise shapes recorded. A builder is used to insert tf.Cast - // operation when changing the type of a result is the user is not a TF - // operation, as we can't guarantee that the new type will be OK. + // has more precise shapes recorded. bool changed = false; - OpBuilder builder(op); - builder.setInsertionPointAfter(op); for (int output : llvm::seq(0, c.num_outputs())) { // Skip already statically shaped results. - Value* result = op->getResult(output); - auto shaped_type = result->getType().dyn_cast(); + Value result = op->getResult(output); + auto shaped_type = result.getType().dyn_cast(); if (!shaped_type || shaped_type.hasStaticShape()) continue; tensorflow::shape_inference::ShapeHandle shape_handle = c.output(output); LLVM_DEBUG(llvm::dbgs() << "Inferred output " << output << " : " << c.DebugString(shape_handle) << "\n"); - if (!c.RankKnown(shape_handle)) continue; - - // Convert the shape from TensorFlow (int64) to MLIR (int64_t). - SmallVector shape; - for (int dim : llvm::seq(0, c.Rank(shape_handle))) - shape.push_back(c.Value(c.Dim(shape_handle, dim))); - auto new_type = RankedTensorType::get(shape, shaped_type.getElementType()); - - // A tf.Cast operation is lazily created on the first uses that isn't a TF - // operation. - TF::CastOp cast_op; - auto get_cast_op = [&]() { - if (!cast_op) - cast_op = - builder.create(op->getLoc(), result->getType(), result, - /*truncate=*/builder.getBoolAttr(false)); - return cast_op; + auto get_tensor_type = + [&c](const tensorflow::shape_inference::ShapeHandle& sh, + Type element_type) -> TensorType { + if (!c.RankKnown(sh)) return UnrankedTensorType::get(element_type); + // Convert the shape from TensorFlow (int64) to MLIR (int64_t). + SmallVector shape; + for (int dim : llvm::seq(0, c.Rank(sh))) + shape.push_back(c.Value(c.Dim(sh, dim))); + return RankedTensorType::get(shape, element_type); }; - for (OpOperand& use : llvm::make_early_inc_range(result->getUses())) { - if (use.getOwner()->getDialect() != tf_dialect) use.set(get_cast_op()); + auto new_element_type = shaped_type.getElementType(); + // Populate the handle shapes for a resource. + if (auto resource_type = new_element_type.dyn_cast()) { + auto handle_shapes_types = c.output_handle_shapes_and_types(output); + if (handle_shapes_types) { + llvm::SmallVector subtypes; + OpBuilder b(op); + for (const auto& shape_n_type : *handle_shapes_types) { + Type element_type; + auto status = + tensorflow::ConvertDataType(shape_n_type.dtype, b, &element_type); + assert(status.ok() && "Unknown element type"); + subtypes.push_back(get_tensor_type(shape_n_type.shape, element_type)); + } + new_element_type = TF::ResourceType::get(subtypes, op->getContext()); + } } - - if (result->getType() == new_type) continue; - + auto new_type = get_tensor_type(shape_handle, new_element_type); + if (result.getType() == new_type) continue; + // Inserts a cast back to the original type if any user is not in the TF + // dialect. + AddCastBackForUnsupportedNonTFUses(op, result, tf_dialect, + result.getType()); // Finally we inferred the shape and replace the type for this result. - result->setType(new_type); + result.setType(new_type); changed = true; } if (changed) @@ -285,7 +391,7 @@ LogicalResult RefineShapeForControlFlowFunc(FuncOp func, func.getContext())); for (auto arg_and_idx : llvm::enumerate(func.getArguments())) { - arg_and_idx.value()->setType(input_types[arg_and_idx.index()]); + arg_and_idx.value().setType(input_types[arg_and_idx.index()]); } auto res = @@ -307,8 +413,8 @@ LogicalResult PropagateShapeToIfWhileOpFunctions( int64_t max_iteration) { llvm::SmallVector input_types; input_types.reserve(std::distance(op.input().begin(), op.input().end())); - for (Value* v : op.input()) { - input_types.push_back(v->getType()); + for (Value v : op.input()) { + input_types.push_back(v.getType()); } ModuleOp module = op.template getParentOfType(); @@ -360,7 +466,10 @@ LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version, LLVM_DEBUG(llvm::dbgs() << "Shape inference, iteration " << iteration << "\n"); region->walk([&](Operation* op) { - if (op->getDialect() != tf_dialect) return; + if (op->getDialect() != tf_dialect) { + changed |= InferShapeForNonTFDialectOperation(op, tf_dialect); + return; + } // Before attempting inference, just try to fold the operation. if (succeeded(folder.tryToFold(op))) return; @@ -415,7 +524,7 @@ LogicalResult InferShapeForFunction(FuncOp func, auto new_arg_type = mlir::RankedTensorType::get(shape, element_type); if (new_arg_type != func_type.getInput(i)) { // If the new type is more detailed, trigger shape inference. - func.getArgument(i)->setType(new_arg_type); + func.getArgument(i).setType(new_arg_type); needs_refinement = true; } new_arg_types.push_back(new_arg_type); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h index 0529e6414b7..73993a07292 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h @@ -18,10 +18,10 @@ limitations under the License. #include -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/Region.h" // TF:local_config_mlir -#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/Region.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project namespace mlir { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc index d5b86173b69..129efd74f4f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc @@ -20,15 +20,15 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Debug.h" -#include "mlir/IR/Block.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Location.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir -#include "mlir/Support/LLVM.h" // TF:local_config_mlir -#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "mlir/IR/Block.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/PassRegistry.h" // TF:llvm-project +#include "mlir/Support/LLVM.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" @@ -65,10 +65,9 @@ struct ShapeInference : public ModulePass { } for (auto func : module.getOps()) { InferShapeUntilFixPoint(&func.getBody(), producer.getInt()); - } - - if (auto main_func = module.lookupSymbol("main")) { - InferShapeForFunctionType(main_func); + // TODO(yuanzx): Verify that it is always fine to refine a function's + // return type, as long as we do not change the argument shapes. + InferShapeForFunctionType(func); } } }; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc b/tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc index e4358e7e1c7..9d872fb3d1a 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc @@ -19,11 +19,11 @@ limitations under the License. #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Pass/PassManager.h" // TF:local_config_mlir -#include "mlir/Support/LLVM.h" // TF:local_config_mlir -#include "mlir/Transforms/Passes.h" // TF:local_config_mlir -#include "mlir/Transforms/RegionUtils.h" // TF:local_config_mlir +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/PassManager.h" // TF:llvm-project +#include "mlir/Support/LLVM.h" // TF:llvm-project +#include "mlir/Transforms/Passes.h" // TF:llvm-project +#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -48,12 +48,11 @@ class ExecutorConstantSinking // The sunk_constant map keeps a mapping from a ConstOp defined above to // a sunk clone of it. This allows for reusing a sunk constant with // multiple uses in the region. - llvm::DenseMap sunk_constant; + llvm::DenseMap sunk_constant; Region &body = launch.body(); visitUsedValuesDefinedAbove(body, [&](OpOperand *use) { - Value *constant = use->get(); - auto const_op = - dyn_cast_or_null(constant->getDefiningOp()); + Value constant = use->get(); + auto const_op = dyn_cast_or_null(constant.getDefiningOp()); if (!const_op) return; // We found a constant, try to insert it in the map and re-use its @@ -62,13 +61,13 @@ class ExecutorConstantSinking if (!map_entry.second) { // This constant has already been cloned into the region, reuse it. use->set(map_entry.first->getSecond().getResult()); - LLVM_DEBUG(llvm::dbgs() << "Re-use sunk constant " << *use->get() - << "\n in " << *use->get() << "\n"); - if (constant->use_empty()) const_op.erase(); + LLVM_DEBUG(llvm::dbgs() << "Re-use sunk constant " << use->get() + << "\n in " << use->get() << "\n"); + if (constant.use_empty()) const_op.erase(); return; } - if (constant->hasOneUse()) { - LLVM_DEBUG(llvm::dbgs() << "Moved constant " << *constant << "\n"); + if (constant.hasOneUse()) { + LLVM_DEBUG(llvm::dbgs() << "Moved constant " << constant << "\n"); const_op.getOperation()->moveBefore(&body.begin()->front()); return; } @@ -76,8 +75,8 @@ class ExecutorConstantSinking body.begin()->getOperations().insert(body.begin()->begin(), map_entry.first->getSecond()); use->set(map_entry.first->getSecond().getResult()); - LLVM_DEBUG(llvm::dbgs() << "Sunk cloned constant " << *use->get() - << "\n in " << *use->get() << "\n"); + LLVM_DEBUG(llvm::dbgs() << "Sunk cloned constant " << use->get() + << "\n in " << use->get() << "\n"); }); }); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/test_side_effect_analysis.cc b/tensorflow/compiler/mlir/tensorflow/transforms/test_side_effect_analysis.cc index f0b7964389d..eb754cc3bbd 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/test_side_effect_analysis.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/test_side_effect_analysis.cc @@ -22,11 +22,11 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Debug.h" -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Pass/PassManager.h" // TF:local_config_mlir -#include "mlir/Support/LLVM.h" // TF:local_config_mlir -#include "mlir/Transforms/Passes.h" // TF:local_config_mlir -#include "mlir/Transforms/RegionUtils.h" // TF:local_config_mlir +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/PassManager.h" // TF:llvm-project +#include "mlir/Support/LLVM.h" // TF:llvm-project +#include "mlir/Transforms/Passes.h" // TF:llvm-project +#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc index 2eb12c80efe..5606428bb19 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc @@ -16,10 +16,10 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.h" #include "llvm/Support/CommandLine.h" -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Identifier.h" // TF:local_config_mlir -#include "mlir/IR/Location.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Identifier.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.h b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.h index 8b97bd606a9..49d92bf3151 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_GRAPH_OPTIMIZATION_PASS_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_GRAPH_OPTIMIZATION_PASS_H_ -#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/Pass/Pass.h" // TF:llvm-project #include "tensorflow/core/common_runtime/optimization_registry.h" namespace tensorflow { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc index 7a840aa0d12..98833a7de40 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc @@ -35,17 +35,17 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Casting.h" -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Identifier.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/Types.h" // TF:local_config_mlir -#include "mlir/IR/Value.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir -#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir -#include "mlir/Transforms/RegionUtils.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Identifier.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/IR/Value.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/PassRegistry.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -140,8 +140,8 @@ bool ShouldMoveOpAfterCluster( const llvm::SmallSetVector& cluster_ops, const llvm::SmallSetVector& preceding_users) { auto result = op->walk([&](Operation* op) { - for (Value* operand : op->getOperands()) { - Operation* def = operand->getDefiningOp(); + for (Value operand : op->getOperands()) { + Operation* def = operand.getDefiningOp(); // Operands may not have a defining op (BlockArgument) or is from a // different block. if (!def || def->getBlock() != block) continue; @@ -179,13 +179,13 @@ llvm::SmallSetVector CollectClusterPrecedingUsers( // `tf_device::LaunchOp` and associated terminator. Results that have no uses // outside of the cluster (i.e. results of ops in the cluster are only consumed // by other ops in the cluster) are pruned. -llvm::SmallVector CollectClusterResults( +llvm::SmallVector CollectClusterResults( Block* block, const llvm::SmallSetVector& cluster_ops) { - llvm::SmallVector results; + llvm::SmallVector results; for (Operation* op : cluster_ops) { - for (Value* result : op->getResults()) { - for (Operation* user : result->getUsers()) { + for (Value result : op->getResults()) { + for (Operation* user : result.getUsers()) { // Check if user is not an op in the cluster. if (cluster_ops.count(block->findAncestorOpInBlock(*user)) == 0) { results.push_back(result); @@ -200,13 +200,13 @@ llvm::SmallVector CollectClusterResults( // Creates a `tf_device::LaunchOp` to wrap cluster ops. tf_device::LaunchOp CreateLaunchOpForCluster(Operation* last_cluster_op, - llvm::ArrayRef results) { + llvm::ArrayRef results) { // `tf_device::LaunchOp` will be placed at where the last op of the cluster // is. OpBuilder builder(last_cluster_op); llvm::SmallVector result_types; - for (Value* result : results) result_types.push_back(result->getType()); + for (Value result : results) result_types.push_back(result.getType()); // An empty string placeholder is used for the device as that will be later // populated with the device of the associated TPUReplicateMetadata op. @@ -241,12 +241,12 @@ void MoveClusterOpsToLaunchOp( // Replaces uses of cluster ops results outside of cluster with the associated // `tf_device::LaunchOp` results. void UpdateLaunchOpResultExternalUses(tf_device::LaunchOp launch_op, - llvm::ArrayRef results) { + llvm::ArrayRef results) { Block& launch_op_block = launch_op.GetBody(); for (auto ret_vals : llvm::zip(results, launch_op.getResults())) { - Value* old_ret = std::get<0>(ret_vals); - Value* new_ret = std::get<1>(ret_vals); - for (auto& use : old_ret->getUses()) + Value old_ret = std::get<0>(ret_vals); + Value new_ret = std::get<1>(ret_vals); + for (auto& use : old_ret.getUses()) if (!launch_op_block.findAncestorOpInBlock(*use.getOwner())) use.set(new_ret); } @@ -307,7 +307,7 @@ LogicalResult ReplicateCluster(tf_device::LaunchOp launch_op, llvm::SmallSetVector unique_replicated_input_ops; mlir::visitUsedValuesDefinedAbove( launch_op.body(), launch_op.body(), [&](mlir::OpOperand* operand) { - Operation* def = operand->get()->getDefiningOp(); + Operation* def = operand->get().getDefiningOp(); if (def && llvm::isa(def)) unique_replicated_input_ops.insert(def); }); @@ -337,9 +337,9 @@ LogicalResult ReplicateCluster(tf_device::LaunchOp launch_op, // Replace replicated cluster results with replicate op results. for (auto result_and_idx : llvm::enumerate(launch_op.getResults())) { - Value* result = result_and_idx.value(); + Value result = result_and_idx.value(); int idx = result_and_idx.index(); - for (auto& use : result->getUses()) { + for (auto& use : result.getUses()) { Operation* def = use.getOwner(); if (!def || !llvm::isa(def)) return launch_op.emitError() @@ -360,7 +360,7 @@ LogicalResult ReplicateCluster(tf_device::LaunchOp launch_op, for (auto input_and_block_arg : llvm::zip(replicated_input_ops, replicate_op.GetBody().getArguments())) { Operation* input = std::get<0>(input_and_block_arg); - Value* block_arg = std::get<1>(input_and_block_arg); + Value block_arg = std::get<1>(input_and_block_arg); mlir::replaceAllUsesInRegionWith(input->getResult(0), block_arg, launch_op.body()); } @@ -412,7 +412,7 @@ LogicalResult FormClustersInBlock(Block* block, llvm::SmallSetVector preceding_users = CollectClusterPrecedingUsers(block, cluster_ops); - llvm::SmallVector results = + llvm::SmallVector results = CollectClusterResults(block, cluster_ops); tf_device::LaunchOp launch_op = @@ -470,7 +470,7 @@ void TPUClusterFormation::runOnFunction() { // `tf_device.replicate` is created and replicated (1) operands/results are // untouched. if (op->getNumOperands() == 1 && op->getNumResults() == 1) - op->getResult(0)->replaceAllUsesWith(op->getOperand(0)); + op->getResult(0).replaceAllUsesWith(op->getOperand(0)); // Leftover TPUReplicatedInput/TPUReplicatedOutput that are not of // `num_replicas` to 1. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_padding_mapper.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_padding_mapper.cc index f2f885dbcc8..38a01e168f7 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_padding_mapper.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_padding_mapper.cc @@ -24,15 +24,15 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Block.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/IR/Value.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir -#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Block.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/Value.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/PassRegistry.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/core/protobuf/tpu/dynamic_padding.pb.h" @@ -60,9 +60,9 @@ llvm::SmallDenseMap GetRemappedReplicatedInputIndices( llvm::SmallDenseMap remapped_indices; for (auto operand_and_idx : llvm::enumerate(launch_func.getOperands())) - if (auto block_arg = llvm::dyn_cast(operand_and_idx.value())) - if (block_arg->getOwner() == replicate_block) - remapped_indices[block_arg->getArgNumber()] = operand_and_idx.index(); + if (auto block_arg = operand_and_idx.value().dyn_cast()) + if (block_arg.getOwner() == replicate_block) + remapped_indices[block_arg.getArgNumber()] = operand_and_idx.index(); return remapped_indices; } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_merge_variables_with_execute.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_merge_variables_with_execute.cc index 28332503adc..dddf916089b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_merge_variables_with_execute.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_merge_variables_with_execute.cc @@ -28,18 +28,18 @@ limitations under the License. #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/Identifier.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/Types.h" // TF:local_config_mlir -#include "mlir/IR/Value.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir -#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir -#include "mlir/Transforms/RegionUtils.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/Identifier.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/IR/Value.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/PassRegistry.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -92,15 +92,15 @@ struct VariableAccessInfo { // Information about all resource accesses to be fused into a TPUExecute op. struct VariableAccessesForTPUExecute { // Maps each resource detected to VariableAccessInfo. - llvm::SmallDenseMap per_resource_info; + llvm::SmallDenseMap per_resource_info; // The corresponding new output index in TPUExecuteAndUpdateVariables for // each old output index in TPUExecute. llvm::SmallVector old_to_new_output_mapping; // The resources read by ReadVariableOps that are inputs to TPUExecute. // Ordered by the input indices to TPUExecute - llvm::SmallVector resources_read; + llvm::SmallVector resources_read; // Operands for the new TPUExecuteAndUpdateVariables. - llvm::SmallVector new_operand_values; + llvm::SmallVector new_operand_values; }; // Returns if an op accesses a resource. @@ -135,23 +135,23 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(Operation* execute, // Find inputs that are variable reads. for (auto operand : llvm::enumerate(execute->getOpOperands())) { infos.new_operand_values.push_back(operand.value().get()); - if (!operand.value().get()->getDefiningOp()) continue; + if (!operand.value().get().getDefiningOp()) continue; auto read_op = llvm::dyn_cast( - operand.value().get()->getDefiningOp()); + operand.value().get().getDefiningOp()); if (!read_op) continue; auto resource = read_op.resource(); if (check_device) { - if (auto resource_op = resource->getDefiningOp()) { + if (auto resource_op = resource.getDefiningOp()) { auto resource_attr = resource_op->getAttr(kDeviceAttr); // Check device matching for the node defining the resource. if (!resource_attr || resource_attr != device_attr) continue; } else { - auto resource_arg = llvm::dyn_cast(resource); + auto resource_arg = resource.dyn_cast(); assert(resource_arg); // Check device matching for the argument defining the resource. auto resource_attr = func.getArgAttrOfType( - resource_arg->getArgNumber(), kFuncDeviceAttr); + resource_arg.getArgNumber(), kFuncDeviceAttr); if (!resource_attr || resource_attr != device_attr) continue; } } @@ -206,7 +206,7 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(Operation* execute, } infos.resources_read.erase( llvm::remove_if(infos.resources_read, - [&](const Value* resource) { + [&](const Value resource) { return infos.per_resource_info.count(resource) == 0; }), infos.resources_read.end()); @@ -222,9 +222,8 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(Operation* execute, llvm::SmallVector output_fused(execute->getNumResults(), false); for (int i = 0; i < execute->getNumResults(); ++i) { auto result = execute->getResult(i); - if (!result->hasOneUse()) continue; - auto assign_op = - llvm::dyn_cast(*result->user_begin()); + if (!result.hasOneUse()) continue; + auto assign_op = llvm::dyn_cast(*result.user_begin()); if (!assign_op) continue; auto resource = assign_op.resource(); auto it = infos.per_resource_info.find(resource); @@ -330,7 +329,7 @@ void MergeForOneTPUExecute(Operation* execute, bool check_device, // Replace the uses. for (int i = 0; i < infos.old_to_new_output_mapping.size(); ++i) { if (infos.old_to_new_output_mapping[i] < 0) continue; - execute->getResult(i)->replaceAllUsesWith( + execute->getResult(i).replaceAllUsesWith( merged_execute.getResult(infos.old_to_new_output_mapping[i])); } // Remove the assign ops. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc index 1033670dd1c..355c0afa40b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc @@ -25,15 +25,15 @@ limitations under the License. #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/Types.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir -#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/PassRegistry.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" @@ -277,7 +277,7 @@ Operation* BuildCompileOp(tf_device::LaunchFuncOp launch_func, int num_replicas, // TODO(b/139377366): When shape inference is ready, we can use compile time // shape inference to get inputs that have static shapes and only use shape // ops for the rest. - llvm::SmallVector compile_op_operands; + llvm::SmallVector compile_op_operands; compile_op_operands.reserve(launch_func.getNumOperands()); for (auto operand_and_idx : llvm::enumerate(launch_func.getOperands())) { @@ -332,7 +332,7 @@ Operation* BuildExecuteOp(Operation* compile_op, OpBuilder* builder) { // TPUExecute inherits all launch_func inputs, and takes an additional input // for compilation cache key. - llvm::SmallVector tensor_inputs(launch_func.getOperands()); + llvm::SmallVector tensor_inputs(launch_func.getOperands()); tensor_inputs.push_back(compile_op->getResult(1)); // TODO(b/139377366): Need to snapshot all resource variable inputs in @@ -457,7 +457,7 @@ LogicalResult Rewrite( // the other ops that are intended to consume the compile result. Block* block = launch_func.getOperation()->getBlock(); for (auto compile_result_op : block->getOps()) - compile_result_op.output()->replaceAllUsesWith(compile_op->getResult(0)); + compile_result_op.output().replaceAllUsesWith(compile_op->getResult(0)); BuildTPUCompileSucceededAssertOp(compile_op, builder); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc b/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc index 764c7915577..98a043219db 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc @@ -19,12 +19,12 @@ limitations under the License. #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir -#include "mlir/Support/STLExtras.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/PassRegistry.h" // TF:llvm-project +#include "mlir/Support/STLExtras.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" @@ -44,7 +44,7 @@ struct BreakUpIslands : OperationPass { void BreakUpIsland(tf_executor::IslandOp op, const TF::SideEffectAnalysis& side_effect_analysis, - llvm::DenseMap>* + llvm::DenseMap>* new_control_edges); }; @@ -64,7 +64,7 @@ void BreakUpIslands::runOnOperation() { // Map from the users of the existing islands to the list of control // edges that need to be added. - llvm::DenseMap> new_control_edges; + llvm::DenseMap> new_control_edges; auto& side_effect_analysis = getAnalysis(); // Iterate in reverse order to avoid invalidating Operation* stored in // new_control_edges. @@ -78,7 +78,7 @@ void BreakUpIslands::runOnOperation() { // Apply edge additions in reverse order so that the ops don't get // invalidated. - llvm::SmallVector edges; + llvm::SmallVector edges; llvm::SmallPtrSet dups; llvm::SmallVector types; for (auto& item : @@ -96,12 +96,12 @@ void BreakUpIslands::runOnOperation() { edges.assign(item.operand_begin(), item.operand_end()); dups.clear(); - for (Value* input : edges) { - dups.insert(input->getDefiningOp()); + for (Value input : edges) { + dups.insert(input.getDefiningOp()); } // Insert new control edges removing duplicates. - for (Value* value : llvm::reverse(edge.second)) { - if (dups.insert(value->getDefiningOp()).second) edges.push_back(value); + for (Value value : llvm::reverse(edge.second)) { + if (dups.insert(value.getDefiningOp()).second) edges.push_back(value); } state.addOperands(edges); Operation* new_op = builder.createOperation(state); @@ -114,7 +114,7 @@ void BreakUpIslands::runOnOperation() { // Helper that creates an island. If `sub_op` is not nullptr, it will be moved // to the island. tf_executor::IslandOp CreateIsland(ArrayRef result_types, - ArrayRef control_inputs, + ArrayRef control_inputs, const tf_executor::ControlType& control_type, const Location& loc, Operation* sub_op, tf_executor::IslandOp original_island) { @@ -132,7 +132,7 @@ tf_executor::IslandOp CreateIsland(ArrayRef result_types, if (sub_op) { island_builder.create(loc, sub_op->getResults()); } else { - island_builder.create(loc, ArrayRef{}); + island_builder.create(loc, ArrayRef{}); } return island; } @@ -160,7 +160,7 @@ IslandSourcesAndSinks FindSourcesAndSinksInIsland( for (auto predecessor : predecessors) result.sinks.erase(predecessor); bool has_in_island_operands = false; for (auto operand : sub_op.getOperands()) { - auto defining_op = operand->getDefiningOp(); + auto defining_op = operand.getDefiningOp(); if (!defining_op || defining_op->getParentOp() != island) continue; // Remove operands from sinks. result.sinks.erase(defining_op); @@ -178,7 +178,7 @@ IslandSourcesAndSinks FindSourcesAndSinksInIsland( void BreakUpIslands::BreakUpIsland( tf_executor::IslandOp op, const TF::SideEffectAnalysis& side_effect_analysis, - llvm::DenseMap>* + llvm::DenseMap>* new_control_edges) { auto island_body = op.GetBody().without_terminator(); // Skip islands that are already only a single op. @@ -188,18 +188,18 @@ void BreakUpIslands::BreakUpIsland( auto island_control_inputs = llvm::to_vector<4>(op.controlInputs()); // Add control dependencies for yields of values defined by other islands to // the island that defines that fetched value. - for (auto* fetch : op.GetYield().fetches()) { + for (auto fetch : op.GetYield().fetches()) { // Ok, because there is no op to add control to (eg: function args). - if (!fetch->getDefiningOp()) continue; - if (fetch->getDefiningOp()->getParentOp() == op) { + if (!fetch.getDefiningOp()) continue; + if (fetch.getDefiningOp()->getParentOp() == op) { // OK, because it is the same island. } else if (auto island_op = llvm::dyn_cast( - fetch->getDefiningOp())) { + fetch.getDefiningOp())) { island_control_inputs.push_back(island_op.control()); } else { // TODO(parkers): Any defining op that has a control output can be handled // just like an island. - fetch->getDefiningOp()->emitError("Fetching non-island as dependency."); + fetch.getDefiningOp()->emitError("Fetching non-island as dependency."); return signalPassFailure(); } } @@ -214,9 +214,9 @@ void BreakUpIslands::BreakUpIsland( auto sources_and_sinks = FindSourcesAndSinksInIsland(op, side_effect_analysis); // The corresponding control output of the new island created for each sub-op. - llvm::SmallDenseMap new_control_for_sub_ops; + llvm::SmallDenseMap new_control_for_sub_ops; // Control outputs of newly created islands that are sinks. - llvm::SmallVector sink_island_controls; + llvm::SmallVector sink_island_controls; // For each operation in the island, construct a new island to wrap the op, // yield all the results, and replace all the usages with the results of the // new island. @@ -224,7 +224,7 @@ void BreakUpIslands::BreakUpIsland( const auto predecessors = side_effect_analysis.DirectControlPredecessors(&sub_op); // Get the controls from the predecessors. - llvm::SmallVector predecessors_control; + llvm::SmallVector predecessors_control; predecessors_control.reserve(predecessors.size()); for (auto predecessor : predecessors) { predecessors_control.push_back(new_control_for_sub_ops[predecessor]); @@ -233,9 +233,9 @@ void BreakUpIslands::BreakUpIsland( // by inter-islands dependencies; otherwise, we do not need to include // island_control_inputs, since they must have been tracked by the (direct // or indirect) control predecessors or operands. - ArrayRef control = sources_and_sinks.sources.count(&sub_op) > 0 - ? island_control_inputs - : predecessors_control; + ArrayRef control = sources_and_sinks.sources.count(&sub_op) > 0 + ? island_control_inputs + : predecessors_control; auto island = CreateIsland(llvm::to_vector<4>(sub_op.getResultTypes()), control, control_type, sub_op.getLoc(), &sub_op, op); @@ -255,11 +255,11 @@ void BreakUpIslands::BreakUpIsland( sink_island_controls.push_back(island.control()); } assert(sink_island_controls.size() == 1); - op.control()->replaceAllUsesWith(sink_island_controls[0]); + op.control().replaceAllUsesWith(sink_island_controls[0]); // All existing outputs need to add a control flow edge from // sink_island_controls[0]. - for (Value* out : op.outputs()) { - for (auto& use : out->getUses()) { + for (Value out : op.outputs()) { + for (auto& use : out.getUses()) { Operation* owner = use.getOwner(); if (auto island_op = llvm::dyn_cast(owner->getParentOp())) { @@ -275,7 +275,7 @@ void BreakUpIslands::BreakUpIsland( } } for (auto item : llvm::zip(op.outputs(), op.GetYield().fetches())) - std::get<0>(item)->replaceAllUsesWith(std::get<1>(item)); + std::get<0>(item).replaceAllUsesWith(std::get<1>(item)); op.erase(); } diff --git a/tensorflow/compiler/mlir/tensorflow/translate/control_to_executor_dialect.cc b/tensorflow/compiler/mlir/tensorflow/translate/control_to_executor_dialect.cc index 29979c02116..696891289ca 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/control_to_executor_dialect.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/control_to_executor_dialect.cc @@ -22,13 +22,13 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" #include "llvm/Support/Debug.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/Value.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir -#include "mlir/Support/LLVM.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/Value.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/PassRegistry.h" // TF:llvm-project +#include "mlir/Support/LLVM.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -68,9 +68,9 @@ static bool HasOnlyTFControlOperations(FuncOp function) { tf_executor::IslandOp ControlToExecutorDialectConversion::CreateIslandForOp( Operation *op, OpBuilder *builder) { // Create a new region for the tf_executor.island body - SmallVector operands; - for (Value *operand : op->getOperands()) - if (operand->getType().isa()) + SmallVector operands; + for (Value operand : op->getOperands()) + if (operand.getType().isa()) operands.push_back(operand); SmallVector types; for (Type result_type : op->getResultTypes()) @@ -118,8 +118,8 @@ void ControlToExecutorDialectConversion::runOnFunction() { // This is the return of the function, we will create a fetch in the graph // matching the operands of the returns. The return is then updated to // take as operands the results of the tf_executor.graph operation. - SmallVector ret_vals; - for (Value *operand : op.getOperands()) ret_vals.push_back(operand); + SmallVector ret_vals; + for (Value operand : op.getOperands()) ret_vals.push_back(operand); for (auto &graph_result : llvm::enumerate(graph_op.getResults())) op.setOperand(graph_result.index(), graph_result.value()); builder.create(getFunction().getLoc(), ret_vals); @@ -128,7 +128,7 @@ void ControlToExecutorDialectConversion::runOnFunction() { assert(IsUnderscoredTFOp(&op) && "Expected only _tf operations"); // The operands and types arrays are used to create the tf_executor ops. - SmallVector operands; + SmallVector operands; operands.append(op.getOperands().begin(), op.getOperands().end()); SmallVector types; for (Type result_type : op.getResultTypes()) { @@ -155,7 +155,7 @@ void ControlToExecutorDialectConversion::runOnFunction() { loc, types, operands, ArrayRef{}); } else if (op.getName().getStringRef() == "_tf.NextIteration.source") { replacement = builder.create( - loc, op.getResult(0)->getType()); + loc, op.getResult(0).getType()); // Record a mapping of the name to the nextiteration.source so that when // we convert the sink we can get the token. StringAttr frame = op.getAttrOfType("name"); @@ -164,9 +164,9 @@ void ControlToExecutorDialectConversion::runOnFunction() { cast(replacement); // Replace the results here since the _tf source does not produce a token // there isn't a mapping for the new result #1. - op.getResult(0)->replaceAllUsesWith(replacement->getResult(0)); + op.getResult(0).replaceAllUsesWith(replacement->getResult(0)); for (int i : llvm::seq(1, op.getNumResults())) - op.getResult(i)->replaceAllUsesWith(replacement->getResult(i + 1)); + op.getResult(i).replaceAllUsesWith(replacement->getResult(i + 1)); replacement->setAttrs(op.getAttrList()); op.erase(); continue; @@ -201,8 +201,8 @@ void ControlToExecutorDialectConversion::runOnFunction() { // Only the non-control operands are carried over, the island is handling // the control input. - for (Value *operand : op.getOperands()) - if (!operand->getType().isa()) + for (Value operand : op.getOperands()) + if (!operand.getType().isa()) result.operands.push_back(operand); // Add a result type for each non-control result we find @@ -223,7 +223,7 @@ void ControlToExecutorDialectConversion::runOnFunction() { inner_op->setAttrs(op.getAttrList()); // Add the terminator for the island - SmallVector ret_vals(inner_op->getResults()); + SmallVector ret_vals(inner_op->getResults()); island_builder.create(loc, ret_vals); } @@ -232,7 +232,7 @@ void ControlToExecutorDialectConversion::runOnFunction() { if (!isa(replacement)) replacement->setAttrs(op.getAttrList()); for (int i : llvm::seq(0, op.getNumResults())) - op.getResult(i)->replaceAllUsesWith(replacement->getResult(i)); + op.getResult(i).replaceAllUsesWith(replacement->getResult(i)); op.erase(); } } diff --git a/tensorflow/compiler/mlir/tensorflow/translate/derived_attr_populator_gen.cc b/tensorflow/compiler/mlir/tensorflow/translate/derived_attr_populator_gen.cc index 222463e1d29..be146ab63a0 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/derived_attr_populator_gen.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/derived_attr_populator_gen.cc @@ -23,7 +23,7 @@ limitations under the License. #include "llvm/TableGen/Main.h" #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" -#include "mlir/TableGen/Operator.h" // TF:local_config_mlir +#include "mlir/TableGen/Operator.h" // TF:llvm-project using llvm::LessRecord; using llvm::raw_ostream; diff --git a/tensorflow/compiler/mlir/tensorflow/translate/executor_to_control_dialect.cc b/tensorflow/compiler/mlir/tensorflow/translate/executor_to_control_dialect.cc index 8a4f8aacc0d..96a7fcbb5ba 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/executor_to_control_dialect.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/executor_to_control_dialect.cc @@ -21,13 +21,13 @@ limitations under the License. #include "llvm/ADT/SmallString.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/Value.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir -#include "mlir/Support/LLVM.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/Value.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/PassRegistry.h" // TF:llvm-project +#include "mlir/Support/LLVM.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -42,54 +42,6 @@ struct ExecutorToControlDialectConversion : public FunctionPass { void runOnFunction() override; }; - -// Replace all uses of value `v` with a list of new values. Because number of -// new values might be greater than 1, users of `v` might be replaced with their -// clones in case of non-resizable operands list. -void ReplaceAllUsesOfValueWithValues(Value *v, - Operation::operand_range new_values) { - int new_values_size = std::distance(new_values.begin(), new_values.end()); - if (new_values_size == 1) { - v->replaceAllUsesWith(*new_values.begin()); - return; - } - - OpBuilder builder(v->getContext()); - for (Operation *user : llvm::make_early_inc_range(v->getUsers())) { - builder.setInsertionPoint(user); - - llvm::SmallVector new_operands; - new_operands.reserve(user->getNumOperands() - 1 + new_values_size); - for (Value *operand : user->getOperands()) { - if (operand == v) { - new_operands.append(new_values.begin(), new_values.end()); - } else { - new_operands.push_back(operand); - } - } - - if (user->hasResizableOperandsList()) { - user->setOperands(new_operands); - continue; - } - - OperationState state(user->getLoc(), user->getName().getStringRef()); - state.addOperands(new_operands); - - llvm::SmallVector result_types(user->getResultTypes()); - state.addTypes(result_types); - - state.addAttributes(user->getAttrs()); - for (auto &old_region : user->getRegions()) { - Region *r = state.addRegion(); - r->takeBody(old_region); - } - Operation *replacement = builder.createOperation(state); - user->replaceAllUsesWith(replacement); - user->erase(); - } -} - } // end anonymous namespace static bool HasSingleGraph(FuncOp function) { @@ -127,7 +79,7 @@ void ExecutorToControlDialectConversion::runOnFunction() { for (auto ops_and_ret_vals : llvm::zip(graph.getResults(), fetch.getOperands())) std::get<0>(ops_and_ret_vals) - ->replaceAllUsesWith(std::get<1>(ops_and_ret_vals)); + .replaceAllUsesWith(std::get<1>(ops_and_ret_vals)); op.erase(); continue; } @@ -135,7 +87,18 @@ void ExecutorToControlDialectConversion::runOnFunction() { builder.setInsertionPoint(&op); if (auto island = dyn_cast(op)) { - Value *ctl_sequence = nullptr; + Value ctl_sequence = nullptr; + if (island.GetBody().without_terminator().empty() && + island.getNumOperands() > 1) { + // For an empty island with multiple control inputs, we create a no-op + // inside it which will group all the inputs into one control output. + // This helps reducing the number of edges when there are multiple + // islands depending on this one. + builder.setInsertionPointToStart(&island.GetBody()); + builder.create(op.getLoc(), ArrayRef{}, + ArrayRef{}, ArrayRef{}); + builder.setInsertionPoint(&op); + } for (Operation &wrapped_op : island.GetBody()) { LLVM_DEBUG(llvm::dbgs() << " In island: " << wrapped_op.getName() << "\n"); @@ -143,7 +106,7 @@ void ExecutorToControlDialectConversion::runOnFunction() { for (auto ops_and_ret_vals : llvm::zip(island.getResults(), wrapped_op.getOperands())) std::get<0>(ops_and_ret_vals) - ->replaceAllUsesWith(std::get<1>(ops_and_ret_vals)); + .replaceAllUsesWith(std::get<1>(ops_and_ret_vals)); break; } // Add a leading _ off the name. @@ -162,7 +125,7 @@ void ExecutorToControlDialectConversion::runOnFunction() { if (ctl_sequence) { state.operands.push_back(ctl_sequence); } else { - for (Value *ctl_operand : island.getOperands()) + for (Value ctl_operand : island.getOperands()) state.operands.push_back(ctl_operand); } @@ -178,7 +141,7 @@ void ExecutorToControlDialectConversion::runOnFunction() { for (auto ops_and_ret_vals : llvm::zip(wrapped_op.getResults(), replacement->getResults())) std::get<0>(ops_and_ret_vals) - ->replaceAllUsesWith(std::get<1>(ops_and_ret_vals)); + .replaceAllUsesWith(std::get<1>(ops_and_ret_vals)); ctl_sequence = replacement->getResult(replacement->getNumResults() - 1); } @@ -188,12 +151,13 @@ void ExecutorToControlDialectConversion::runOnFunction() { // been rewritten from ops in island. Last op rewritten must logically // carry // all the island control inputs, we can simply use it to // replace all uses of island's control output. - island.control()->replaceAllUsesWith(ctl_sequence); - } else { - // Getting here means island had an effectively empty body. In this - // case, island's control output should be replaced with all the control - // inputs of island. - ReplaceAllUsesOfValueWithValues(island.control(), island.getOperands()); + island.control().replaceAllUsesWith(ctl_sequence); + } else if (island.getNumOperands() > 0) { + // Getting here means island had an effectively empty body and there is + // just one control input. In this case, island's control output should + // be replaced with the control input. + assert(island.getNumOperands() == 1); + island.control().replaceAllUsesWith(island.getOperand(0)); } op.erase(); @@ -228,7 +192,7 @@ void ExecutorToControlDialectConversion::runOnFunction() { // dialect. auto non_null_operands = llvm::make_filter_range( op.getOperands(), - [](Value *v) { return !v->getType().isa(); }); + [](Value v) { return !v.getType().isa(); }); state.operands.append(non_null_operands.begin(), non_null_operands.end()); for (Type result_type : op.getResultTypes()) { // Filter out TokenType, they don't exist in the control dialect. @@ -248,14 +212,14 @@ void ExecutorToControlDialectConversion::runOnFunction() { if (auto next_iteration = dyn_cast(op)) { - next_iteration.output()->replaceAllUsesWith(replacement->getResult(0)); - next_iteration.token()->dropAllUses(); - next_iteration.control()->replaceAllUsesWith(replacement->getResult(1)); + next_iteration.output().replaceAllUsesWith(replacement->getResult(0)); + next_iteration.token().dropAllUses(); + next_iteration.control().replaceAllUsesWith(replacement->getResult(1)); } else { for (auto ops_and_ret_vals : llvm::zip(op.getResults(), replacement->getResults())) std::get<0>(ops_and_ret_vals) - ->replaceAllUsesWith(std::get<1>(ops_and_ret_vals)); + .replaceAllUsesWith(std::get<1>(ops_and_ret_vals)); } op.erase(); } diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc index 9d572209b31..39698c0f96b 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc @@ -27,18 +27,19 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/Identifier.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/Types.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Pass/PassManager.h" // TF:local_config_mlir -#include "mlir/Support/DebugStringHelper.h" // TF:local_config_mlir -#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/Identifier.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/PassManager.h" // TF:llvm-project +#include "mlir/Support/DebugStringHelper.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" @@ -110,25 +111,28 @@ std::string LegalizeNodeName(llvm::StringRef name) { return legalized_name; } -// TODO(jpienaar): unify and move from here to be able to reuse with tflite -std::string GetName(Operation* inst) { - // TODO(prakalps): b/137006652 prevents us from using location info (derived - // from experimental_debug_info) to generate node names. Until it is fixed, - // first check for "name" attribute to get node name. - - // Default name is Operation type. - auto name = inst->getName().getStringRef(); - if (auto attr = inst->getAttrOfType("name")) { - name = attr.getValue(); - } else if (auto name_loc = inst->getLoc().dyn_cast()) { - name = name_loc.getName().strref(); - } else if (auto call_loc = inst->getLoc().dyn_cast()) { +llvm::StringRef GetNameFromLoc(mlir::Location loc, + llvm::StringRef default_name) { + if (auto name_loc = loc.dyn_cast()) { + return name_loc.getName().strref().split('@').first; + } else if (auto call_loc = loc.dyn_cast()) { // Return name if CallSiteLoc's callee has a NameLoc (as should be the case // if imported with DebugInfo), else use the fallback naming scheme below. if (auto name_loc = call_loc.getCallee().dyn_cast()) - name = name_loc.getName().strref(); + return name_loc.getName().strref().split('@').first; + } else if (auto fused_loc = loc.dyn_cast()) { + // According to the importer, the last location of a fused location is + // the name from the node_def and the rests are from the experimental debug + // info. + return GetNameFromLoc(fused_loc.getLocations().back(), default_name); } + return default_name; +} +// TODO(jpienaar): unify and move from here to be able to reuse with tflite +std::string GetName(Operation* inst) { + // Default name is Operation type. + auto name = GetNameFromLoc(inst->getLoc(), inst->getName().getStringRef()); return LegalizeNodeName(name); } @@ -161,7 +165,7 @@ class Exporter { explicit Exporter(Graph* graph, const Dialect* tf_dialect) : graph_(graph), tf_dialect_(tf_dialect) {} - Status AddArgumentNode(BlockArgument* arg, unsigned index, + Status AddArgumentNode(BlockArgument arg, unsigned index, llvm::StringRef name); Status AddReturnNode(mlir::ReturnOp op, llvm::ArrayRef names); @@ -169,7 +173,7 @@ class Exporter { Status AddNextIterationNode(Operation* inst); Status AddEdge(Operation* inst); - StatusOr> GetArgumentNode(BlockArgument* arg, + StatusOr> GetArgumentNode(BlockArgument arg, unsigned index, llvm::StringRef name); StatusOr> GetReturnNode(Operation* inst, @@ -177,7 +181,7 @@ class Exporter { llvm::StringRef name); // Adds one edge between src_node and dst_node. If it is not a control edge, // an index is used to find out the right operand of the dst_node. - Status AddEdgeBetweenNodes(Value* src, Node* dst_node, unsigned dst_index); + Status AddEdgeBetweenNodes(Value src, Node* dst_node, unsigned dst_index); // Returns a unique name for `op`. std::string UniqueName(Operation* op); @@ -189,7 +193,7 @@ class Exporter { absl::flat_hash_map op_to_name_; absl::flat_hash_map name_to_count_; absl::flat_hash_map nodes_; - absl::flat_hash_map args_; + llvm::DenseMap args_; // One single return operation can return multiple results, and each of them // will be converted to one node in the graph. typedef absl::InlinedVector NodeVector; @@ -231,8 +235,8 @@ std::string Exporter::UniqueName(Operation* op) { } StatusOr> Exporter::GetArgumentNode( - BlockArgument* arg, unsigned index, llvm::StringRef name) { - auto func = arg->getParentRegion()->getParentOfType(); + BlockArgument arg, unsigned index, llvm::StringRef name) { + auto func = arg.getParentRegion()->getParentOfType(); auto node_def = absl::make_unique(); if (!name.empty()) @@ -244,7 +248,7 @@ StatusOr> Exporter::GetArgumentNode( DataType dtype; TF_RETURN_IF_ERROR(ConvertToDataType( - arg->getType().cast().getElementType(), &dtype)); + arg.getType().cast().getElementType(), &dtype)); AttrValue type_attr; type_attr.set_type(dtype); (*node_def->mutable_attr())["T"] = type_attr; @@ -279,10 +283,10 @@ StatusOr> Exporter::GetReturnNode( UniqueName(inst->getParentOfType().getName().str())); node_def->set_op(FunctionLibraryDefinition::kRetOp); - auto* inst_op = inst->getOperand(index); + auto inst_op = inst->getOperand(index); DataType dtype; TF_RETURN_IF_ERROR(ConvertToDataType( - inst_op->getType().cast().getElementType(), &dtype)); + inst_op.getType().cast().getElementType(), &dtype)); AttrValue type_attr; type_attr.set_type(dtype); (*node_def->mutable_attr())["T"] = type_attr; @@ -292,10 +296,10 @@ StatusOr> Exporter::GetReturnNode( return node_def; } -Status Exporter::AddEdgeBetweenNodes(Value* src, Node* dst_node, +Status Exporter::AddEdgeBetweenNodes(Value src, Node* dst_node, unsigned dst_index) { - if (auto* input_result = dyn_cast(src)) { - auto* input_inst = input_result->getOwner(); + if (auto input_result = src.dyn_cast()) { + auto* input_inst = input_result.getOwner(); // replaces the input node by the sink one if it is an NextIteration source: auto it = source_to_sink_.find(input_inst); if (it != source_to_sink_.end()) { @@ -304,16 +308,16 @@ Status Exporter::AddEdgeBetweenNodes(Value* src, Node* dst_node, auto node_it = nodes_.find(input_inst); TF_RET_CHECK(node_it != nodes_.end()) << "Use of OpResult encountered before def!"; - if (input_result->getType().isa()) { + if (input_result.getType().isa()) { graph_->AddControlEdge(node_it->second, dst_node); } else { - graph_->AddEdge(node_it->second, input_result->getResultNumber(), - dst_node, dst_index); + graph_->AddEdge(node_it->second, input_result.getResultNumber(), dst_node, + dst_index); } return Status::OK(); } - auto* input_arg = cast(src); + auto input_arg = src.cast(); auto input_node_it = args_.find(input_arg); TF_RET_CHECK(input_node_it != args_.end()) << "Use of BlockArgument encounted before def!"; @@ -326,7 +330,7 @@ Status Exporter::AddEdge(Operation* inst) { auto* dst_node = nodes_[inst]; bool is_return_op = isa(inst); for (int index = 0, e = inst->getNumOperands(); index < e; index++) { - auto* src = inst->getOperand(index); + auto src = inst->getOperand(index); // For return operation, the edge is from the operand owner to one of the // faked return nodes. The input index is always 0 for the return node. if (is_return_op) { @@ -361,14 +365,14 @@ Status Exporter::AddInstructionNode(Operation* inst) { return Status::OK(); } -bool IsEntryFunctionArg(BlockArgument* arg) { - return arg->getParentRegion()->getParentOfType().getName() == +bool IsEntryFunctionArg(BlockArgument arg) { + return arg.getParentRegion()->getParentOfType().getName() == "main"; } // Creates argument nodes from Block argument. If a name is supplied, that // name will be used instead of generating a unique name. -Status Exporter::AddArgumentNode(BlockArgument* arg, unsigned index, +Status Exporter::AddArgumentNode(BlockArgument arg, unsigned index, llvm::StringRef name) { if (!IsEntryFunctionArg(arg) || !name.empty()) { TF_ASSIGN_OR_RETURN(auto node_def, GetArgumentNode(arg, index, name)); @@ -383,21 +387,21 @@ Status Exporter::AddArgumentNode(BlockArgument* arg, unsigned index, // is an input node. We recover the original input node and skip adding the // argument node. The new input node will be handled as normal in the // following steps. - if (!arg->hasOneUse()) { + if (!arg.hasOneUse()) { return errors::FailedPrecondition( "Arg in 'main' should only have one user."); } - auto* input = *arg->user_begin(); + auto* input = *arg.user_begin(); auto input_name = input->getName().getStringRef(); input_name.consume_back(".input"); - mlir::OpBuilder builder(arg->getOwner()); + mlir::OpBuilder builder(arg.getOwner()); auto loc = mlir::NameLoc::get(builder.getIdentifier(UniqueName(input)), builder.getContext()); OperationState state(loc, input_name.str()); state.attributes.append(input->getAttrs().begin(), input->getAttrs().end()); - for (auto* op : input->getOperands()) { + for (auto op : input->getOperands()) { // Skip the argument in the new operation. - if (llvm::isa(op)) continue; + if (op.isa()) continue; state.operands.push_back(op); } state.types.append(input->getResultTypes().begin(), @@ -405,9 +409,17 @@ Status Exporter::AddArgumentNode(BlockArgument* arg, unsigned index, auto* inst = builder.createOperation(state); // If it is one of the specified input names, then the new // instruction should have the same name. - op_to_name_[inst].assign(op_to_name_[input]); + auto& mapped_name = op_to_name_[inst]; + const auto& input_mapped_name = op_to_name_[input]; + DCHECK(mapped_name.empty()) + << "AddArgumentNode() attempted to change the op_to_name_ mapping for " + << inst << " from " << mapped_name << " to " << input_mapped_name << "."; + DCHECK(!input_mapped_name.empty()) + << "AddArgumentNode() attempted to set the op_to_name_ mapping for " + << inst << " to an empty string."; + mapped_name.assign(input_mapped_name); for (int index : llvm::seq(0, input->getNumResults())) { - input->getResult(index)->replaceAllUsesWith(inst->getResult(index)); + input->getResult(index).replaceAllUsesWith(inst->getResult(index)); } input->dropAllReferences(); input->erase(); @@ -511,9 +523,15 @@ StatusOr> Exporter::Convert( // Only assign defining op of operands of the return the output names if // the main graph did not have its _Retval nodes lifted into the functions // returns. - if (!graph_as_function) - exporter.op_to_name_[it.value()->getDefiningOp()] = - output_names[it.index()]; + if (!graph_as_function) { + auto defining_op = it.value().getDefiningOp(); + auto& mapped_name = exporter.op_to_name_[defining_op]; + DCHECK(mapped_name.empty()) + << "Convert() attempted to change the op_to_name_ mapping for " + << defining_op << " from " << mapped_name << " to output " + << it.index() << " name " << output_names[it.index()].str() << "."; + mapped_name = output_names[it.index()]; + } } } if (!input_names.empty()) { @@ -522,17 +540,23 @@ StatusOr> Exporter::Convert( exporter.name_to_count_[input_names[it.index()].str()] = 1; // Only assign user of argument the input name if the main graph did not // have its _Arg nodes lifted into the functions arguments. - if (!graph_as_function) - exporter.op_to_name_[*it.value()->user_begin()] = - input_names[it.index()]; + if (!graph_as_function) { + auto first_user = *it.value().user_begin(); + auto& mapped_name = exporter.op_to_name_[first_user]; + DCHECK(mapped_name.empty()) + << "Convert() attempted to change the op_to_name_ mapping for " + << first_user << " from " << mapped_name << " to input " + << it.index() << " name " << input_names[it.index()].str() << "."; + mapped_name = input_names[it.index()]; + } } } // Adds nodes for basic block (function) arguments. for (auto it : llvm::enumerate(block.getArguments())) { int index = it.index(); - auto* arg = it.value(); - mlir::Type type = arg->getType(); + auto arg = it.value(); + mlir::Type type = arg.getType(); if (!type.isa()) { return errors::InvalidArgument( "FuncOps arguments must have tensor types. Found ", diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h index ab9b9731ab4..71ef3c8c493 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h @@ -17,9 +17,9 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_EXPORT_GRAPHDEF_H_ #include "llvm/ADT/StringRef.h" -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc index adb5ba2b569..8cc12869704 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc @@ -83,11 +83,10 @@ Status GetUnregisteredAttrs( TF_ASSIGN_OR_RETURN(auto op_name, GetTensorFlowOpName(inst->getName().getStringRef())); - const tensorflow::OpRegistrationData* op_reg_data; - auto status = tensorflow::OpRegistry::Global()->LookUp(op_name, &op_reg_data); - if (!status.ok()) { + const tensorflow::OpRegistrationData* op_reg_data = + tensorflow::OpRegistry::Global()->LookUp(op_name); + if (!op_reg_data) { // This is likely a function call node, so we should continue. - VLOG(1) << status.ToString(); return Status::OK(); } @@ -132,8 +131,8 @@ StatusOr> ConvertTFDialectOpToNodeDef( if (inst->getDialect() && inst->getDialect()->getNamespace() == "_tf") { mlir::OperationState result(inst->getLoc(), inst->getName().getStringRef().drop_front()); - for (mlir::Value* operand : inst->getOperands()) - if (!operand->getType().isa()) + for (mlir::Value operand : inst->getOperands()) + if (!operand.getType().isa()) result.operands.push_back(operand); // Add a result type for each non-control result we find @@ -161,6 +160,13 @@ StatusOr> ConvertTFDialectOpToNodeDef( TF_RETURN_IF_ERROR(GetUnregisteredAttrs(inst, &attrs_to_ignore)); } + if (inst->hasTrait()) { + // TODO(b/146937733): Don't use here. + llvm::StringRef attr_name = mlir::OpTrait::AttrSizedResultSegments< + void>::getResultSegmentSizeAttr(); + attrs_to_ignore.insert(attr_name.data()); + } + TF_ASSIGN_OR_RETURN(auto node_def, GetOperationNodeDef(attrs_to_ignore, inst, name)); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h index 1e18a2d5d3b..df1f4859ded 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_EXPORT_TF_DIALECT_OP_H_ #include "llvm/ADT/StringRef.h" -#include "mlir/IR/Operation.h" // TF:local_config_mlir +#include "mlir/IR/Operation.h" // TF:llvm-project #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/stream_executor/lib/statusor.h" diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 868faed9b0b..0f258495f47 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -35,20 +35,22 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/Analysis/Verifier.h" // TF:local_config_mlir -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/Identifier.h" // TF:local_config_mlir -#include "mlir/IR/Location.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/IR/Types.h" // TF:local_config_mlir +#include "mlir/Analysis/Verifier.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/Identifier.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/OpDefinition.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project #include "tensorflow/compiler/jit/shape_inference_helpers.h" #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" #include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h" @@ -70,6 +72,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/resource_var.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/types.h" @@ -80,6 +83,7 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/grappler/utils/transitive_fanin.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/protobuf.h" @@ -264,7 +268,7 @@ class ImporterBase { mlir::Operation* createOperation( const Node& node, llvm::StringRef node_type_name, const mlir::OperationState& result, - const llvm::SmallVectorImpl& control_operands, + const llvm::SmallVectorImpl& control_operands, bool convert_to_legacy_call = false); // Converts one NodeDef from the input GraphDef into an Operation and @@ -421,7 +425,6 @@ Status UpdateLegacyFedInputNode(const GraphDef& graph_def, // - Replacing LegacyFedInput nodes with Placeholder nodes if // convert_legacy_fed_inputs option is enabled. Status PreprocessGraphDef(const GraphImportConfig* specs, GraphDef* graph_def) { - const tensorflow::OpRegistrationData* op_reg_data; for (auto& node_def : *graph_def->mutable_node()) { // TODO(hinsu): Completely deprecate support for LegacyFedInput ops. One // solution could be have a tool to let users upgrade old serialized graphs. @@ -431,11 +434,10 @@ Status PreprocessGraphDef(const GraphImportConfig* specs, GraphDef* graph_def) { UpdateLegacyFedInputNode(*graph_def, specs->inputs, &node_def)); } - auto status = - tensorflow::OpRegistry::Global()->LookUp(node_def.op(), &op_reg_data); - if (!status.ok()) { + const tensorflow::OpRegistrationData* op_reg_data = + tensorflow::OpRegistry::Global()->LookUp(node_def.op()); + if (!op_reg_data) { // This is likely a function call node, so we should continue. - VLOG(1) << status.ToString(); continue; } ::tensorflow::AddDefaultsToNodeDef(op_reg_data->op_def, &node_def); @@ -1176,7 +1178,7 @@ Status ImporterBase::ConvertFunctionArgAndRets( const absl::InlinedVector& ret_nodes, const absl::InlinedVector& control_ret_nodes) { auto* bb = &func.front(); - llvm::SmallDenseMap, mlir::Value*, 4> + llvm::SmallDenseMap, mlir::Value, 4> arg_nodes_to_values; for (int i = 0, e = arg_types.size(); i < e; ++i) { auto& arg_node = arg_nodes[i]; @@ -1184,8 +1186,8 @@ Status ImporterBase::ConvertFunctionArgAndRets( // be converted to mlir operations and don't have a mapping. mlir::Operation* island = node_values_.find(arg_node.node->id())->second; - auto* bb_arg = bb->getArgument(i); - mlir::Value* arg_def = bb_arg; + auto bb_arg = bb->getArgument(i); + mlir::Value arg_def = bb_arg; if (island->getNumResults() != 2) return errors::InvalidArgument( @@ -1193,9 +1195,9 @@ Status ImporterBase::ConvertFunctionArgAndRets( // Collect mapping of OutputTensor to associated block arg. arg_nodes_to_values.try_emplace({arg_node.node, arg_node.index}, arg_def); - island->getResult(0)->replaceAllUsesWith(arg_def); + island->getResult(0).replaceAllUsesWith(arg_def); // Erase control outputs from feed. - auto control_uses = island->getResult(1)->getUses(); + auto control_uses = island->getResult(1).getUses(); for (auto& control_use : llvm::make_early_inc_range(control_uses)) control_use.getOwner()->eraseOperand(control_use.getOperandNumber()); @@ -1208,7 +1210,7 @@ Status ImporterBase::ConvertFunctionArgAndRets( island->erase(); } - llvm::SmallVector inst_to_return; + llvm::SmallVector inst_to_return; for (const auto& ret : ret_nodes) { auto* inst = node_values_[ret.node->id()]; auto op = absl::string_view(ret.node->type_string()); @@ -1320,15 +1322,21 @@ mlir::Location ImporterBase::GetLocation(const NodeDef& node_def) { return create_location(node_def.name(), function_name_for_debug_info_); } else { // If the original nodes are defined, then we use them to get a list of - // call sites, and then fuse them to a single fused location. - llvm::SmallVector node_call_sites; - node_call_sites.reserve(original_nodes.size()); + // call sites, and then fuse them to a single fused location, with the name + // of the node_def. + llvm::SmallVector node_locations; + node_locations.reserve(original_nodes.size() + 1); + + // store the names in the experimental_debug_info for (int i = 0, e = original_nodes.size(); i != e; ++i) { auto node_name = original_nodes[i]; auto func_name = (i < original_funcs.size()) ? original_funcs[i] : ""; - node_call_sites.push_back(create_location(node_name, func_name)); + node_locations.push_back(create_location(node_name, func_name)); } - return mlir::FusedLoc::get(node_call_sites, context_); + // store the name of the node_def + node_locations.push_back( + create_location(node_def.name(), function_name_for_debug_info_)); + return mlir::FusedLoc::get(node_locations, context_); } } @@ -1349,14 +1357,14 @@ std::string ImporterBase::GetLocationStr(const Node& node, mlir::Operation* ImporterBase::createOperation( const Node& node, llvm::StringRef node_type_name, const mlir::OperationState& result, - const llvm::SmallVectorImpl& control_operands, + const llvm::SmallVectorImpl& control_operands, bool convert_to_legacy_call) { // For the tf.executor specific operations (not wrapped in an island), we // have an extra returned value for the control result, and we concatenate // control and non-control operands. mlir::SmallVector types(result.types); types.push_back(mlir::tf_executor::ControlType::get(builder_.getContext())); - mlir::SmallVector operands(result.operands); + mlir::SmallVector operands(result.operands); operands.append(control_operands.begin(), control_operands.end()); auto loc = result.location; @@ -1384,7 +1392,7 @@ mlir::Operation* ImporterBase::createOperation( builder_.getBlock()->begin()); auto source_op = builder_at_begin.create( - loc, operands[0]->getType(), result.attributes); + loc, operands[0].getType(), result.attributes); return builder_.create( loc, source_op.token(), operands, result.attributes); } @@ -1434,6 +1442,32 @@ mlir::Operation* ImporterBase::createOperation( inner_op = island_builder.createOperation(result); } + if (inner_op->hasTrait()) { + // The op has multiple variadic outputs. + // Calculate result segment sizes using the OpDef. + NameRangeMap output_ranges; + // This will fail only if the OpDef is syntactically invalid. + // TODO(jpienaar): Convert this CHECK into a properly propagated error. + TF_CHECK_OK( + NameRangesForNode(node, node.op_def(), nullptr, &output_ranges)); + std::vector values; + values.reserve(node.op_def().output_arg_size()); + for (const auto& output_arg : node.op_def().output_arg()) { + auto range = output_ranges[output_arg.name()]; + values.push_back( + island_builder.getI32IntegerAttr(range.second - range.first)); + } + + // Add derived "result_segment_sizes" attr to the created operation. + // TODO(b/146937733): Don't use here. + llvm::StringRef attr_name = mlir::OpTrait::AttrSizedResultSegments< + void>::getResultSegmentSizeAttr(); + auto attr_type = mlir::VectorType::get(node.op_def().output_arg_size(), + builder_.getIntegerType(32)); + auto attr_value = mlir::DenseElementsAttr::get(attr_type, values); + inner_op->setAttr(attr_name, attr_value); + } + // Add the terminator for the island island_builder.create(result.location, inner_op->getResults()); @@ -1499,7 +1533,7 @@ Status ImporterBase::ConvertNode(const Node& node) { result.operands.reserve(in_edges.size()); // Collect the control operands separately, they will be held by the island. - mlir::SmallVector control_operands; + mlir::SmallVector control_operands; for (const auto* input_edge : in_edges) { const Node& input_node = *input_edge->src(); @@ -1568,8 +1602,6 @@ Status ImporterBase::ConvertNode(const Node& node) { &result.attributes)); } - result.attributes.push_back(builder_.getNamedAttr( - "name", builder_.getStringAttr(std::string(node.name())))); result.attributes.push_back(builder_.getNamedAttr( "device", builder_.getStringAttr(std::string(node_def.device())))); @@ -1625,7 +1657,7 @@ Status ImporterBase::AddBackedges() { Status ImporterBase::AddBackedge(mlir::Operation* sink, mlir::Operation* dst, int dst_input) { // Get the NextIteration.Source operation from the token operand of the sink. - mlir::Operation* source = sink->getOperand(0)->getDefiningOp(); + mlir::Operation* source = sink->getOperand(0).getDefiningOp(); // Adds the "source" to the operands of the dst by creating a new dst // operation. @@ -1650,8 +1682,8 @@ Status ImporterBase::AddBackedge(mlir::Operation* sink, mlir::Operation* dst, // Replaces the output uses of the old operation by the corresponding // result of the new operation, and deletes the old operation. for (unsigned i = 0, e = dst->getNumResults(); i != e; ++i) { - auto* new_output = new_dst->getResult(i); - dst->getResult(i)->replaceAllUsesWith(new_output); + auto new_output = new_dst->getResult(i); + dst->getResult(i).replaceAllUsesWith(new_output); } dst->dropAllReferences(); dst->erase(); @@ -1705,8 +1737,8 @@ class GraphDefImporter : public ImporterBase { static StatusOr Convert( mlir::MLIRContext* context, const Graph& graph, const GraphDebugInfo& debug_info, - const FunctionLibraryDefinition& flib_def, - const GraphImportConfig& specs); + const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs, + llvm::StringRef func_name); private: explicit GraphDefImporter( @@ -1744,7 +1776,7 @@ class GraphDefImporter : public ImporterBase { StatusOr GraphDefImporter::Convert( mlir::MLIRContext* context, const Graph& graph, const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def, - const GraphImportConfig& specs) { + const GraphImportConfig& specs, llvm::StringRef func_name) { mlir::OwningModuleRef module = mlir::ModuleOp::create(mlir::UnknownLoc::get(context)); std::unordered_map tf_name_to_mlir_name; @@ -1832,7 +1864,7 @@ StatusOr GraphDefImporter::Convert( {producer, min_consumer, bad_consumers}))); TF_RETURN_IF_ERROR(importer.ImporterBase::Convert( - "main", func_type, arg_nodes, ret_nodes, control_ret_nodes, attrs, + func_name, func_type, arg_nodes, ret_nodes, control_ret_nodes, attrs, resource_arg_unique_ids)); return module; } @@ -2535,7 +2567,7 @@ Status CreateSavedModelIR( module.insert(module.getBody()->begin(), func); func.addEntryBlock(); func.setName("__sm_exported_" + orig_func.getName().str()); - llvm::SmallVector args_as_values; + llvm::SmallVector args_as_values; for (auto block_argument : func.getArguments()) { args_as_values.push_back(block_argument); } @@ -2742,6 +2774,292 @@ StatusOr SavedModelImporter::Convert( return module; } +// A helper class to import a TensorFlow model expressed in SavedModel V1 into +// an MLIR Module. +class SavedModelV1Importer { + public: + // Main entry point: converts all functions (specified by SignatureDefs) in + // the given meta graph to an MLIR Module. + static StatusOr Convert(const SavedModelBundle& bundle, + mlir::MLIRContext* context) { + SavedModelV1Importer importer(bundle, context); + + return importer.ConvertSignatures(); + } + + private: + SavedModelV1Importer(const SavedModelBundle& bundle, + mlir::MLIRContext* context) + : bundle_(bundle), + module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(context))) {} + + // Convert the SavedModel to TF Executor Dialect. It creates a MLIR function + // for each signature. + StatusOr ConvertSignatures(); + StatusOr ConvertSignature( + const GraphImportConfig& specs, llvm::StringRef func_name, + const SignatureDef& signature_def, const GraphDef& sub_graph_def, + const GraphDebugInfo& debug_info, + const FunctionLibraryDefinition& flib_def); + + // Create GlobalTensorOp for each variable and move each VarHandle op to + // the enclosing function's arugments. + Status LiftVariables(); + void LiftVariable(mlir::TF::VarHandleOp op); + + // Read all variables from the SavedModel through session, and create + // GlobalTensorOp for these variables. + Status ReadVariablesFromSession( + const llvm::SmallVectorImpl& ops); + + GraphImportConfig::InputArrays ParseInputArrays( + const tensorflow::protobuf::Map& inputs); + + std::vector ParseOutputArrays( + const tensorflow::protobuf::Map& outputs); + + const SavedModelBundle& bundle_; + mlir::OwningModuleRef module_; +}; + +// Convert the SavedModel to TF Executor Dialect. It creates a MLIR function +// for each signature. +StatusOr SavedModelV1Importer::ConvertSignatures() { + const auto& signatures = bundle_.GetSignatures(); + const auto& graphdef = bundle_.meta_graph_def.graph_def(); + + FunctionLibraryDefinition flib_def(OpRegistry::Global(), graphdef.library()); + + // debug_info might not be loaded with loader_lite. + GraphDebugInfo debug_info; + if (bundle_.debug_info != nullptr) debug_info = *bundle_.debug_info; + + for (const auto& key_and_signature_def : signatures) { + const auto& func_name = key_and_signature_def.first; + const auto& signature_def = key_and_signature_def.second; + GraphImportConfig specs; + specs.inputs = ParseInputArrays(signature_def.inputs()); + specs.outputs = ParseOutputArrays(signature_def.outputs()); + + // Remove unused nodes and create a sub graphdef. + GraphDef sub_graph_def; + TF_RETURN_IF_ERROR(tensorflow::grappler::SetTransitiveFaninGraph( + graphdef, &sub_graph_def, + /* terminal_nodes = */ {specs.outputs.begin(), specs.outputs.end()})); + + auto status_or_sub_module = ConvertSignature( + specs, func_name, signature_def, sub_graph_def, debug_info, flib_def); + if (!status_or_sub_module.ok()) { + LOG(ERROR) << "Failed to convert SignatureDef for " << func_name << ": " + << status_or_sub_module.status(); + continue; + } + + auto& sub_module = status_or_sub_module.ValueOrDie(); + + // Move the converted functions to top level MLIR module. + auto* block = module_->getBody(); + auto* sub_block = sub_module->getBody(); + block->getOperations().splice( + mlir::Block::iterator(block->getTerminator()), + sub_block->getOperations(), sub_block->begin(), + mlir::Block::iterator(sub_block->getTerminator())); + } + + TF_RETURN_IF_ERROR(LiftVariables()); + + return std::move(module_); +} + +StatusOr SavedModelV1Importer::ConvertSignature( + const GraphImportConfig& specs, llvm::StringRef func_name, + const SignatureDef& signature_def, const GraphDef& sub_graph_def, + const GraphDebugInfo& debug_info, + const FunctionLibraryDefinition& flib_def) { + // Convert this sub graphdef to sub graph + GraphConstructorOptions options; + options.allow_internal_ops = true; + options.add_default_attributes = true; + Graph sub_graph(OpRegistry::Global()); + + TF_RETURN_IF_ERROR( + ConvertGraphDefToGraph(options, sub_graph_def, &sub_graph)); + + // Convert the sub graphdef to a MLIR function. + return GraphDefImporter::Convert(module_->getContext(), sub_graph, debug_info, + flib_def, specs, func_name); +} + +// Create GlobalTensorOp for each variable and move each VarHandle op to +// the enclosing function's arugments. +Status SavedModelV1Importer::LiftVariables() { + llvm::SmallVector ops; + + bool contains_ref_variable = false; + + module_->walk([&ops, &contains_ref_variable](mlir::Operation* op) { + if (auto var_handle_op = llvm::dyn_cast(op)) + ops.push_back(var_handle_op); + else if (op->getName().getStringRef() == "tf.VariableV2") + contains_ref_variable = true; + }); + + if (contains_ref_variable) + return errors::InvalidArgument( + "Ref variable created by VariableV2 is not supported."); + + if (ops.empty()) return Status::OK(); + + TF_RETURN_IF_ERROR(ReadVariablesFromSession(ops)); + + for (auto op : ops) LiftVariable(op); + + return Status::OK(); +} + +// Move the result of the VarHandleOp to the enclosing function's arugment list +// and erase this VarHandleOp. +void SavedModelV1Importer::LiftVariable(mlir::TF::VarHandleOp op) { + mlir::OpBuilder builder(&module_->getBodyRegion()); + + auto func_op = op.getParentOfType(); + builder.setInsertionPoint(func_op); + + auto func_type = func_op.getType(); + + // Create the new function type by adding variable type to the arguments. + llvm::SmallVector new_input_types( + func_type.getInputs().begin(), func_type.getInputs().end()); + new_input_types.push_back(op.resource()->getType()); + auto new_func_type = + builder.getFunctionType(new_input_types, func_type.getResults()); + + auto new_func_op = builder.create( + func_op.getLoc(), func_op.getName(), new_func_type, + llvm::ArrayRef()); + + // Bind the argument to the corresponding global tensor op. + new_func_op.setArgAttr(new_func_op.getNumArguments() - 1, + "tf_saved_model.bound_input", + builder.getSymbolRefAttr(op.shared_name())); + + // Replace the function body and update its signature. + auto& new_region = new_func_op.getBody(); + new_region.getBlocks().splice(new_region.end(), + func_op.getBody().getBlocks()); + + func_op.getOperation()->erase(); + + auto& new_block = new_region.front(); + auto new_value = new_block.addArgument(op.resource()->getType()); + + op.getOperation()->replaceAllUsesWith(llvm::ArrayRef(new_value)); + + op.getOperation()->erase(); +} + +// Read all variables from the SavedModel through session, and create +// GlobalTensorOp for these variables. +Status SavedModelV1Importer::ReadVariablesFromSession( + const llvm::SmallVectorImpl& ops) { + mlir::OpBuilder builder(&module_->getBodyRegion()); + + // Find all variables and their corresponding read ops. + + llvm::MapVector + variable_names_and_ops; + for (auto op : ops) { + variable_names_and_ops[op.shared_name()] = op; + } + + // Read all resource variables from the session. + + std::vector variable_names; + variable_names.reserve(variable_names_and_ops.size()); + for (const auto& name_and_location : variable_names_and_ops) + variable_names.push_back(name_and_location.first); + + std::vector resource_tensors; + TF_RETURN_IF_ERROR(bundle_.GetSession()->Run( + /*inputs=*/{}, variable_names, + /*target_node_names=*/{}, &resource_tensors)); + + const DeviceMgr* device_manager; + TF_RETURN_IF_ERROR(bundle_.GetSession()->LocalDeviceManager(&device_manager)); + + // Read all underlying tensors of the variables from the session. + std::vector tensors; + tensors.reserve(resource_tensors.size()); + for (const auto& resource_tensor : resource_tensors) { + const auto& resource_handle = resource_tensor.scalar()(); + + Device* device; + TF_RETURN_IF_ERROR( + device_manager->LookupDevice(resource_handle.device(), &device)); + + Var* var_ptr; + TF_RETURN_IF_ERROR(device->resource_manager()->Lookup( + resource_handle.container(), resource_handle.name(), &var_ptr)); + core::RefCountPtr var(var_ptr); + + // The variable tensor is already loaded into corresponding device's + // resource manager when we load the saved model using LoadSavedModel(). + // Here we just read its value. + mutex_lock ml(*var->mu()); + tensors.push_back(*var->tensor()); + } + + for (const auto& iter : llvm::zip(variable_names_and_ops, tensors)) { + const auto& name = std::get<0>(iter).first; + auto location = std::get<0>(iter).second.getLoc(); + const auto& tensor = std::get<1>(iter); + + // Create tensor attribute for this variable. + TF_ASSIGN_OR_RETURN(auto tensor_attr, ConvertTensor(tensor, &builder)); + + builder.create( + location, builder.getStringAttr(name), tensor_attr, + mlir::TypeAttr::get(tensor_attr.getType()), builder.getUnitAttr()); + } + + return Status::OK(); +} + +GraphImportConfig::InputArrays SavedModelV1Importer::ParseInputArrays( + const tensorflow::protobuf::Map& inputs) { + GraphImportConfig::InputArrays results; + for (const auto& iter : inputs) { + const auto& tensor_info = iter.second; + + // Only dense tensor is supported. + DCHECK_EQ(tensor_info.encoding_case(), tensorflow::TensorInfo::kName); + + ArrayInfo array_info; + array_info.imported_dtype = tensor_info.dtype(); + array_info.shape = tensor_info.tensor_shape(); + + std::vector node_names = + absl::StrSplit(tensor_info.name(), ':'); + + results.insert(std::pair(node_names.at(0), + std::move(array_info))); + } + return results; +} + +std::vector SavedModelV1Importer::ParseOutputArrays( + const tensorflow::protobuf::Map& outputs) { + std::vector results; + for (const auto& iter : outputs) { + const auto& tensor_info = iter.second; + + std::vector node_names = + absl::StrSplit(tensor_info.name(), ':'); + results.push_back(node_names.at(0)); + } + return results; +} + } // namespace Status UpgradeLegacyGraph(Graph* graph, FunctionLibraryDefinition* flib_def) { @@ -2777,7 +3095,8 @@ StatusOr ConvertGraphToMlir( UpgradeLegacyGraph(const_cast(&graph), const_cast(&flib_def))); } - return GraphDefImporter::Convert(context, graph, debug_info, flib_def, specs); + return GraphDefImporter::Convert(context, graph, debug_info, flib_def, specs, + /* func_name = */ "main"); } StatusOr ConvertSavedModelToMlir( @@ -2787,6 +3106,11 @@ StatusOr ConvertSavedModelToMlir( add_default_attributes); } +StatusOr ConvertSavedModelV1ToMlir( + const SavedModelBundle& saved_model, mlir::MLIRContext* context) { + return SavedModelV1Importer::Convert(saved_model, context); +} + std::string MlirModuleToString(mlir::ModuleOp module, bool show_debug_info) { std::string txt_module; { diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.h b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h index d4b17073bd5..efc316483fe 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h @@ -18,9 +18,10 @@ limitations under the License. #include -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project #include "tensorflow/cc/saved_model/bundle_v2.h" +#include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" @@ -50,6 +51,12 @@ stream_executor::port::StatusOr ConvertSavedModelToMlir( SavedModelV2Bundle* saved_model, mlir::MLIRContext* context, absl::Span exported_names, bool add_default_attributes = true); +// Given a V1 SavedModel, returns a MLIR module containing the functions, +// expressed with tf_executor dialect. +stream_executor::port::StatusOr +ConvertSavedModelV1ToMlir(const SavedModelBundle& saved_model, + mlir::MLIRContext* context); + // Serialize a MLIR module to a string. std::string MlirModuleToString(mlir::ModuleOp m, bool show_debug_info = false); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.cc b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.cc index ca13db56df3..004293410b3 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.cc @@ -15,9 +15,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.h" -#include "mlir/Analysis/Verifier.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir +#include "mlir/Analysis/Verifier.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.h b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.h index 1daa29045c5..79a302b066b 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_ROUNDTRIP_PASS_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_ROUNDTRIP_PASS_H_ -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/lib/core/status.h" diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_functional_to_executor.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_functional_to_executor.cc index 86fbff91db1..a97bca9fc3d 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_functional_to_executor.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_functional_to_executor.cc @@ -14,11 +14,11 @@ limitations under the License. ==============================================================================*/ #include "llvm/Support/Debug.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/PassRegistry.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #define DEBUG_TYPE "tf-functional-to-executor" @@ -75,7 +75,7 @@ void FunctionalToExecutorDialectConversion::runOnFunction() { builder.setInsertionPointToEnd(&graph_op.GetBody()); auto island = builder.create( loc, getFunction().getType().getResults(), - tf_executor::ControlType::get(&getContext()), ArrayRef()); + tf_executor::ControlType::get(&getContext()), ArrayRef()); // Create Fetch. ValueRange to_fetch = island.getResults(); if (to_fetch.size() != 1) { diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc index 5c59eace5cc..8f3cab0e619 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc @@ -17,14 +17,15 @@ limitations under the License. #include "absl/memory/memory.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/Identifier.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/Parser.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/Identifier.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/Parser.h" // TF:llvm-project +#include "tensorflow/cc/saved_model/bundle_v2.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" @@ -129,6 +130,27 @@ mlir::OwningModuleRef SavedModelToMlirImport( return module_or.ConsumeValueOrDie(); } +mlir::OwningModuleRef SavedModelV1ToMlirImport( + absl::string_view saved_model_dir, + const std::unordered_set& tags, mlir::MLIRContext* context) { + tensorflow::SavedModelBundle bundle; + auto load_status = tensorflow::LoadSavedModel( + /* session_options = */ {}, /* run_options = */ {}, + std::string(saved_model_dir), tags, &bundle); + if (!load_status.ok()) { + LOG(ERROR) << "Failed to load saved model v1 '" << saved_model_dir + << "': " << load_status; + return nullptr; + } + + auto module_or = ConvertSavedModelV1ToMlir(bundle, context); + if (!module_or.status().ok()) { + LOG(ERROR) << "SavedModel V1 import failed: " << module_or.status(); + return nullptr; + } + return module_or.ConsumeValueOrDie(); +} + mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction( llvm::StringRef input, absl::string_view debug_info_file, absl::string_view input_arrays, absl::string_view input_dtypes, diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h index ce5337949c1..46e6376207c 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h @@ -21,8 +21,8 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project namespace tensorflow { // TODO(antiagainst): Directly manipulating files in library functions is not @@ -54,6 +54,14 @@ mlir::OwningModuleRef SavedModelToMlirImport( absl::string_view saved_model_dir, const std::unordered_set& tags, absl::Span exported_names, mlir::MLIRContext* context); + +// Converts a TensorFlow V1 SavedModel stored in the directory with the given +// `saved_model_dir` into a MLIR module. Creates MLIR entities into the +// given MLIR `context`. +mlir::OwningModuleRef SavedModelV1ToMlirImport( + absl::string_view saved_model_dir, + const std::unordered_set& tags, mlir::MLIRContext* context); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_TF_MLIR_TRANSLATE_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc index 08b09924fd1..db46fdcf931 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc @@ -21,8 +21,8 @@ limitations under the License. #include "llvm/Support/FileSystem.h" #include "llvm/Support/MemoryBuffer.h" -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/Translation.h" // TF:local_config_mlir +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/Translation.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" diff --git a/tensorflow/compiler/mlir/tensorflow/translate/translate_tf_dialect_op.cc b/tensorflow/compiler/mlir/tensorflow/translate/translate_tf_dialect_op.cc index 38d6a572584..a9b5021559c 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/translate_tf_dialect_op.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/translate_tf_dialect_op.cc @@ -14,11 +14,11 @@ limitations under the License. ==============================================================================*/ #include "llvm/Support/ToolOutputFile.h" -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/Location.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/Translation.h" // TF:local_config_mlir +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/Translation.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc index a37e092aa56..7d449b8775f 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc @@ -17,26 +17,39 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/Support/FormatVariadic.h" -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" namespace tensorflow { +BridgeLoggerConfig::BridgeLoggerConfig(bool print_module_scope, + bool print_after_only_on_change) + : mlir::PassManager::IRPrinterConfig(print_module_scope, + print_after_only_on_change) {} + // Logs op to file with name of format `mlir_bridge-pass_name-file_suffix.mlir`. -inline static void Log(mlir::Pass* pass, mlir::Operation* op, +inline static void Log(BridgeLoggerConfig::PrintCallbackFn print_callback, + mlir::Pass* pass, mlir::Operation* op, llvm::StringRef file_suffix) { - DumpMlirOpToFile( - llvm::formatv("mlir_bridge-{0}-{1}", pass->getName(), file_suffix).str(), - op); + std::string name = + llvm::formatv("mlir_bridge_{0}_{1}", pass->getName(), file_suffix).str(); + + std::unique_ptr os; + std::string filepath; + if (CreateFileForDumping(name, &os, &filepath).ok()) print_callback(*os); } -void BridgeLogger::runBeforePass(mlir::Pass* pass, mlir::Operation* op) { - Log(pass, op, "before"); +void BridgeLoggerConfig::printBeforeIfEnabled(mlir::Pass* pass, + mlir::Operation* operation, + PrintCallbackFn print_callback) { + Log(print_callback, pass, operation, "before"); } -void BridgeLogger::runAfterPass(mlir::Pass* pass, mlir::Operation* op) { - Log(pass, op, "after"); +void BridgeLoggerConfig::printAfterIfEnabled(mlir::Pass* pass, + mlir::Operation* operation, + PrintCallbackFn print_callback) { + Log(print_callback, pass, operation, "after"); } } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h index 2943a37886a..4f6d49b77e9 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h @@ -16,18 +16,32 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_BRIDGE_LOGGER_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_BRIDGE_LOGGER_H_ -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Pass/PassInstrumentation.h" // TF:local_config_mlir +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/PassManager.h" // TF:llvm-project namespace tensorflow { // Logger for logging/dumping MLIR modules before and after passes in bridge // targeting TPUs. -class BridgeLogger : public mlir::PassInstrumentation { +class BridgeLoggerConfig : public mlir::PassManager::IRPrinterConfig { public: - void runBeforePass(mlir::Pass* pass, mlir::Operation* op) override; - void runAfterPass(mlir::Pass* pass, mlir::Operation* op) override; + explicit BridgeLoggerConfig(bool print_module_scope = false, + bool print_after_only_on_change = true); + + // A hook that may be overridden by a derived config that checks if the IR + // of 'operation' should be dumped *before* the pass 'pass' has been + // executed. If the IR should be dumped, 'print_callback' should be invoked + // with the stream to dump into. + void printBeforeIfEnabled(mlir::Pass *pass, mlir::Operation *operation, + PrintCallbackFn print_callback) override; + + // A hook that may be overridden by a derived config that checks if the IR + // of 'operation' should be dumped *after* the pass 'pass' has been + // executed. If the IR should be dumped, 'print_callback' should be invoked + // with the stream to dump into. + void printAfterIfEnabled(mlir::Pass *pass, mlir::Operation *operation, + PrintCallbackFn print_callback) override; }; } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index 4e914a5a20d..02ffae658cc 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -17,15 +17,15 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/OpDefinition.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/Parser.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Pass/PassManager.h" // TF:local_config_mlir -#include "mlir/Transforms/Passes.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/OpDefinition.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/Parser.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/PassManager.h" // TF:llvm-project +#include "mlir/Transforms/Passes.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" @@ -216,8 +216,7 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op, // and canonicalization opportunities that are necessary for the second // LegalizeTFPass(allow_partial_conversion=false) invocation. tf2xla.addNestedPass(mlir::xla_hlo::createLegalizeTFPass(true)); - tf2xla.addPass(mlir::tf_executor::CreateTFExecutorGraphPruningPass( - /*skip_main_func=*/true)); + tf2xla.addPass(mlir::tf_executor::CreateTFExecutorGraphPruningPass()); tf2xla.addNestedPass(mlir::createCanonicalizerPass()); tf2xla.addNestedPass( mlir::xla_hlo::createLegalizeTFPass(false)); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h index a07927ce432..4a462898276 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h @@ -18,7 +18,7 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" -#include "mlir/IR/Module.h" // TF:local_config_mlir +#include "mlir/IR/Module.h" // TF:llvm-project #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/stream_executor/lib/statusor.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc index 1c1f9803bd7..fafd6cc11cb 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc @@ -23,10 +23,10 @@ limitations under the License. #include "llvm/ADT/APFloat.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/Types.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h index 7e982bb489b..b2646c265ad 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h @@ -18,8 +18,8 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc index 69cda63e889..bcd37e39de9 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc @@ -15,9 +15,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project #include "tensorflow/compiler/xla/test.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status_test_util.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc index e2d970c8dfd..7b0cbe6d5b5 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc @@ -17,9 +17,9 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "llvm/Support/Casting.h" -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/Types.h" // TF:local_config_mlir -#include "mlir/Support/DebugStringHelper.h" // TF:local_config_mlir +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/Support/DebugStringHelper.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.h b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.h index fa5c92c12fe..24c4273ad0e 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_CONVERT_TYPE_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_CONVERT_TYPE_H_ -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Types.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.pb.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_type_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_type_test.cc index 423d61dc2c6..e7206096d2c 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_type_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_type_test.cc @@ -16,9 +16,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project #include "tensorflow/compiler/xla/test.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status_test_util.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/device_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/device_util.cc index 8309ab39feb..e983f3e9c0c 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/device_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/device_util.cc @@ -21,10 +21,10 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/FormatVariadic.h" -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/util/device_name_utils.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/device_util.h b/tensorflow/compiler/mlir/tensorflow/utils/device_util.h index fa8a09801fa..73ae18d2487 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/device_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/device_util.h @@ -17,8 +17,8 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DEVICE_UTIL_H_ #include "llvm/ADT/SmallVector.h" -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/util/device_name_utils.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc index a8d628b153a..cb25e000f7a 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc @@ -21,12 +21,12 @@ limitations under the License. #include #include "llvm/ADT/SmallVector.h" -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Location.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/framework/device_attributes.pb.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc index 4b1d059bfa4..423e5012768 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc @@ -24,7 +24,7 @@ limitations under the License. #include "llvm/ADT/Twine.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/IR/Operation.h" // TF:local_config_mlir +#include "mlir/IR/Operation.h" // TF:llvm-project #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" @@ -60,10 +60,43 @@ std::string MakeUniqueFilename(string name) { filename = llvm::Twine(filename).concat(".mlir").str(); return filename; } + +// Simple raw_ostream that prints to LOG(INFO). +struct LogInfoRawStream : public llvm::raw_ostream { + LogInfoRawStream() { SetUnbuffered(); } + ~LogInfoRawStream() override = default; + uint64_t current_pos() const override { return 0; } + + void write_impl(const char* ptr, size_t size) override { + LOG(INFO) << absl::string_view(ptr, size); + } +}; + +// Simple raw_ostream that prints to a file. +struct WritableFileRawStream : public llvm::raw_ostream { + explicit WritableFileRawStream(std::unique_ptr file) + : file(std::move(file)) { + SetUnbuffered(); + } + ~WritableFileRawStream() override = default; + uint64_t current_pos() const override { return 0; } + + void write_impl(const char* ptr, size_t size) override { + // Write the file if it is still valid. If the write fails, null out the + // file to avoid encountering another error. + if (file && !file->Append(StringPiece(ptr, size)).ok()) { + file = nullptr; + } + } + + // The file being written to. + std::unique_ptr file; +}; } // namespace -std::string DumpMlirOpToFile(llvm::StringRef name, mlir::Operation* op, - llvm::StringRef dirname) { +Status CreateFileForDumping(llvm::StringRef name, + std::unique_ptr* os, + std::string* filepath, llvm::StringRef dirname) { const char* dir = nullptr; if (!dirname.empty()) dir = dirname.data(); @@ -72,44 +105,49 @@ std::string DumpMlirOpToFile(llvm::StringRef name, mlir::Operation* op, if (!dir) { LOG(WARNING) - << "Failed to dump MLIR operation '" - << op->getName().getStringRef().str() << "' to '" << name.str() - << "' because dump location is not specified through either " + << "Failed to generate file because dump location is not specified " + "through either " "TF_DUMP_GRAPH_PREFIX environment variable or function argument."; - return "(TF_DUMP_GRAPH_PREFIX not specified)"; + return Status(error::Code::INVALID_ARGUMENT, + "(TF_DUMP_GRAPH_PREFIX not specified)"); } - std::string txt_op; - { - llvm::raw_string_ostream os(txt_op); - op->print(os, mlir::OpPrintingFlags().useLocalScope()); - os.flush(); - } - - Env* env = Env::Default(); - std::string filepath; if (std::strncmp(dir, "-", 2) == 0) { - LOG(INFO) << txt_op; - filepath = "LOG(INFO)"; - } else { - Status status = env->RecursivelyCreateDir(dir); - if (!status.ok()) { - LOG(WARNING) << "Failed to create '" << dir - << "' directory for dumping MLIR operation '" - << op->getName().getStringRef().str() << "': " << status; - return "(unavailable)"; - } - filepath = - llvm::Twine(dir).concat("/").concat(MakeUniqueFilename(name)).str(); - status = WriteStringToFile(env, filepath, txt_op); - if (!status.ok()) { - LOG(WARNING) << "Failed to dump MLIR operation '" - << op->getName().getStringRef().str() << "' to file '" - << filepath << "': " << status; - return "(unavailable)"; - } + *os = std::make_unique(); + *filepath = "LOG(INFO)"; + return Status(); } + // Get a valid file path to dump with. + Env* env = Env::Default(); + Status status = env->RecursivelyCreateDir(dir); + if (!status.ok()) { + LOG(WARNING) << "Failed to create '" << dir + << "' directory for dumping: " << status; + return Status(error::Code::UNAVAILABLE, "(unavailable)"); + } + *filepath = + llvm::Twine(dir).concat("/").concat(MakeUniqueFilename(name)).str(); + + // Try to open the file and generate a raw_ostream. + std::unique_ptr file; + status = env->NewWritableFile(*filepath, &file); + if (!status.ok()) { + LOG(WARNING) << "Failed to create file '" << filepath << "': " << status; + return Status(error::Code::UNAVAILABLE, "(unavailable)"); + } + *os = std::make_unique(std::move(file)); + return Status(); +} + +std::string DumpMlirOpToFile(llvm::StringRef name, mlir::Operation* op, + llvm::StringRef dirname) { + std::unique_ptr os; + std::string filepath; + Status result = CreateFileForDumping(name, &os, &filepath, dirname); + if (!result.ok()) return result.error_message(); + + op->print(*os, mlir::OpPrintingFlags().useLocalScope()); LOG(INFO) << "Dumped MLIR operation '" << op->getName().getStringRef().str() << "' to '" << filepath << "'"; return filepath; diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h index 8ae6797a4f8..c2e4683c1c6 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h @@ -19,10 +19,26 @@ limitations under the License. #include #include "llvm/ADT/StringRef.h" -#include "mlir/IR/Operation.h" // TF:local_config_mlir +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "tensorflow/core/platform/status.h" namespace tensorflow { +// Creates a file to use for dumping and returns success if a file could be +// created. The opened file is placed in 'os' and the path of the file used is +// placed in 'filepath'. +// +// If the TF_DUMP_GRAPH_PREFIX environment variable is "-", then the LOG(INFO) +// macro is used instead. +// +// This will create a file name via prefixing `name` with the value of the +// TF_DUMP_GRAPH_PREFIX environment variable if `dirname` is empty and +// suffixing `name` with ".mlir". +Status CreateFileForDumping(llvm::StringRef name, + std::unique_ptr* os, + std::string* filepath, + llvm::StringRef dirname = ""); + // Dumps MLIR operation to a file and returns the file name used. // // If the TF_DUMP_GRAPH_PREFIX environment variable is "-", then the MLIR diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc index 59d8da91e7b..947a0ef0af3 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc @@ -16,9 +16,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/IR/Location.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/error_util.h b/tensorflow/compiler/mlir/tensorflow/utils/error_util.h index a60d90cbfb7..7eb30ee2c46 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/error_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/error_util.h @@ -18,9 +18,9 @@ limitations under the License. #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/IR/Diagnostics.h" // TF:local_config_mlir -#include "mlir/IR/Location.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir +#include "mlir/IR/Diagnostics.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project #include "tensorflow/core/lib/core/status.h" // Error utilities for MLIR when interacting with code using Status returns. diff --git a/tensorflow/compiler/mlir/tensorflow/utils/error_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/error_util_test.cc index 4e59cec86ab..3f4947bec23 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/error_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/error_util_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "llvm/ADT/Twine.h" -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project #include "tensorflow/compiler/xla/test.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status_test_util.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc index e70ab3197d5..dae0a6cf515 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc @@ -20,10 +20,10 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Types.h" // TF:local_config_mlir -#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project #include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/eval_util.h b/tensorflow/compiler/mlir/tensorflow/utils/eval_util.h index 657ea688b93..39fd91afe40 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/eval_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/eval_util.h @@ -19,7 +19,7 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" -#include "mlir/IR/Operation.h" // TF:local_config_mlir +#include "mlir/IR/Operation.h" // TF:llvm-project #include "tensorflow/c/eager/c_api.h" namespace tensorflow { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc index e35b7130de8..d2f17586ad3 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc @@ -23,17 +23,17 @@ limitations under the License. #include "absl/strings/string_view.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/Identifier.h" // TF:local_config_mlir -#include "mlir/IR/Location.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir -#include "mlir/Support/DebugStringHelper.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/Identifier.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/OperationSupport.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/TypeUtilities.h" // TF:llvm-project +#include "mlir/Support/DebugStringHelper.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" @@ -65,8 +65,12 @@ Status ConvertLocation(mlir::Location inst_loc, debug_info->add_original_node_names(name_loc.getName().c_str()); } } else if (auto fused = inst_loc.dyn_cast()) { - for (auto loc : fused.getLocations()) { - TF_RETURN_IF_ERROR(ConvertLocation(loc, debug_info)); + auto locations = fused.getLocations(); + if (locations.size() <= 1) + return errors::InvalidArgument("expected experimental debuf info."); + // skip the first one, which is the name of the node_def. + for (int i = 0; i < locations.size() - 1; ++i) { + TF_RETURN_IF_ERROR(ConvertLocation(locations[i], debug_info)); } } return Status::OK(); @@ -218,12 +222,12 @@ static bool IsRefTypeControlOp(mlir::Operation* op) { auto op_name = op_name_or_status.ConsumeValueOrDie(); if (op_name.equals("NextIteration")) - return mlir::getElementTypeOrSelf(op->getOperand(0)->getType()) + return mlir::getElementTypeOrSelf(op->getOperand(0).getType()) .isa(); if (op_name.equals("Enter") || op_name.equals("Exit") || op_name.equals("Switch") || op_name.equals("Merge")) { - return getElementTypeOrSelf(op->getResult(0)->getType()) + return getElementTypeOrSelf(op->getResult(0).getType()) .isa(); } return false; diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h index df176762c07..a8c91c0b494 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h @@ -23,10 +23,10 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Location.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/Types.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def.pb.h" diff --git a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc index f70868e217f..736e954278e 100644 --- a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc +++ b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc @@ -17,10 +17,10 @@ limitations under the License. #include "llvm/Support/InitLLVM.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/ToolOutputFile.h" -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Pass/PassManager.h" // TF:local_config_mlir -#include "mlir/Support/FileUtilities.h" // TF:local_config_mlir -#include "mlir/Support/MlirOptMain.h" // TF:local_config_mlir +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/PassManager.h" // TF:llvm-project +#include "mlir/Support/FileUtilities.h" // TF:llvm-project +#include "mlir/Support/MlirOptMain.h" // TF:llvm-project #include "tensorflow/compiler/mlir/init_mlir.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc index 9ab31265a33..f5fc56556ec 100644 --- a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc +++ b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc @@ -21,11 +21,11 @@ limitations under the License. #include "llvm/Support/SMLoc.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/ToolOutputFile.h" -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/Support/FileUtilities.h" // TF:local_config_mlir -#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir -#include "mlir/Support/ToolUtilities.h" // TF:local_config_mlir -#include "mlir/Support/TranslateClParser.h" // TF:local_config_mlir +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/Support/FileUtilities.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/Support/ToolUtilities.h" // TF:llvm-project +#include "mlir/Support/TranslateClParser.h" // TF:llvm-project #include "tensorflow/compiler/mlir/init_mlir.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h" @@ -54,6 +54,12 @@ static llvm::cl::opt import_saved_model( llvm::cl::desc("Import a saved model to its MLIR representation"), llvm::cl::value_desc("dir")); +// NOLINTNEXTLINE +static llvm::cl::opt import_saved_model_v1( + "savedmodel-v1-to-mlir", + llvm::cl::desc("Import a saved model V1 to its MLIR representation"), + llvm::cl::value_desc("dir")); + // NOLINTNEXTLINE static llvm::cl::opt saved_model_tags( "tf-savedmodel-tags", @@ -77,10 +83,11 @@ int main(int argc, char** argv) { llvm::cl::ParseCommandLineOptions(argc, argv, "TF MLIR translation driver\n"); - if (!import_saved_model && !requested_translation) { + if (!import_saved_model && !import_saved_model_v1 && !requested_translation) { llvm::errs() << "error: need to specify one translation to perform\n"; return 1; - } else if (import_saved_model && requested_translation) { + } else if (import_saved_model && import_saved_model_v1 && + requested_translation) { llvm::errs() << "error: cannot specify more than one translation to perform\n"; return 1; @@ -105,6 +112,16 @@ int main(int argc, char** argv) { &context); if (!module) return 1; + module->print(output->os()); + } else if (import_saved_model_v1) { + std::unordered_set tags = + absl::StrSplit(saved_model_tags, ','); + mlir::MLIRContext context; + + auto module = + tensorflow::SavedModelV1ToMlirImport(input_filename, tags, &context); + if (!module) return 1; + module->print(output->os()); } else { auto input = mlir::openInputFile(input_filename, &error_message); diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index 17629440871..451f37211e8 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -1,4 +1,4 @@ -load("@local_config_mlir//:tblgen.bzl", "gentbl") +load("//third_party/mlir:tblgen.bzl", "gentbl") load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_native_cc_binary") package( @@ -8,7 +8,7 @@ package( package_group( name = "friends", - includes = ["@local_config_mlir//:subpackages"], + includes = ["//third_party/mlir:subpackages"], packages = [ "//babelfish/device/...", "//learning/brain/experimental/mlir/...", @@ -32,7 +32,7 @@ filegroup( "ir/hlo_ops_base.td", "ir/hlo_utils.td", "ir/lhlo_ops.td", - "@local_config_mlir//:OpBaseTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", ], ) @@ -44,7 +44,7 @@ gentbl( ("-gen-struct-attr-decls", "ir/hlo_structs.h.inc"), ("-gen-struct-attr-defs", "ir/hlo_structs.cc.inc"), ], - tblgen = "@local_config_mlir//:mlir-tblgen", + tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/hlo_ops.td", td_includes = ["ir/hlo_utils.td"], td_srcs = [":hlo_ops_td_files"], @@ -56,7 +56,7 @@ gentbl( ("-gen-op-decls", "ir/hlo_ops_base.h.inc"), ("-gen-op-defs", "ir/hlo_ops_base.cc.inc"), ], - tblgen = "@local_config_mlir//:mlir-tblgen", + tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/hlo_ops_base.td", td_srcs = [":hlo_ops_td_files"], ) @@ -67,7 +67,7 @@ gentbl( ("-gen-op-decls", "ir/lhlo_ops.h.inc"), ("-gen-op-defs", "ir/lhlo_ops.cc.inc"), ], - tblgen = "@local_config_mlir//:mlir-tblgen", + tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/lhlo_ops.td", td_srcs = [":hlo_ops_td_files"], ) @@ -77,12 +77,12 @@ gentbl( tbl_outs = [ ("-gen-rewriters", "transforms/generated_legalize_tf.inc"), ], - tblgen = "@local_config_mlir//:mlir-tblgen", + tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/legalize_tf_patterns.td", td_srcs = [ ":hlo_ops_td_files", - "@llvm//:support", - "@local_config_mlir//:StdOpsTdFiles", + "@llvm-project//llvm:support", + "@llvm-project//mlir:StdOpsTdFiles", "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", ], ) @@ -95,7 +95,7 @@ gentbl( "transforms/generated_canonicalize.inc", ), ], - tblgen = "@local_config_mlir//:mlir-tblgen", + tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/canonicalize.td", td_srcs = [ ":hlo_ops_td_files", @@ -114,15 +114,16 @@ cc_library( ":hlo", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:lower_tf_lib", + "//tensorflow/compiler/xla/client:padding", "//tensorflow/core:framework", "//tensorflow/core/kernels:conv_grad_shape_utils", - "@llvm//:support", - "@local_config_mlir//:Analysis", - "@local_config_mlir//:IR", - "@local_config_mlir//:Pass", - "@local_config_mlir//:StandardOps", - "@local_config_mlir//:Support", - "@local_config_mlir//:Transforms", + "@llvm-project//llvm:support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", ], alwayslink = 1, ) @@ -135,11 +136,11 @@ cc_library( ":lhlo", "//tensorflow/compiler/xla:status", "@com_google_absl//absl/memory", - "@llvm//:support", - "@local_config_mlir//:AffineOps", - "@local_config_mlir//:IR", - "@local_config_mlir//:Pass", - "@local_config_mlir//:StandardOps", + "@llvm-project//llvm:support", + "@llvm-project//mlir:AffineOps", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", ], alwayslink = 1, ) @@ -151,13 +152,13 @@ cc_library( deps = [ ":lhlo", "@com_google_absl//absl/memory", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:Linalg", - "@local_config_mlir//:LinalgDialectRegistration", - "@local_config_mlir//:Pass", - "@local_config_mlir//:StandardOps", - "@local_config_mlir//:Transforms", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Linalg", + "@llvm-project//mlir:LinalgDialectRegistration", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Transforms", ], alwayslink = 1, ) @@ -169,14 +170,14 @@ cc_library( deps = [ ":lhlo", "@com_google_absl//absl/memory", - "@llvm//:support", - "@local_config_mlir//:GPUDialect", - "@local_config_mlir//:IR", - "@local_config_mlir//:Linalg", - "@local_config_mlir//:LoopOps", - "@local_config_mlir//:Pass", - "@local_config_mlir//:StandardOps", - "@local_config_mlir//:Transforms", + "@llvm-project//llvm:support", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Linalg", + "@llvm-project//mlir:LoopOps", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Transforms", ], alwayslink = 1, ) @@ -187,9 +188,9 @@ cc_library( deps = [ ":lhlo", "@com_google_absl//absl/memory", - "@local_config_mlir//:Linalg", - "@local_config_mlir//:LinalgDialectRegistration", - "@local_config_mlir//:Pass", + "@llvm-project//mlir:Linalg", + "@llvm-project//mlir:LinalgDialectRegistration", + "@llvm-project//mlir:Pass", ], alwayslink = 1, ) @@ -202,10 +203,10 @@ cc_library( ":lhlo", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", - "@local_config_mlir//:IR", - "@local_config_mlir//:Pass", - "@local_config_mlir//:StandardOps", - "@local_config_mlir//:Transforms", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Transforms", ], alwayslink = 1, ) @@ -215,11 +216,11 @@ gentbl( tbl_outs = [ ("-gen-rewriters", "transforms/generated_legalize_to_standard.inc"), ], - tblgen = "@local_config_mlir//:mlir-tblgen", + tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/legalize_to_standard_patterns.td", td_srcs = [ ":hlo_ops_td_files", - "@local_config_mlir//:StdOpsTdFiles", + "@llvm-project//mlir:StdOpsTdFiles", ], ) @@ -230,12 +231,12 @@ cc_library( ], deps = [ ":hlo", - "@llvm//:support", - "@local_config_mlir//:Analysis", - "@local_config_mlir//:IR", - "@local_config_mlir//:Pass", - "@local_config_mlir//:StandardOps", - "@local_config_mlir//:Support", + "@llvm-project//llvm:support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", ], alwayslink = 1, ) @@ -246,11 +247,11 @@ cc_library( deps = [ ":hlo", ":xla_legalize_to_standard_inc_gen", - "@llvm//:support", - "@local_config_mlir//:Analysis", - "@local_config_mlir//:IR", - "@local_config_mlir//:Pass", - "@local_config_mlir//:StandardOps", + "@llvm-project//llvm:support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", ], alwayslink = 1, ) @@ -260,12 +261,12 @@ gentbl( tbl_outs = [ ("-gen-rewriters", "transforms/generated_lower_complex.inc"), ], - tblgen = "@local_config_mlir//:mlir-tblgen", + tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/lower_complex_patterns.td", td_srcs = [ ":hlo_ops_td_files", - "@llvm//:support", - "@local_config_mlir//:StdOpsTdFiles", + "@llvm-project//llvm:support", + "@llvm-project//mlir:StdOpsTdFiles", ], ) @@ -279,13 +280,13 @@ cc_library( deps = [ ":hlo", ":xla_dialect_registration", - "@llvm//:support", - "@local_config_mlir//:Analysis", - "@local_config_mlir//:IR", - "@local_config_mlir//:Pass", - "@local_config_mlir//:StandardOps", - "@local_config_mlir//:Support", - "@local_config_mlir//:Transforms", + "@llvm-project//llvm:support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", ], alwayslink = 1, ) @@ -310,13 +311,13 @@ cc_library( ":hlo_ops_base_inc_gen", ":hlo_ops_inc_gen", ":xla_canonicalize_inc_gen", - "@llvm//:support", - "@local_config_mlir//:Analysis", - "@local_config_mlir//:IR", - "@local_config_mlir//:Pass", - "@local_config_mlir//:StandardOps", - "@local_config_mlir//:Support", - "@local_config_mlir//:TransformUtils", + "@llvm-project//llvm:support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", ], alwayslink = 1, ) @@ -337,13 +338,13 @@ cc_library( deps = [ ":hlo_ops_base_inc_gen", ":lhlo_ops_inc_gen", - "@llvm//:support", - "@local_config_mlir//:Analysis", - "@local_config_mlir//:IR", - "@local_config_mlir//:Pass", - "@local_config_mlir//:StandardOps", - "@local_config_mlir//:Support", - "@local_config_mlir//:TransformUtils", + "@llvm-project//llvm:support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", ], alwayslink = 1, ) @@ -357,7 +358,7 @@ cc_library( ":convert_op_folder", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla/service:hlo", - "@local_config_mlir//:IR", + "@llvm-project//mlir:IR", ], alwayslink = 1, ) @@ -369,7 +370,7 @@ cc_library( deps = [ ":hlo", ":lhlo", - "@local_config_mlir//:IR", + "@llvm-project//mlir:IR", ], alwayslink = 1, ) @@ -388,9 +389,9 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core/platform:logging", "//tensorflow/core/platform:types", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:Support", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -405,7 +406,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test_main", - "@local_config_mlir//:IR", + "@llvm-project//mlir:IR", ], ) @@ -430,13 +431,13 @@ cc_library( "//tensorflow/compiler/xla/client/lib:slicing", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/stream_executor/lib", - "@llvm//:support", - "@local_config_mlir//:Analysis", - "@local_config_mlir//:IR", - "@local_config_mlir//:Pass", - "@local_config_mlir//:StandardOps", - "@local_config_mlir//:TransformUtils", - "@local_config_mlir//:Transforms", + "@llvm-project//llvm:support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", ], ) @@ -474,9 +475,9 @@ cc_library( "//tensorflow/compiler/xla:xla_proto_cc", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:lib", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:StandardOps", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:StandardOps", ], ) @@ -493,9 +494,9 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_proto_cc", "//tensorflow/core:lib", "@com_google_protobuf//:protobuf_headers", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:Translation", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Translation", ], alwayslink = 1, ) @@ -504,24 +505,24 @@ tf_native_cc_binary( name = "operator_writer_gen", srcs = ["operator_writer_gen.cc"], deps = [ - "@llvm//:support", - "@llvm//:tablegen", - "@local_config_mlir//:Support", - "@local_config_mlir//:TableGen", + "@llvm-project//llvm:support", + "@llvm-project//llvm:tablegen", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TableGen", ], ) genrule( name = "operator_writer_inc", srcs = [ - "@local_config_mlir//:include/mlir/IR/OpBase.td", + "@llvm-project//mlir:include/mlir/IR/OpBase.td", ":ir/hlo_ops.td", ":ir/hlo_ops_base.td", ":ir/hlo_utils.td", ], outs = ["operator_writers.inc"], cmd = ("$(location :operator_writer_gen) " + - "-I external/local_config_mlir/include " + + "-I external/llvm-project/mlir/include " + "$(location //tensorflow/compiler/mlir/xla:ir/hlo_ops.td) " + " -o $@"), tools = [":operator_writer_gen"], @@ -532,6 +533,6 @@ cc_library( srcs = ["convert_op_folder.cc"], hdrs = ["convert_op_folder.h"], deps = [ - "@local_config_mlir//:IR", + "@llvm-project//mlir:IR", ], ) diff --git a/tensorflow/compiler/mlir/xla/convert_op_folder.cc b/tensorflow/compiler/mlir/xla/convert_op_folder.cc index 8245b4a0585..dfd7cb39bf9 100644 --- a/tensorflow/compiler/mlir/xla/convert_op_folder.cc +++ b/tensorflow/compiler/mlir/xla/convert_op_folder.cc @@ -17,9 +17,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/xla/convert_op_folder.h" -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/TypeUtilities.h" // TF:llvm-project namespace mlir { namespace xla { diff --git a/tensorflow/compiler/mlir/xla/convert_op_folder.h b/tensorflow/compiler/mlir/xla/convert_op_folder.h index 63ac0e61df5..37a4db0227f 100644 --- a/tensorflow/compiler/mlir/xla/convert_op_folder.h +++ b/tensorflow/compiler/mlir/xla/convert_op_folder.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_XLA_CONVERT_OP_FOLDER_H_ #define TENSORFLOW_COMPILER_MLIR_XLA_CONVERT_OP_FOLDER_H_ -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project namespace mlir { namespace xla { diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc index fe468e26ff6..5300824aabc 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc @@ -19,14 +19,14 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/BlockAndValueMapping.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Identifier.h" // TF:local_config_mlir -#include "mlir/IR/Location.h" // TF:local_config_mlir -#include "mlir/IR/Region.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Identifier.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/Region.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/xla/protobuf_util.h" @@ -264,6 +264,12 @@ StatusOr HloFunctionImporter::ImportInstruction( attributes.push_back(ConvertComparisonDirection(instruction)); MakeAndReturn(CompareOp); } + case HloOpcode::kCholesky: { + attributes.push_back(builder_->getNamedAttr( + "lower", + builder_->getBoolAttr(instruction->cholesky_options().lower()))); + MakeAndReturn(CholeskyOp); + } case HloOpcode::kGather: { auto gather_instruction = static_cast(instruction); attributes.push_back(ConvertGatherDimensionNumbers( @@ -284,9 +290,21 @@ StatusOr HloFunctionImporter::ImportInstruction( return func_builder ->create( loc, result_type, operands[0], operands[1], - llvm::ArrayRef(operands.begin() + 2, operands.end())) + llvm::ArrayRef(operands.begin() + 2, operands.end())) .getOperation(); } + case HloOpcode::kInfeed: { + attributes.push_back(builder_->getNamedAttr( + "infeed_config", mlir::StringAttr::get(instruction->infeed_config(), + builder_->getContext()))); + MakeAndReturn(InfeedOp); + } + case HloOpcode::kOutfeed: { + attributes.push_back(builder_->getNamedAttr( + "outfeed_config", mlir::StringAttr::get(instruction->outfeed_config(), + builder_->getContext()))); + MakeAndReturn(OutfeedOp); + } case HloOpcode::kPad: { const auto& padding_config = instruction->padding_config(); llvm::SmallVector edge_padding_low; @@ -309,6 +327,12 @@ StatusOr HloFunctionImporter::ImportInstruction( Convert(interior_padding)) .getOperation(); } + case HloOpcode::kSetDimensionSize: { + attributes.push_back(builder_->getNamedAttr( + "dimension", builder_->getIntegerAttr(builder_->getIntegerType(32), + instruction->dimension()))); + MakeAndReturn(SetDimensionSizeOp); + } case HloOpcode::kSlice: { return func_builder ->create( @@ -359,9 +383,31 @@ StatusOr HloFunctionImporter::ImportInstruction( ConvertDimensions(instruction->dimensions())) .getOperation(); } + case HloOpcode::kRng: { + auto shape = func_builder->create( + loc, Convert(result_type.cast().getShape())); + switch (instruction->random_distribution()) { + case xla::RNG_UNIFORM: + return func_builder + ->create( + loc, result_type, operands[0], operands[1], shape) + .getOperation(); + + case xla::RNG_NORMAL: + return func_builder + ->create( + loc, result_type, operands[0], operands[1], shape) + .getOperation(); + + default: + return tensorflow::errors::InvalidArgument(absl::StrCat( + "Unsupported distribution: ", + RandomDistributionToString(instruction->random_distribution()))); + } + } case HloOpcode::kWhile: { auto op = func_builder->create( - loc, operands[0]->getType(), operands[0]); + loc, operands[0].getType(), operands[0]); TF_RETURN_IF_ERROR( ImportComputation(instruction->while_condition(), &op.cond())); TF_RETURN_IF_ERROR( @@ -461,10 +507,12 @@ StatusOr HloFunctionImporter::ImportInstruction( NoAttributeCase(kPower, PowOp); NoAttributeCase(kReal, RealOp); NoAttributeCase(kRemainder, RemOp); + NoAttributeCase(kReplicaId, ReplicaIdOp); // The dimensions attribute is not present on the HLO Reshape instruction. // If dimensions are non-default, the XLA builder implements it as a // separate transpose. NoAttributeCase(kReshape, ReshapeOp); + NoAttributeCase(kRoundNearestAfz, RoundOp); NoAttributeCase(kRsqrt, RsqrtOp); NoAttributeCase(kSelect, SelectOp); NoAttributeCase(kShiftLeft, ShiftLeftOp); @@ -500,9 +548,9 @@ StatusOr HloFunctionImporter::ImportInstruction( } } -StatusOr> HloFunctionImporter::GetOperands( +StatusOr> HloFunctionImporter::GetOperands( HloInstruction* instruction) { - llvm::SmallVector operands; + llvm::SmallVector operands; for (const auto& operand : instruction->operands()) { auto input_it = instruction_value_map_.find(operand); if (input_it == instruction_value_map_.end()) { @@ -590,8 +638,7 @@ tensorflow::Status HloFunctionImporter::GetMlirTypes( return tensorflow::Status::OK(); } -StatusOr HloFunctionImporter::GetMlirValue( - HloInstruction* instruction) { +StatusOr HloFunctionImporter::GetMlirValue(HloInstruction* instruction) { auto lookup = instruction_value_map_.find(instruction); if (lookup != instruction_value_map_.end()) { return lookup->second; diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.h b/tensorflow/compiler/mlir/xla/hlo_function_importer.h index bd36c9b2b54..9085e23ffd8 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.h +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.h @@ -18,12 +18,12 @@ limitations under the License. #include -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/xla/status.h" @@ -71,7 +71,7 @@ class HloFunctionImporter { mlir::OpBuilder* func_builder); // Gets the MLIR operand values from an HLO Instruction. - StatusOr> GetOperands( + StatusOr> GetOperands( xla::HloInstruction* instruction); // Converts xla Tensor type to the corresponding MLIR type. @@ -89,7 +89,7 @@ class HloFunctionImporter { llvm::SmallVectorImpl* types); // Returns the Mlir Value for the corresponding HloInstruction. - StatusOr GetMlirValue(xla::HloInstruction* instruction); + StatusOr GetMlirValue(xla::HloInstruction* instruction); // Converts an XLA PrecisionConfig to the corresponding MLIR attribute. mlir::NamedAttribute ConvertPrecisionConfig(xla::HloInstruction* instruction); @@ -129,7 +129,7 @@ class HloFunctionImporter { std::unordered_map* function_map_; // Mapping from HloInstructions to the associative MLIR values. - std::unordered_map instruction_value_map_; + std::unordered_map instruction_value_map_; }; } // namespace xla diff --git a/tensorflow/compiler/mlir/xla/hlo_module_importer.cc b/tensorflow/compiler/mlir/xla/hlo_module_importer.cc index 60a2b93d907..f8eabeb046d 100644 --- a/tensorflow/compiler/mlir/xla/hlo_module_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_module_importer.cc @@ -15,12 +15,12 @@ limitations under the License. #include "tensorflow/compiler/mlir/xla/hlo_module_importer.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Location.h" // TF:local_config_mlir -#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/Types.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/OperationSupport.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" diff --git a/tensorflow/compiler/mlir/xla/hlo_module_importer.h b/tensorflow/compiler/mlir/xla/hlo_module_importer.h index 5e8005f9489..c3e8c04cdcd 100644 --- a/tensorflow/compiler/mlir/xla/hlo_module_importer.h +++ b/tensorflow/compiler/mlir/xla/hlo_module_importer.h @@ -18,10 +18,10 @@ limitations under the License. #include -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/xla/status.h" diff --git a/tensorflow/compiler/mlir/xla/hlo_utils.cc b/tensorflow/compiler/mlir/xla/hlo_utils.cc index 7fa9dd71345..bfa57d97336 100644 --- a/tensorflow/compiler/mlir/xla/hlo_utils.cc +++ b/tensorflow/compiler/mlir/xla/hlo_utils.cc @@ -17,9 +17,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/xla/hlo_utils.h" -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/TypeUtilities.h" // TF:llvm-project #include "tensorflow/compiler/xla/literal.h" namespace xla { diff --git a/tensorflow/compiler/mlir/xla/hlo_utils.h b/tensorflow/compiler/mlir/xla/hlo_utils.h index b267b39ce5a..74bd4391395 100644 --- a/tensorflow/compiler/mlir/xla/hlo_utils.h +++ b/tensorflow/compiler/mlir/xla/hlo_utils.h @@ -18,9 +18,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_XLA_HLO_UTILS_H_ #define TENSORFLOW_COMPILER_MLIR_XLA_HLO_UTILS_H_ -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project #include "tensorflow/compiler/mlir/xla/convert_op_folder.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc index 41e561fd731..be0cd1bdc53 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc @@ -29,22 +29,22 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/MathExtras.h" -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Dialect.h" // TF:local_config_mlir -#include "mlir/IR/Location.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/OpDefinition.h" // TF:local_config_mlir -#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir -#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir -#include "mlir/IR/Types.h" // TF:local_config_mlir -#include "mlir/IR/Value.h" // TF:local_config_mlir -#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir -#include "mlir/Transforms/InliningUtils.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Dialect.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/OpDefinition.h" // TF:llvm-project +#include "mlir/IR/OpImplementation.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/OperationSupport.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/TypeUtilities.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/IR/Value.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/Transforms/InliningUtils.h" // TF:llvm-project #include "tensorflow/compiler/mlir/xla/convert_op_folder.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h.inc" #include "tensorflow/compiler/mlir/xla/ir/hlo_utils.h" @@ -175,7 +175,7 @@ void ConstOp::build(Builder* builder, OperationState& result, Attribute value) { //===----------------------------------------------------------------------===// OpFoldResult IotaOp::fold(ArrayRef operands) { - const auto output_type = getResult()->getType().cast(); + const auto output_type = getResult().getType().cast(); const auto output_size = output_type.getNumElements(); const auto dimension = iota_dimension().getSExtValue(); const auto max_dim_size = output_type.getDimSize(dimension); @@ -203,16 +203,15 @@ OpFoldResult IotaOp::fold(ArrayRef operands) { // AbsOp //===----------------------------------------------------------------------===// -void AbsOp::build(Builder* builder, OperationState& result, Value* operand) { - auto shaped_type = operand->getType().cast(); +void AbsOp::build(Builder* builder, OperationState& result, Value operand) { + auto shaped_type = operand.getType().cast(); Type new_type; if (!shaped_type.getElementType().isa()) { - new_type = operand->getType(); + new_type = operand.getType(); } else if (shaped_type.hasRank()) { - new_type = - RankedTensorType::get(shaped_type.getShape(), operand->getType()); + new_type = RankedTensorType::get(shaped_type.getShape(), operand.getType()); } else { - new_type = UnrankedTensorType::get(operand->getType()); + new_type = UnrankedTensorType::get(operand.getType()); } return AbsOp::build(builder, result, new_type, operand); @@ -222,10 +221,10 @@ void AbsOp::build(Builder* builder, OperationState& result, Value* operand) { // ConvertOp //===----------------------------------------------------------------------===// -void ConvertOp::build(Builder* builder, OperationState& result, Value* operand, +void ConvertOp::build(Builder* builder, OperationState& result, Value operand, Type result_element_ty) { Type result_ty; - Type operand_ty = operand->getType(); + Type operand_ty = operand.getType(); if (auto ranked_ty = operand_ty.dyn_cast()) { result_ty = RankedTensorType::get(ranked_ty.getShape(), result_element_ty); } else { @@ -235,7 +234,7 @@ void ConvertOp::build(Builder* builder, OperationState& result, Value* operand, } OpFoldResult ConvertOp::fold(ArrayRef operands) { - if (getOperand()->getType() == getResult()->getType()) return getOperand(); + if (getOperand().getType() == getResult().getType()) return getOperand(); // If the operand is constant, we can do the conversion now. if (auto elementsAttr = operands.front().dyn_cast_or_null()) { @@ -252,7 +251,7 @@ OpFoldResult ConvertOp::fold(ArrayRef operands) { static LogicalResult Verify(GetTupleElementOp op) { auto indexVal = op.index().getZExtValue(); - auto operandType = op.getOperand()->getType().cast(); + auto operandType = op.getOperand().getType().cast(); if (indexVal >= operandType.size()) { return op.emitOpError( llvm::formatv("index {0} is out of bounds of operand with size {1}", @@ -269,7 +268,7 @@ static LogicalResult Verify(GetTupleElementOp op) { OpFoldResult GetTupleElementOp::fold(ArrayRef operands) { if (auto tupleOp = - dyn_cast_or_null(getOperand()->getDefiningOp())) { + dyn_cast_or_null(getOperand().getDefiningOp())) { return tupleOp.getOperand(index().getLimitedValue()); } @@ -291,6 +290,25 @@ static LogicalResult Verify(TupleOp op) { return success(); } +//===----------------------------------------------------------------------===// +// AllToAllOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(AllToAllOp op) { + // If operand is ranked, size of split dimension should be a multiple of split + // count. + auto type = op.getOperand().getType().dyn_cast(); + if (!type) return success(); + auto split_dim_size = type.getDimSize(op.split_dimension().getSExtValue()); + auto split_count = op.split_count().getSExtValue(); + if (split_dim_size % split_count != 0) { + return op.emitError() << "split dimension has size " << split_dim_size + << ", expected to be a multiple of split_count " + << split_count; + } + return success(); +} + //===----------------------------------------------------------------------===// // BroadcastOp //===----------------------------------------------------------------------===// @@ -305,9 +323,9 @@ static LogicalResult Verify(BroadcastOp op) { "broadcast_sizes has rank {0} instead of rank 1", sizesRank)); } - auto resultType = op.getResult()->getType().cast(); + auto resultType = op.getResult().getType().cast(); auto resultRank = resultType.getRank(); - auto operandType = op.operand()->getType().cast(); + auto operandType = op.operand().getType().cast(); auto operandRank = operandType.getRank(); auto sizesSize = sizesType.getNumElements(); auto expectedRank = operandRank + sizesSize; @@ -341,7 +359,7 @@ static LogicalResult Verify(BroadcastOp op) { //===----------------------------------------------------------------------===// static LogicalResult Verify(BroadcastInDimOp op) { - auto operandType = op.operand()->getType().cast(); + auto operandType = op.operand().getType().cast(); auto operandRank = operandType.getRank(); if (!op.broadcast_dimensions()) { if (operandRank == 0) { @@ -368,7 +386,7 @@ static LogicalResult Verify(BroadcastInDimOp op) { dimensionsSize, operandRank)); } - auto resultType = op.getResult()->getType().cast(); + auto resultType = op.getResult().getType().cast(); auto resultRank = resultType.getRank(); if (resultRank < operandRank) { return op.emitOpError( @@ -403,9 +421,9 @@ static LogicalResult Verify(BroadcastInDimOp op) { //===----------------------------------------------------------------------===// static LogicalResult Verify(ClampOp op) { - auto operandType = op.operand()->getType().cast(); + auto operandType = op.operand().getType().cast(); auto operandShape = operandType.getShape(); - auto minType = op.min()->getType().cast(); + auto minType = op.min().getType().cast(); auto minShape = minType.getShape(); if (minShape != operandShape && minType.getRank() != 0) { @@ -415,7 +433,7 @@ static LogicalResult Verify(ClampOp op) { llvm::make_range(operandShape.begin(), operandShape.end()))); } - auto maxType = op.max()->getType().cast(); + auto maxType = op.max().getType().cast(); auto maxShape = maxType.getShape(); if (maxShape != operandShape && maxType.getRank() != 0) { return op.emitOpError(llvm::formatv( @@ -431,9 +449,9 @@ static LogicalResult Verify(ClampOp op) { // ComplexOp //===----------------------------------------------------------------------===// -void ComplexOp::build(Builder* builder, OperationState& state, Value* lhs, - Value* rhs) { - auto type = lhs->getType(); +void ComplexOp::build(Builder* builder, OperationState& state, Value lhs, + Value rhs) { + auto type = lhs.getType(); auto element_ty = ComplexType::get(getElementTypeOrSelf(type)); Type result_ty; if (auto ranked_type = type.dyn_cast()) { @@ -449,9 +467,9 @@ void ComplexOp::build(Builder* builder, OperationState& state, Value* lhs, OpFoldResult ComplexOp::fold(ArrayRef operands) { auto real_op = - dyn_cast_or_null(getOperand(0)->getDefiningOp()); + dyn_cast_or_null(getOperand(0).getDefiningOp()); auto imag_op = - dyn_cast_or_null(getOperand(1)->getDefiningOp()); + dyn_cast_or_null(getOperand(1).getDefiningOp()); if (real_op && imag_op && real_op.getOperand() == imag_op.getOperand()) { return real_op.getOperand(); } @@ -476,26 +494,26 @@ Type CreateRealType(Type type) { } } // namespace -void ImagOp::build(Builder* builder, OperationState& state, Value* val) { - build(builder, state, CreateRealType(val->getType()), val); +void ImagOp::build(Builder* builder, OperationState& state, Value val) { + build(builder, state, CreateRealType(val.getType()), val); } OpFoldResult ImagOp::fold(ArrayRef operands) { if (auto complex_op = - dyn_cast_or_null(getOperand()->getDefiningOp())) { + dyn_cast_or_null(getOperand().getDefiningOp())) { return complex_op.getOperand(1); } return {}; } -void RealOp::build(Builder* builder, OperationState& state, Value* val) { - build(builder, state, CreateRealType(val->getType()), val); +void RealOp::build(Builder* builder, OperationState& state, Value val) { + build(builder, state, CreateRealType(val.getType()), val); } OpFoldResult RealOp::fold(ArrayRef operands) { if (auto complex_op = - dyn_cast_or_null(getOperand()->getDefiningOp())) { + dyn_cast_or_null(getOperand().getDefiningOp())) { return complex_op.getOperand(0); } @@ -512,12 +530,12 @@ OpFoldResult ConcatenateOp::fold(ArrayRef operands) { } static LogicalResult Verify(ConcatenateOp op) { - auto firstType = op.getOperand(0)->getType().cast(); + auto firstType = op.getOperand(0).getType().cast(); auto firstShape = firstType.getShape(); int numOperands = op.getNumOperands(); for (int i = 1; i < numOperands; i++) { - auto secondType = op.getOperand(i)->getType().cast(); + auto secondType = op.getOperand(i).getType().cast(); if (firstType.getRank() != secondType.getRank()) { return op.emitOpError( @@ -552,18 +570,18 @@ void DynamicSliceOp::getCanonicalizationPatterns( //===----------------------------------------------------------------------===// OpFoldResult ReshapeOp::fold(ArrayRef operands) { - if (getOperand()->getType() == getType()) { + if (getOperand().getType() == getType()) { return getOperand(); } if (auto prev_op = - dyn_cast_or_null(getOperand()->getDefiningOp())) { + dyn_cast_or_null(getOperand().getDefiningOp())) { setOperand(prev_op.getOperand()); return getResult(); } if (auto elements = operands.front().dyn_cast_or_null()) { - return elements.reshape(getResult()->getType().cast()); + return elements.reshape(getResult().getType().cast()); } return {}; @@ -611,9 +629,9 @@ void ReduceOp::build(Builder* builder, OperationState& state, SmallVector result_ty; result_ty.reserve(operands.size()); - for (Value* operand : operands) { + for (Value operand : operands) { result_ty.push_back( - GetReduceResultType(operand->getType(), dimensions, builder)); + GetReduceResultType(operand.getType(), dimensions, builder)); } build(builder, state, result_ty, operands, init_values, dimensions); } @@ -622,7 +640,7 @@ LogicalResult ReduceOp::fold(ArrayRef operands, SmallVectorImpl& results) { // No dimensions to reduce. if (dimensions().getNumElements() == 0) { - for (Value* input : this->operands()) { + for (Value input : this->operands()) { results.push_back(input); } return success(); @@ -645,8 +663,8 @@ static LogicalResult Verify(SelectOp op) { //===----------------------------------------------------------------------===// static LogicalResult Verify(PadOp op) { - auto input_type = op.operand()->getType().cast(); - auto pad_type = op.padding_value()->getType().cast(); + auto input_type = op.operand().getType().cast(); + auto pad_type = op.padding_value().getType().cast(); if (pad_type.getRank() != 0) { return op.emitOpError( @@ -678,7 +696,7 @@ static LogicalResult Verify(PadOp op) { auto input_shape = input_type.getShape(); auto output_shape = - op.getResult()->getType().cast().getShape(); + op.getResult().getType().cast().getShape(); if (input_shape.size() != output_shape.size()) { return op.emitOpError( llvm::formatv("operand rank ({0}) and result rank({0}) should match", @@ -757,15 +775,15 @@ static Type GetBroadcastType(Builder* builder, Type x, Type y, } } // namespace -#define BINARY_BUILDER(Op) \ - void Op::build(Builder* builder, OperationState& result, Value* left, \ - Value* right, DenseIntElementsAttr broadcast_dimensions) { \ - auto type = GetBroadcastType(builder, left->getType().cast(), \ - right->getType().cast(), \ - getElementTypeOrSelf(right->getType()), \ - broadcast_dimensions); \ - return Op::build(builder, result, type, left, right, \ - broadcast_dimensions); \ +#define BINARY_BUILDER(Op) \ + void Op::build(Builder* builder, OperationState& result, Value left, \ + Value right, DenseIntElementsAttr broadcast_dimensions) { \ + auto type = GetBroadcastType(builder, left.getType().cast(), \ + right.getType().cast(), \ + getElementTypeOrSelf(right.getType()), \ + broadcast_dimensions); \ + return Op::build(builder, result, type, left, right, \ + broadcast_dimensions); \ } BINARY_BUILDER(AddOp); @@ -790,7 +808,7 @@ BINARY_BUILDER(XorOp); // SliceOp //===----------------------------------------------------------------------===// -void SliceOp::build(Builder* builder, OperationState& result, Value* operand, +void SliceOp::build(Builder* builder, OperationState& result, Value operand, DenseIntElementsAttr start_indices, DenseIntElementsAttr limit_indices, DenseIntElementsAttr strides) { @@ -811,11 +829,11 @@ static int64_t InferSliceDim(int64_t input_dim, int64_t start, int64_t end, return llvm::divideCeil(end - start, stride); } -Type SliceOp::InferOutputTypes(Builder* builder, Value* operand, +Type SliceOp::InferOutputTypes(Builder* builder, Value operand, DenseIntElementsAttr start_indices, DenseIntElementsAttr limit_indices, DenseIntElementsAttr strides) { - Type ty = operand->getType(); + Type ty = operand.getType(); RankedTensorType ranked_ty = ty.dyn_cast(); if (!ranked_ty) return ty; int64_t rank = ranked_ty.getRank(); @@ -852,7 +870,7 @@ void SortOp::build(Builder* builder, OperationState& state, ValueRange operands, SmallVector element_types; element_types.reserve(operands.size()); - for (Value* operand : operands) element_types.push_back(operand->getType()); + for (Value operand : operands) element_types.push_back(operand.getType()); state.addTypes(builder->getTupleType(element_types)); state.addRegion(); @@ -863,15 +881,14 @@ static LogicalResult Verify(SortOp op) { if (operands.empty()) return op.emitOpError("requires at least one input"); // TODO(antiagainst): verify partionally dynamic shapes - if (llvm::all_of(operands, [](Value* operand) { - return operand->getType().cast().hasRank(); + if (llvm::all_of(operands, [](Value operand) { + return operand.getType().cast().hasRank(); })) { ArrayRef input_shape = - (*operands.begin())->getType().cast().getShape(); + (*operands.begin()).getType().cast().getShape(); - if (llvm::any_of(llvm::drop_begin(operands, 1), [&](Value* operand) { - return operand->getType().cast().getShape() != - input_shape; + if (llvm::any_of(llvm::drop_begin(operands, 1), [&](Value operand) { + return operand.getType().cast().getShape() != input_shape; })) return op.emitOpError("requires all inputs to have the same dimensions"); @@ -889,10 +906,10 @@ static LogicalResult Verify(SortOp op) { for (auto indexed_operand : llvm::enumerate(operands)) { int index = indexed_operand.index(); Type element_type = - indexed_operand.value()->getType().cast().getElementType(); + indexed_operand.value().getType().cast().getElementType(); Type tensor_type = RankedTensorType::get({}, element_type); for (int i : {2 * index, 2 * index + 1}) { - Type arg_type = block.getArgument(i)->getType(); + Type arg_type = block.getArgument(i).getType(); if (arg_type != tensor_type) return op.emitOpError("comparator block argument #") << i << " should be of type " << tensor_type << " but got " @@ -926,7 +943,7 @@ static LogicalResult Verify(TransposeOp op) { } auto permutationSize = permutationType.getNumElements(); - auto operandType = op.operand()->getType().dyn_cast(); + auto operandType = op.operand().getType().dyn_cast(); if (operandType) { auto operandRank = operandType.getRank(); if (operandRank != permutationSize) { @@ -936,7 +953,7 @@ static LogicalResult Verify(TransposeOp op) { } } - auto resultType = op.getResult()->getType().dyn_cast(); + auto resultType = op.getResult().getType().dyn_cast(); if (resultType) { auto resultRank = resultType.getRank(); if (resultRank != permutationSize) { @@ -971,15 +988,15 @@ static LogicalResult Verify(TransposeOp op) { //===----------------------------------------------------------------------===// void GetTupleElementOp::build(Builder* builder, OperationState& result, - Value* tuple, int32_t index) { - if (auto tuple_type = tuple->getType().dyn_cast()) { + Value tuple, int32_t index) { + if (auto tuple_type = tuple.getType().dyn_cast()) { auto element_type = tuple_type.getType(index); build(builder, result, element_type, tuple, builder->getI32IntegerAttr(index)); return; } - build(builder, result, tuple->getType(), tuple, + build(builder, result, tuple.getType(), tuple, builder->getI32IntegerAttr(index)); } @@ -992,7 +1009,7 @@ void TupleOp::build(Builder* builder, OperationState& result, SmallVector types; types.reserve(values.size()); for (auto val : values) { - types.push_back(val->getType()); + types.push_back(val.getType()); } build(builder, result, builder->getTupleType(types), values); @@ -1011,10 +1028,10 @@ void UnaryEinsumOp::getCanonicalizationPatterns( // CompareOp //===----------------------------------------------------------------------===// -void CompareOp::build(Builder* builder, OperationState& result, Value* lhs, - Value* rhs, DenseIntElementsAttr broadcast_dimensions, +void CompareOp::build(Builder* builder, OperationState& result, Value lhs, + Value rhs, DenseIntElementsAttr broadcast_dimensions, StringAttr comparison_direction) { - auto new_type = GetBroadcastType(builder, lhs->getType(), rhs->getType(), + auto new_type = GetBroadcastType(builder, lhs.getType(), rhs.getType(), builder->getI1Type(), broadcast_dimensions); build(builder, result, new_type, lhs, rhs, broadcast_dimensions, comparison_direction); diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.h b/tensorflow/compiler/mlir/xla/ir/hlo_ops.h index 9610a787b7d..d0bc9619db9 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.h +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.h @@ -19,16 +19,16 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_OPS_H_ #include "llvm/ADT/StringRef.h" -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Dialect.h" // TF:local_config_mlir -#include "mlir/IR/DialectImplementation.h" // TF:local_config_mlir -#include "mlir/IR/Location.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/OpDefinition.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/Types.h" // TF:local_config_mlir -#include "mlir/Support/Functional.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Dialect.h" // TF:llvm-project +#include "mlir/IR/DialectImplementation.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/OpDefinition.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/Support/Functional.h" // TF:llvm-project namespace mlir { class OpBuilder; diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td index 6eeb32e804c..5c30ff8f134 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td @@ -56,7 +56,7 @@ def HLO_Tensor : TensorOf<[AnyFloat, AnyInteger, AnyComplex]>; def HLO_ComplexTensor : TensorOf<[AnyComplex]>; -def HLO_Tuple : NestedTupleOf<[HLO_Tensor]>; +def HLO_Tuple : NestedTupleOf<[HLO_Tensor, HLO_Token]>; def HLO_TensorOrTuple : AnyTypeOf<[HLO_Tensor, HLO_Tuple]>; @@ -76,6 +76,9 @@ def HLO_FpOrComplexTensor : TensorOf<[AnyFloat, AnyComplex]>; // Any int, floating-point or complex tensor types def HLO_IntFpOrComplexTensor : TensorOf<[HLO_Int, AnyFloat, AnyComplex]>; +// Any pred, int or floating-point tensor types +def HLO_PredIntOrFpTensor : TensorOf<[HLO_Pred, HLO_Int, AnyFloat]>; + //===----------------------------------------------------------------------===// // XLA nullary op definitions. //===----------------------------------------------------------------------===// @@ -128,7 +131,7 @@ class HLO_UnaryElementwiseOp traits, def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs", [NoSideEffect, SameOperandsAndResultShape], HLO_Tensor>, BASE_HLO_AbsOp { let builders = [OpBuilder< - "Builder *builder, OperationState &result, Value *operand" + "Builder *builder, OperationState &result, Value operand" >]; } @@ -140,7 +143,7 @@ def HLO_ConvertOp : HLO_UnaryElementwiseOp< BASE_HLO_ConvertOp { let builders = [OpBuilder< - "Builder *, OperationState &tblgen_state, Value *operand, " + "Builder *, OperationState &tblgen_state, Value operand, " "Type result_element_ty" >]; @@ -149,6 +152,10 @@ def HLO_ConvertOp : HLO_UnaryElementwiseOp< let hasCustomHLOConverter = 1; } +def HLO_ClzOp: HLO_UnaryElementwiseOp<"count_leading_zeros", + [NoSideEffect, SameOperandsAndResultType], HLO_IntTensor>, + BASE_HLO_ClzOp; + def HLO_CosOp: HLO_UnaryElementwiseOp<"cos", [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>, BASE_HLO_CosOp; @@ -191,6 +198,9 @@ def HLO_PopulationCountOp: HLO_UnaryElementwiseOp<"popcnt", [NoSideEffect, SameOperandsAndResultType], HLO_IntTensor>, BASE_HLO_PopulationCountOp; +def HLO_RoundOp: HLO_UnaryElementwiseOp<"round", + [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_RoundOp; + def HLO_RsqrtOp: HLO_UnaryElementwiseOp<"rsqrt", [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>, BASE_HLO_RsqrtOp; @@ -220,7 +230,7 @@ def HLO_ComplexOp: HLO_Op<"complex", [NoSideEffect, SameOperandsElementType, SameOperandsAndResultShape]>, BASE_HLO_ComplexOp { let builders = [OpBuilder< - "Builder *, OperationState &tblgen_state, Value *lhs, Value *rhs">]; + "Builder *, OperationState &tblgen_state, Value lhs, Value rhs">]; let arguments = (ins HLO_FpTensor:$lhs, HLO_FpTensor:$rhs); let results = (outs HLO_ComplexTensor); @@ -230,7 +240,7 @@ def HLO_ComplexOp: HLO_Op<"complex", def HLO_ImagOp: HLO_Op< "imag", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_ImagOp { let builders = [OpBuilder< - "Builder *, OperationState &tblgen_state, Value *val">]; + "Builder *, OperationState &tblgen_state, Value val">]; let arguments = (ins HLO_ComplexTensor); let results = (outs HLO_FpTensor); @@ -240,7 +250,7 @@ def HLO_ImagOp: HLO_Op< def HLO_RealOp: HLO_Op< "real", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_RealOp { let builders = [OpBuilder< - "Builder *, OperationState &tblgen_state, Value *val">]; + "Builder *, OperationState &tblgen_state, Value val">]; let arguments = (ins HLO_ComplexTensor); let results = (outs HLO_FpTensor); @@ -261,7 +271,7 @@ class HLO_BinaryElementwiseOp traits> : ); let builders = [OpBuilder< - "Builder *builder, OperationState &result, Value *left, Value* right, " + "Builder *builder, OperationState &result, Value left, Value right, " "DenseIntElementsAttr broadcast_dimensions" >]; @@ -324,6 +334,101 @@ def HLO_AndOp: HLO_BinaryLogicalElementwiseOp<"and">, BASE_HLO_AndOp; def HLO_OrOp: HLO_BinaryLogicalElementwiseOp<"or">, BASE_HLO_OrOp; def HLO_XorOp : HLO_BinaryLogicalElementwiseOp<"xor">, BASE_HLO_XorOp; +//===----------------------------------------------------------------------===// +// XLA communication op definitions. +//===----------------------------------------------------------------------===// + +// Represents a unique identifier for each Send/Recv instruction pair or +// optionally for collective instructions (AllReduce, CollectivePermute, +// AllToAll). Non-positive channel_id handle is equivalent to no channel id. +def ChannelHandle : StructAttr<"ChannelHandle", HLO_Dialect, [ + StructFieldAttr<"handle", I64Attr>, + StructFieldAttr<"type", I64Attr>]> { + let description = "two 64-bit integers 'handle' and 'type'"; +} + +// InfeedOp corresponds to 'InfeedWithToken' xla client API and not 'Infeed'. +// InfeedWithToken allows ordering of infeed HLO instructions using tokens. +def HLO_InfeedOp : HLO_Op<"infeed", []> { + + string summary = "Infeed operator"; + + string description = [{ + Reads a single data item from the implicit Infeed streaming interface of + the device, interpreting the data as the given shape, and returns a XlaOp + of the data. Multiple Infeed operations are allowed in a computation, but + there must be a total order among the Infeed operations. + + See https://www.tensorflow.org/xla/operation_semantics#infeed. + }]; + + let arguments = (ins + HLO_Token:$token, + DefaultValuedAttr:$infeed_config + ); + let results = (outs HLO_Tuple); + let hasCustomHLOConverter = 1; +} + +// OutfeedOp corresponds to 'OutfeedWithToken' xla client API and not 'Outfeed'. +// OutfeedWithToken allows ordering of outfeed HLO instructions using tokens. +def HLO_OutfeedOp : HLO_Op<"outfeed", []> { + + string summary = "Outfeed operator"; + + string description = [{ + Generates outgoing data transfers for the given data. It takes data and a + token type operand and produces a token type value. Tokens are used for + ordering side-effecting operations. + + See https://www.tensorflow.org/xla/operation_semantics#outfeed. + }]; + + let arguments = (ins + HLO_TensorOrTuple:$operand, + HLO_Token:$token, + DefaultValuedAttr:$outfeed_config + ); + let results = (outs HLO_Token); + let hasCustomHLOConverter = 1; +} + +def HLO_SendOp : HLO_Op<"send", []> { + + string summary = "Send operator"; + + string description = [{ + Sends the given operand data to a Recv instruction in another computation + that shares the same channel handle. Does not return any data. Similar to + the Recv operation, Send operation represents synchronous communication, + and is internally decomposed into 2 HLO instructions (Send and SendDone) to + enable asynchronous data transfers. + + See https://www.tensorflow.org/xla/operation_semantics#send. + }]; + + let arguments = (ins + HLO_TensorOrTuple:$operand, + HLO_Token:$token, + ChannelHandle:$channel_id, + DefaultValuedAttr:$is_host_transfer + ); + + let results = (outs HLO_Token); + let hasCustomHLOConverter = 1; +} + +//===----------------------------------------------------------------------===// +// XLA parallelism related op definitions. +//===----------------------------------------------------------------------===// + +def HLO_ReplicaIdOp : HLO_Op<"replica_id", [NoSideEffect]>, + BASE_HLO_ReplicaIdOp { + // TODO(prakalps): The output should unsigned 32-bit integer but mlir does + // not differentiate between signed and unsigned int. + let results = (outs I32Tensor); +} + //===----------------------------------------------------------------------===// // XLA control flow op definitions. //===----------------------------------------------------------------------===// @@ -343,7 +448,6 @@ def HLO_AfterAllOp : HLO_Op<"after_all", []> { let arguments = (ins Variadic:$operands); let results = (outs HLO_Token); - let hasCustomHLOConverter = 1; } def HLO_ConditionalOp: HLO_Op<"conditional", [NoSideEffect]> { @@ -390,15 +494,6 @@ def HLO_WhileOp: HLO_Op<"while", [NoSideEffect, SameOperandsAndResultType]> { let hasCustomHLOConverter = 1; } -// Represents a unique identifier for each Send/Recv instruction pair or -// optionally for collective instructions (AllReduce, CollectivePermute, -// AllToAll). Non-positive channel_id handle is equivalent to no channel id. -def ChannelHandle : StructAttr<"ChannelHandle", HLO_Dialect, [ - StructFieldAttr<"handle", I64Attr>, - StructFieldAttr<"type", I64Attr>]> { - let description = "two 64-bit integers 'handle' and 'type'"; -} - def HLO_AllReduceOp : HLO_Op<"all_reduce", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_AllReduceOp { @@ -413,6 +508,19 @@ def HLO_AllReduceOp : HLO_Op<"all_reduce", let hasCustomHLOConverter = 1; } +def HLO_AllToAllOp : HLO_Op<"all_to_all", + [NoSideEffect, SameOperandsElementType, SameOperandsShape]>, BASE_HLO_AllToAllOp { + + let arguments = (ins + HLO_Tensor:$operand, + I64Attr:$split_dimension, + I64Attr:$concat_dimension, + I64Attr:$split_count, + I64ElementsAttr:$replica_groups + ); + let results = (outs HLO_Tensor); +} + def HLO_ReduceOp: HLO_Op<"reduce", [ NoSideEffect, SameVariadicOperandSize, @@ -458,7 +566,7 @@ def HLO_GetTupleElementOp: HLO_Op<"get_tuple_element", [NoSideEffect]>, BASE_HLO let builders = [OpBuilder< "Builder *builder, OperationState &results, " - "Value* value, int32_t index">]; + "Value value, int32_t index">]; } def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp { @@ -469,8 +577,6 @@ def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp { "Builder *builder, OperationState &results, " "ValueRange values">]; - // TupleOp has special conversion logic to HLO. - let hasCustomHLOConverter = 1; } def HLO_CompareOp: HLO_Op<"compare", @@ -482,14 +588,14 @@ def HLO_CompareOp: HLO_Op<"compare", HLO_ComparisonDirectionAttr:$comparison_direction ); let builders = [OpBuilder< - "Builder *builder, OperationState &result, Value *left, Value* right, " + "Builder *builder, OperationState &result, Value left, Value right, " "DenseIntElementsAttr broadcast_dimensions, " "StringAttr comparison_direction" >]; let results = (outs HLO_PredTensor); let builders = [OpBuilder< - "Builder *builder, OperationState &result, Value *lhs, Value *rhs, " + "Builder *builder, OperationState &result, Value lhs, Value rhs, " "DenseIntElementsAttr broadcast_dimensions, StringAttr comparison_direction" >]; } @@ -512,7 +618,7 @@ def HLO_SliceOp: HLO_Op< let results = (outs HLO_Tensor); let builders = [OpBuilder< - "Builder *builder, OperationState &result, Value *operand, " + "Builder *builder, OperationState &result, Value operand, " "DenseIntElementsAttr start_indices, DenseIntElementsAttr limit_indices, " "DenseIntElementsAttr strides" >]; @@ -520,7 +626,7 @@ def HLO_SliceOp: HLO_Op< let extraClassDeclaration = [{ // Infers output type for given operand and attributes. Result type is // unranked if any of the attributes is illegal. - static Type InferOutputTypes(Builder *builder, Value *operand, + static Type InferOutputTypes(Builder *builder, Value operand, DenseIntElementsAttr start_indices, DenseIntElementsAttr limit_indices, DenseIntElementsAttr strides); @@ -572,8 +678,8 @@ def HLO_BatchNormGradOp : HLO_Op<"batch_norm_grad", [NoSideEffect]>, let results = (outs HLO_Tuple); } -def HLO_BatchNormInferenceOp : HLO_Op<"batch_norm_inference", [NoSideEffect]>, - BASE_HLO_BatchNormInferenceOp { +def HLO_BatchNormInferenceOp : HLO_Op<"batch_norm_inference", + [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_BatchNormInferenceOp { let arguments = (ins HLO_Tensor:$operand, @@ -634,6 +740,16 @@ def HLO_BroadcastInDimOp : HLO_Op<"broadcast_in_dim", let hasCustomHLOConverter = 1; } +def HLO_CholeskyOp : HLO_Op<"cholesky", + [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_CholeskyOp { + let arguments = (ins + HLO_FpOrComplexTensor:$a, + DefaultValuedAttr:$lower + ); + + let results = (outs HLO_FpOrComplexTensor); +} + def HLO_ClampOp : HLO_Op<"clamp", [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ClampOp { let arguments = (ins @@ -657,8 +773,6 @@ def HLO_ConcatenateOp : HLO_Op<"concatenate", let hasFolder = 1; - // TODO(b/129422361) ConcatOp has special conversion logic to HLO. - let hasCustomHLOConverter = 1; } def HLO_CrossReplicaSumOp : HLO_Op<"cross-replica-sum", @@ -708,8 +822,6 @@ def HLO_ConvOp : HLO_Op<"conv", [NoSideEffect]>, BASE_HLO_ConvOp { let results = (outs HLO_Tensor); - // TODO(b/129422361): Conv Op has special conversion logic to HLO. - let hasCustomHLOConverter = 1; } def HLO_CopyOp: HLO_Op<"copy", [NoSideEffect, SameOperandsAndResultType]> { @@ -751,7 +863,9 @@ def HLO_DotGeneralOp: HLO_Op<"dot_general", [NoSideEffect]>, BASE_HLO_DotGeneral let results = (outs HLO_Tensor); } -def BASE_EinsumOp { +// Define Base Einsum op within the HLO dialect as these are client ops and +// therefore this class is not common between HLO and LHLO ops. +class BASE_EinsumOp { string summary = "Einsum operator"; string description = [{ @@ -760,7 +874,7 @@ def BASE_EinsumOp { }]; } -def HLO_EinsumOp: HLO_Op<"einsum", [NoSideEffect]> { +def HLO_EinsumOp: HLO_Op<"einsum", [NoSideEffect]>, BASE_EinsumOp { let arguments = (ins HLO_Tensor:$lhs, HLO_Tensor:$rhs, @@ -773,7 +887,7 @@ def HLO_EinsumOp: HLO_Op<"einsum", [NoSideEffect]> { // side HLO ops. } -def HLO_UnaryEinsumOp: HLO_Op<"unary_einsum", [NoSideEffect]> { +def HLO_UnaryEinsumOp: HLO_Op<"unary_einsum", [NoSideEffect]>, BASE_EinsumOp { let arguments = (ins HLO_Tensor:$operand, StrAttr:$einsum_config @@ -796,9 +910,6 @@ def HLO_FftOp: HLO_Op<"fft", [NoSideEffect]>, BASE_HLO_FftOp { ); let results = (outs HLO_Tensor); - - // TODO(b/129422361) Attributes are not supported by the codegen. - let hasCustomHLOConverter = 1; } def GatherDimensionNumbers : StructAttr<"GatherDimensionNumbers", HLO_Dialect, @@ -819,8 +930,6 @@ def HLO_GatherOp: HLO_Op<"gather", [NoSideEffect]>, BASE_HLO_GatherOp { ); let results = (outs HLO_Tensor); - - let hasCustomHLOConverter = 1; } def HLO_GetDimensionSizeOp: HLO_Op<"get_dimension_size", [NoSideEffect]>, @@ -896,6 +1005,16 @@ def HLO_SelectAndScatterOp: HLO_Op<"select_and_scatter", let hasCustomHLOConverter = 1; } +def HLO_SetDimensionSizeOp: HLO_Op<"set_dimension_size", [NoSideEffect]>, + BASE_HLO_SetDimensionSizeOp { + let arguments = (ins + HLO_Tensor:$operand, + I32Tensor:$size, + I32Attr:$dimension + ); + let results = (outs HLO_Tensor); +} + def HLO_SortOp : HLO_Op<"sort", [NoSideEffect]>, BASE_HLO_SortOp { let arguments = (ins Variadic:$operands, @@ -926,9 +1045,6 @@ def HLO_ReverseOp: HLO_Op<"reverse", let results = (outs HLO_Tensor); let hasFolder = 1; - - // TODO(b/129422361): ReverseOp has a custom constructor for HLO. - let hasCustomHLOConverter = 1; } def HLO_PadOp: HLO_Op<"pad", @@ -1029,12 +1145,24 @@ def HLO_TorchIndexSelectOp : HLO_Op<"torch_index_select", [NoSideEffect]> { //===----------------------------------------------------------------------===// def HLO_RngUniformOp : HLO_Op<"rng_uniform", []>, BASE_HLO_RngUniformOp { let arguments = (ins - HLO_Tensor:$a, - HLO_Tensor:$b, + HLO_PredIntOrFpTensor:$a, + HLO_PredIntOrFpTensor:$b, I64Tensor:$shape ); - let results = (outs HLO_Tensor); + let results = (outs HLO_PredIntOrFpTensor); + + let hasCustomHLOConverter = 1; +} + +def HLO_RngNormalOp : HLO_Op<"rng_normal", []>, BASE_HLO_RngNormalOp { + let arguments = (ins + HLO_FpTensor:$mu, + HLO_FpTensor:$sigma, + I64Tensor:$shape + ); + + let results = (outs HLO_FpTensor); let hasCustomHLOConverter = 1; } diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td index 3be2c26a1bf..f2010bb56cb 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td @@ -68,6 +68,17 @@ class BASE_HLO_CeilOp { }]; } +class BASE_HLO_ClzOp { + string summary = "Count-leading-zeros (Clz) operator"; + + string description = [{ + Returns the number of leading zeros in each operand element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} + class BASE_HLO_ComplexOp { string summary = "Complex operator"; @@ -228,6 +239,18 @@ class BASE_HLO_RealOp { }]; } +class BASE_HLO_RoundOp { + string summary = "Round operator"; + + string description = [{ + Returns `Round(operand)` element-wise, rounding to nearest integer with + half-way cases rounding away from zero. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} + class BASE_HLO_RsqrtOp { string summary = "Reciprocal Square-root operator"; @@ -465,6 +488,26 @@ class BASE_HLO_XorOp { }]; } +//===----------------------------------------------------------------------===// +// XLA parallelism related op definitions. +//===----------------------------------------------------------------------===// + +class BASE_HLO_ReplicaIdOp { + string summary = "ReplicaId operator"; + + string description = [{ + Returns the unique ID (int32 scalar) of the replica. + + The unique ID of each replica is an unsigned integer in the interval [0, N), + where N is the number of replicas. Since all the replicas are running the + same program, a ReplicaId() call in the program will return a different + value on each replica. + + See https://www.tensorflow.org/xla/operation_semantics#replicaid. + }]; +} + + class BASE_HLO_AllReduceOp { string summary = "AllReduce operator"; @@ -626,6 +669,39 @@ class BASE_HLO_DynamicUpdateSliceOp { // XLA Other op definitions. //===----------------------------------------------------------------------===// +class BASE_HLO_AllToAllOp { + string summary = "AllToAll"; + + string description = [{ + AllToAll is a collective operation that sends data from all cores to all + cores. It has two phases: + - The scatter phase. On each core, the operand is split into `split_count` + number of blocks along the `split_dimensions`, and the blocks are + scattered to all cores, e.g., the i-th block is send to the i-th core. + - The gather phase. Each core concatenates the received blocks along the + `concat_dimension`. + + The participating cores can be configured by: + - replica_groups: each ReplicaGroup contains a list of replica id + participating in the computation (replica id for the current replica can + be retrieved using ReplicaId op). AllToAll will be applied within + subgroups in the specified order. For example, + `replica_groups` = {{1,2,3}, {4,5,0}} means that an AllToAll will be applied + within replicas {1, 2, 3}, and in the gather phase, the received blocks + will be concatenated in the same order of 1, 2, 3. Then, another AllToAll + will be applied within replicas 4, 5, 0, and the concatenation order is + also 4, 5, 0. If `replica_groups` is empty, all replicas belong to one + group, in the concatenation order of their appearance. + + Prerequisites: + - The dimension size of the operand on the split_dimension is divisible by + `split_count`. + - The operand's shape is not tuple. + + See https://www.tensorflow.org/xla/operation_semantics#alltoall + }]; +} + class BASE_HLO_BatchNormGradOp { string summary = "Batch Normalization Gradient"; @@ -707,6 +783,32 @@ class BASE_HLO_BroadcastInDimOp { }]; } +class BASE_HLO_CholeskyOp { + string summary = "Cholesky operator"; + + string description = [{ + Computes the Cholesky decomposition of a batch of symmetric (Hermitian) + positive definite matrices. + + If lower is true, computes lower-triangular matrices l such that + `a=l.Transpose(l)`. If lower is false, computes upper-triangular matrices u such + that `a=Transpose(u).u`. + + Input data is read only from the lower/upper triangle of a, depending on the + value of lower. Values from the other triangle are ignored. Output data is + returned in the same triangle; the values in the other triangle are + implementation-defined and may be anything. + + If the rank of a is greater than 2, a is treated as a batch of matrices, where + all except the minor 2 dimensions are batch dimensions. + + If a is not symmetric (Hermitian) positive definite, the result is + implementation-defined. + + See https://www.tensorflow.org/xla/operation_semantics#cholesky. + }]; +} + class BASE_HLO_ClampOp { string summary = "Clamp operator"; @@ -846,6 +948,18 @@ class BASE_HLO_SelectAndScatterOp { }]; } +class BASE_HLO_SetDimensionSizeOp { + string summary = "SetDimensionSize operator"; + + string description = [{ + Sets the dynamic size of operand's given dimension. Pass through the operand + as result, with dynamic dimension tracked by the compiler. Padded values + will be ignored by downstream reduction ops. + + See https://www.tensorflow.org/xla/operation_semantics#setdimensionsize. + }]; +} + class BASE_HLO_SortOp { string summary = "Sort operator"; @@ -895,11 +1009,26 @@ class BASE_HLO_RngUniformOp { string summary = "RNG with uniform distribution."; string description = [{ - Constructs an output of a given shape with random numbers generated following - the uniform distribution over the interval `[a,b)`. + Constructs an output of a given shape with random numbers generated + following the uniform distribution over the interval `[a,b)`. The parameters + and output element type have to be a boolean type, an integral type or a + floating point types, and the types have to be consistent. See https://www.tensorflow.org/xla/operation_semantics#rnguniform. }]; } +class BASE_HLO_RngNormalOp { + string summary = "RNG with normal distribution."; + + string description = [{ + Constructs an output of a given shape with random numbers generated + following the normal distribution with parameters `mu` and `sigma`. The + parameters and output shape have to have a floating point elemental type. + The parameters furthermore have to be scalar valued. + + See https://www.tensorflow.org/xla/operation_semantics#rngnormal. + }]; +} + #endif // HLO_OPS_BASE diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_utils.cc b/tensorflow/compiler/mlir/xla/ir/hlo_utils.cc index 7d3e2ca2384..08f4dc536cf 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_utils.cc +++ b/tensorflow/compiler/mlir/xla/ir/hlo_utils.cc @@ -17,15 +17,14 @@ limitations under the License. #include -#include "mlir/IR/Attributes.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:llvm-project namespace mlir { namespace xla { -DenseIntElementsAttr getBroadcastDimensionsAttr(Builder *b, Value *x, - Value *y) { - TensorType xType = x->getType().dyn_cast(); - TensorType yType = y->getType().dyn_cast(); +DenseIntElementsAttr getBroadcastDimensionsAttr(Builder *b, Value x, Value y) { + TensorType xType = x.getType().dyn_cast(); + TensorType yType = y.getType().dyn_cast(); if (xType == yType || !xType || !yType) return {}; // If the shapes have the same rank, then there is nothing to do. diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_utils.h b/tensorflow/compiler/mlir/xla/ir/hlo_utils.h index 86c90b49f16..120b035e5d0 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_utils.h +++ b/tensorflow/compiler/mlir/xla/ir/hlo_utils.h @@ -16,11 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_UTILS_H_ #define TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_UTILS_H_ -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/TypeUtilities.h" // TF:llvm-project #include "tensorflow/compiler/mlir/xla/convert_op_folder.h" namespace mlir { @@ -29,14 +29,14 @@ namespace xla { // Computes the broadcast dimensions attr for an elementwise binary operator // between two ranked tensors. mlir::DenseIntElementsAttr getBroadcastDimensionsAttr(mlir::Builder* b, - mlir::Value* x, - mlir::Value* y); + mlir::Value x, + mlir::Value y); /// Get a constant splat for the given value type. template -static ElementsAttr getSplat(Builder* b, Value* val, T constant) { - auto valType = val->getType().cast(); - auto valElementType = getElementTypeOrSelf(val->getType()); +static ElementsAttr getSplat(Builder* b, Value val, T constant) { + auto valType = val.getType().cast(); + auto valElementType = getElementTypeOrSelf(val.getType()); // Handle integer elements. Attribute elementAttr; diff --git a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.cc index c121aa703a3..0fbe5915fe8 100644 --- a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.cc @@ -28,20 +28,20 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/FormatVariadic.h" -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Dialect.h" // TF:local_config_mlir -#include "mlir/IR/Location.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/OpDefinition.h" // TF:local_config_mlir -#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir -#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir -#include "mlir/IR/Types.h" // TF:local_config_mlir -#include "mlir/IR/Value.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Dialect.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/OpDefinition.h" // TF:llvm-project +#include "mlir/IR/OpImplementation.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/OperationSupport.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/TypeUtilities.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/IR/Value.h" // TF:llvm-project #include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h.inc" namespace mlir { diff --git a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.h b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.h index f73e5026541..1a07b1a45f3 100644 --- a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.h +++ b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.h @@ -19,15 +19,15 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_XLA_IR_LHLO_OPS_H_ #include "llvm/ADT/StringRef.h" -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Dialect.h" // TF:local_config_mlir -#include "mlir/IR/Location.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/OpDefinition.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/Types.h" // TF:local_config_mlir -#include "mlir/Support/Functional.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Dialect.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/OpDefinition.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/Support/Functional.h" // TF:llvm-project namespace mlir { class OpBuilder; diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index 5c351876440..c64b4ef9f4a 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -27,16 +27,16 @@ limitations under the License. #include "llvm/Support/SMLoc.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/Location.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Matchers.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Matchers.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/TypeUtilities.h" // TF:llvm-project #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/type_to_shape.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" @@ -91,6 +91,8 @@ static double ConvertAPFloat(llvm::APFloat value) { return value.convertToDouble(); } +static inline bool Convertbool(bool value) { return value; } + static absl::string_view ConvertStringRef(mlir::StringRef value) { return {value.data(), value.size()}; } @@ -115,6 +117,15 @@ static std::vector Convert_broadcast_dimensions( return ConvertDenseIntAttr(*broadcast_dimensions); } +// Converts StringRef to xla FftType enum +static xla::FftType Convert_fft_type(llvm::StringRef fft_type_str) { + xla::FftType fft_type_enum; + // Illegal fft_type string would be caught by the verifier, so 'FftType_Parse' + // call below should never return false. + if (!FftType_Parse(fft_type_str, &fft_type_enum)) return xla::FftType::FFT; + return fft_type_enum; +} + // Convert a nx2 dense attribute to a list of tuples. This is the way padding // is defined in hlo. static std::vector> Convert_padding( @@ -151,10 +162,10 @@ static std::vector Convert_replica_groups( return result; } -#define I64_ELEMENTS_ATTR_TO_VECTOR(attribute) \ - static std::vector Convert_##attribute( \ - mlir::DenseIntElementsAttr attribute) { \ - return ConvertDenseIntAttr(attribute); \ +#define I64_ELEMENTS_ATTR_TO_VECTOR(attribute) \ + static std::vector Convert_##attribute( \ + llvm::Optional attribute) { \ + return ConvertDenseIntAttr(attribute); \ } I64_ELEMENTS_ATTR_TO_VECTOR(broadcast_sizes); @@ -163,6 +174,11 @@ I64_ELEMENTS_ATTR_TO_VECTOR(start_indices); I64_ELEMENTS_ATTR_TO_VECTOR(limit_indices); I64_ELEMENTS_ATTR_TO_VECTOR(strides); I64_ELEMENTS_ATTR_TO_VECTOR(slice_sizes); +I64_ELEMENTS_ATTR_TO_VECTOR(fft_length); +I64_ELEMENTS_ATTR_TO_VECTOR(dimensions); +I64_ELEMENTS_ATTR_TO_VECTOR(window_strides); +I64_ELEMENTS_ATTR_TO_VECTOR(lhs_dilation); +I64_ELEMENTS_ATTR_TO_VECTOR(rhs_dilation); #undef I64_ELEMENTS_ATTR_TO_VECTOR @@ -230,7 +246,7 @@ static xla::DotDimensionNumbers Convert_dot_dimension_numbers( return dot_dimension_numbers; } -static xla::ConvolutionDimensionNumbers Convert_convolution_dimension_numbers( +static xla::ConvolutionDimensionNumbers Convert_dimension_numbers( mlir::xla_hlo::ConvDimensionNumbers input) { xla::ConvolutionDimensionNumbers output; @@ -281,7 +297,7 @@ static xla::ComparisonDirection Convert_comparison_direction( .ValueOrDie(); } -static xla::GatherDimensionNumbers Convert_gather_dimension_numbers( +static xla::GatherDimensionNumbers Convert_dimension_numbers( mlir::xla_hlo::GatherDimensionNumbers input) { xla::GatherDimensionNumbers output; @@ -335,7 +351,7 @@ namespace mlir { namespace { class ConvertToHloModule { public: - using ValueLoweringMap = llvm::DenseMap; + using ValueLoweringMap = llvm::DenseMap; using FunctionLoweringMap = llvm::DenseMap; // If use_tuple_args is true, then the entry function's arguments are @@ -417,7 +433,7 @@ class ConvertToHloModule { namespace { struct OpLoweringContext { - llvm::DenseMap* values; + llvm::DenseMap* values; mlir::ConvertToHloModule* converter; xla::XlaBuilder* builder; }; @@ -425,7 +441,7 @@ struct OpLoweringContext { llvm::SmallVector GetTuple(mlir::Operation::operand_range values, OpLoweringContext ctx) { llvm::SmallVector ops; - for (mlir::Value* value : values) { + for (mlir::Value value : values) { ops.push_back((*ctx.values)[value]); } return ops; @@ -437,16 +453,6 @@ namespace mlir { namespace xla_hlo { namespace { -LogicalResult ExportXlaOp(AfterAllOp op, OpLoweringContext ctx) { - auto& value_map = *ctx.values; - std::vector tokens(op.operands().size()); - for (auto index_and_value : llvm::enumerate(op.operands())) { - tokens[index_and_value.index()] = value_map[index_and_value.value()]; - } - value_map[op] = xla::AfterAll(ctx.builder, tokens); - return mlir::success(); -} - LogicalResult ExportXlaOp(AllReduceOp op, OpLoweringContext ctx) { auto& value_map = *ctx.values; xla::XlaComputation computation; @@ -485,13 +491,6 @@ LogicalResult ExportXlaOp(BroadcastInDimOp op, OpLoweringContext ctx) { return success(); } -LogicalResult ExportXlaOp(ConcatenateOp op, OpLoweringContext ctx) { - auto& value_map = *ctx.values; - value_map[op] = xla::ConcatInDim(ctx.builder, GetTuple(op.val(), ctx), - op.dimension().getSExtValue()); - return success(); -} - LogicalResult ExportXlaOp(ConditionalOp op, OpLoweringContext ctx) { xla::XlaComputation true_branch; xla::XlaComputation false_branch; @@ -514,21 +513,6 @@ LogicalResult ExportXlaOp(ConstOp op, OpLoweringContext ctx) { return failure(); } -LogicalResult ExportXlaOp(ConvOp op, OpLoweringContext ctx) { - auto& value_map = *ctx.values; - value_map[op] = xla::ConvGeneralDilated( - value_map[op.lhs()], value_map[op.rhs()], - Convert_broadcast_dimensions(op.window_strides()), - Convert_padding(op.padding()), - Convert_broadcast_dimensions(op.lhs_dilation()), - Convert_broadcast_dimensions(op.rhs_dilation()), - Convert_convolution_dimension_numbers(op.dimension_numbers()), - op.feature_group_count().getSExtValue(), - op.batch_group_count().getSExtValue(), - Convert_precision_config(op.precision_config()).get()); - return success(); -} - LogicalResult ExportXlaOp(ConvertOp op, OpLoweringContext ctx) { auto& value_map = *ctx.values; value_map[op] = xla::ConvertElementType( @@ -537,19 +521,13 @@ LogicalResult ExportXlaOp(ConvertOp op, OpLoweringContext ctx) { return success(); } -LogicalResult ExportXlaOp(CopyOp op, OpLoweringContext ctx) { - return failure(); -} - -LogicalResult ExportXlaOp(FftOp op, OpLoweringContext ctx) { return failure(); } - -LogicalResult ExportXlaOp(GatherOp op, OpLoweringContext ctx) { +LogicalResult ExportXlaOp(InfeedOp op, OpLoweringContext ctx) { auto& value_map = *ctx.values; - xla::GatherDimensionNumbers dimension_numbers = - Convert_gather_dimension_numbers(op.dimension_numbers()); - value_map[op] = xla::Gather( - value_map[op.operand()], value_map[op.start_indices()], dimension_numbers, - Convert_slice_sizes(op.slice_sizes()), op.indices_are_sorted()); + // The shape argument expected by the xla client API is the type of the first + // element in the result tuple. + auto result_type = op.getType().cast().getType(0); + value_map[op] = xla::InfeedWithToken( + value_map[op.token()], xla::TypeToShape(result_type), op.infeed_config()); return success(); } @@ -560,6 +538,14 @@ LogicalResult ExportXlaOp(IotaOp op, OpLoweringContext ctx) { return success(); } +LogicalResult ExportXlaOp(OutfeedOp op, OpLoweringContext ctx) { + auto& value_map = *ctx.values; + value_map[op] = xla::OutfeedWithToken( + value_map[op.operand()], value_map[op.token()], + xla::TypeToShape(op.operand().getType()), op.outfeed_config()); + return success(); +} + LogicalResult ExportXlaOp(PadOp op, OpLoweringContext ctx) { auto& value_map = *ctx.values; xla::PaddingConfig padding_config; @@ -627,10 +613,10 @@ LogicalResult ExportXlaOp(ReturnOp op, OpLoweringContext ctx) { return failure(); } -LogicalResult ExportXlaOp(ReverseOp op, OpLoweringContext ctx) { +LogicalResult ExportXlaOp(RngNormalOp op, OpLoweringContext ctx) { auto& value_map = *ctx.values; - value_map[op] = xla::Rev(value_map[op.operand()], - Convert_broadcast_dimensions(op.dimensions())); + value_map[op] = xla::RngNormal(value_map[op.mu()], value_map[op.sigma()], + xla::TypeToShape(op.getType())); return success(); } @@ -674,6 +660,21 @@ LogicalResult ExportXlaOp(SelectAndScatterOp op, OpLoweringContext ctx) { return success(); } +LogicalResult ExportXlaOp(SendOp op, OpLoweringContext ctx) { + auto& value_map = *ctx.values; + if (op.is_host_transfer()) { + value_map[op] = + xla::SendToHost(value_map[op.operand()], value_map[op.token()], + xla::TypeToShape(op.operand().getType()), + Convert_channel_handle(op.channel_id())); + return success(); + } + value_map[op] = + xla::SendWithToken(value_map[op.operand()], value_map[op.token()], + Convert_channel_handle(op.channel_id())); + return success(); +} + LogicalResult ExportXlaOp(SliceOp op, OpLoweringContext ctx) { return failure(); } @@ -690,12 +691,6 @@ LogicalResult ExportXlaOp(SortOp op, OpLoweringContext ctx) { return success(); } -LogicalResult ExportXlaOp(TupleOp op, OpLoweringContext ctx) { - auto& value_map = *ctx.values; - value_map[op] = xla::Tuple(ctx.builder, GetTuple(op.val(), ctx)); - return success(); -} - LogicalResult ExportXlaOp(UnaryEinsumOp op, OpLoweringContext ctx) { // Intentional as UnaryEinsumOp is always lowered to the EinsumOp with two // operands. @@ -888,7 +883,7 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction( std::vector arg_shapes; arg_shapes.reserve(bb.getNumArguments()); for (auto& arg : bb.getArguments()) - arg_shapes.push_back(xla::TypeToShape(arg->getType())); + arg_shapes.push_back(xla::TypeToShape(arg.getType())); xla::Shape input_shape = xla::ShapeUtil::MakeTupleShape(arg_shapes); auto tuple = xla::Parameter(builder, 0, input_shape, "arg_tuple"); for (auto& it : llvm::enumerate(bb.getArguments())) { @@ -896,9 +891,9 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction( } } else { for (auto& it : llvm::enumerate(bb.getArguments())) { - auto* arg = it.value(); + auto arg = it.value(); auto num = it.index(); - xla::Shape shape = xla::TypeToShape(arg->getType()); + xla::Shape shape = xla::TypeToShape(arg.getType()); lowering[arg] = xla::Parameter(builder, num, shape, absl::StrCat("Arg_", num)); } @@ -1029,7 +1024,7 @@ LogicalResult AddDynamicParameterBindings(mlir::ModuleOp module, llvm::SmallDenseSet used_shape_indices; auto arg_type = - entry_func.getArgument(i)->getType().dyn_cast(); + entry_func.getArgument(i).getType().dyn_cast(); for (auto shape_and_padding : llvm::enumerate(llvm::zip( shape_indices.getValue(), padding_arg_indices.getValue()))) { const int element_index = shape_and_padding.index(); @@ -1064,7 +1059,7 @@ LogicalResult AddDynamicParameterBindings(mlir::ModuleOp module, kPaddingArgIndicesAttr, i, element_index, e, padding_arg_index)); Type padding_arg_type = - entry_func.getArgument(padding_arg_index)->getType(); + entry_func.getArgument(padding_arg_index).getType(); if (auto tensor_type = padding_arg_type.dyn_cast()) if (tensor_type.getRank() != 0) return entry_func.emitError() diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h index 3dffe2bc461..6f91213b31a 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_XLA_MLIR_HLO_TO_HLO_H_ #define TENSORFLOW_COMPILER_MLIR_XLA_MLIR_HLO_TO_HLO_H_ -#include "mlir/IR/Module.h" // TF:local_config_mlir +#include "mlir/IR/Module.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -37,7 +37,7 @@ Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto, // from `value_lowering` map. llvm::Optional CreateXlaOperator( mlir::Operation* op, - llvm::DenseMap* value_lowering); + llvm::DenseMap* value_lowering); } // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/operator_writer_gen.cc b/tensorflow/compiler/mlir/xla/operator_writer_gen.cc index acc3c17baf5..9a578c83ce6 100644 --- a/tensorflow/compiler/mlir/xla/operator_writer_gen.cc +++ b/tensorflow/compiler/mlir/xla/operator_writer_gen.cc @@ -17,6 +17,7 @@ limitations under the License. #include "llvm/ADT/Sequence.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringMap.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/InitLLVM.h" @@ -25,8 +26,8 @@ limitations under the License. #include "llvm/TableGen/Main.h" #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" -#include "mlir/Support/STLExtras.h" // TF:local_config_mlir -#include "mlir/TableGen/Operator.h" // TF:local_config_mlir +#include "mlir/Support/STLExtras.h" // TF:llvm-project +#include "mlir/TableGen/Operator.h" // TF:llvm-project using llvm::raw_ostream; using llvm::RecordKeeper; @@ -42,14 +43,31 @@ static std::string GetDefaultAttrExport( Attribute attr = named_attr.attr; StringRef storage_type = attr.getStorageType(); // For some attribute types we have a general conversion, so use that. - if (!attr.isEnumAttr() && (storage_type.endswith("IntegerAttr") || + if (!attr.isEnumAttr() && (storage_type.endswith("BoolAttr") || storage_type.endswith("FloatAttr") || + storage_type.endswith("IntegerAttr") || storage_type.endswith("StringAttr"))) { return "Convert" + attr.getReturnType().str(); } return "Convert_" + named_attr.name.str(); } +static std::string GetClientBuilder(const Operator& op) { + static const auto* kOpToXLABuilderMap = + new llvm::StringMap{{"ReverseOp", "Rev"}, + {"ConcatenateOp", "ConcatInDim"}, + {"ConvOp", "ConvGeneralDilated"}}; + + StringRef op_name = op.getCppClassName(); + + // Default case where the client builder method names closely follow the op + // names in the dialect. For e.g., AddOp -> xla::Add method. + if (!kOpToXLABuilderMap->count(op_name)) return op_name.drop_back(2); + + // Otherwise, if the op to client builder method mapping is provided. + return kOpToXLABuilderMap->lookup(op_name); +} + static void BuildOperator(const Operator& op, raw_ostream* output) { auto& os = *output; os << " auto& value_map = *lowering_context.values;\n" @@ -71,7 +89,7 @@ static void BuildOperator(const Operator& op, raw_ostream* output) { } // Otherwise, this is a varidiac operand list. - os << " std::vector xla_arg_" << index << ";" + os << " std::vector xla_arg_" << index << ";\n" << " for (auto operand : xla_op.getODSOperands(" << operand_number++ << "))\n xla_arg_" << index << ".push_back(value_map[operand]);\n"; @@ -85,10 +103,15 @@ static void BuildOperator(const Operator& op, raw_ostream* output) { << op.getArgName(index) << "());\n"; } - // Assumes that the client builder method names closely follow the op names - // in the dialect. For e.g., AddOp -> xla::Add method. - StringRef op_name = op.getCppClassName(); - os << " auto xla_result = xla::" << op_name.drop_back(2) << "("; + // Emit call to client API + os << " auto xla_result = xla::" << GetClientBuilder(op) << "("; + + // If all operands are variadic, then pass the builder explicitly to xla + // client API call + if (op.getNumOperands() == op.getNumVariadicOperands()) { + os << "lowering_context.builder"; + if (op.getNumArgs() != 0) os << ", "; + } // Emit each of the arguments. interleaveComma(llvm::seq(0, op.getNumArgs()), os, diff --git a/tensorflow/compiler/mlir/xla/tests/BUILD b/tensorflow/compiler/mlir/xla/tests/BUILD index 9f47185e90a..4faa8d2efe8 100644 --- a/tensorflow/compiler/mlir/xla/tests/BUILD +++ b/tensorflow/compiler/mlir/xla/tests/BUILD @@ -4,7 +4,7 @@ package(licenses = ["notice"]) glob_lit_tests( data = [":test_utilities"], - driver = "@local_config_mlir//:run_lit.sh", + driver = "@llvm-project//mlir:run_lit.sh", test_file_exts = ["mlir"], ) @@ -14,6 +14,6 @@ filegroup( testonly = True, data = [ "//tensorflow/compiler/mlir:tf-opt", - "@llvm//:FileCheck", + "@llvm-project//llvm:FileCheck", ], ) diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index c4b0e9f9d14..7e743cacb2b 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -6,7 +6,7 @@ // CHECK-LABEL: fusedBatchNorm_notraining func @fusedBatchNorm_notraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { - // CHECK-NEXT: "xla_hlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + // CHECK: "xla_hlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> %0:5 = "tf.FusedBatchNorm"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) return %0#0 : tensor<8x8x8x8xf32> } @@ -14,11 +14,332 @@ func @fusedBatchNorm_notraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32> // CHECK-LABEL: fusedBatchNorm_training func @fusedBatchNorm_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { // TODO(riverriddle) Support training. - // CHECK-NEXT: "tf.FusedBatchNorm" + // CHECK: "tf.FusedBatchNorm" %0:5 = "tf.FusedBatchNorm"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) return %0#0 : tensor<8x8x8x8xf32> } +// CHECK-LABEL: fusedBatchNormV3_noTraining +func @fusedBatchNormV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK: "xla_hlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + return %0#0 : tensor<8x8x8x8xf32> +} + +//CHECK-LABEL: fusedBatchNormV3_noTraining_mixedPrecision +func @fusedBatchNormV3_noTraining_mixedPrecision(%arg0: tensor<8x8x8x8xbf16>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { + // CHECK: %[[RESULT0:.*]] = "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> + // CHECK: %[[RESULT1:.*]] = "xla_hlo.batch_norm_inference"(%[[RESULT0]], %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + // CHECK-NEXT: "xla_hlo.convert"(%[[RESULT1]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> + return %0#0 : tensor<8x8x8x8xbf16> +} + +//CHECK-LABEL: fusedBatchNormV3_training +func @fusedBatchNormV3_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK: %[[RESULT0:.*]] = "xla_hlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tuple, tensor<8xf32>, tensor<8xf32>> + %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + // CHECK: "xla_hlo.get_tuple_element"(%[[RESULT0]]) {index = 0 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32> + // CHECK: "xla_hlo.get_tuple_element"(%[[RESULT0]]) {index = 1 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> + // CHECK: %[[VAR:.*]] = "xla_hlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> + // CHECK: xla_hlo.constant + // CHECK: "xla_hlo.mul"(%[[VAR]], {{.*}}) : (tensor<8xf32>, tensor) -> tensor<8xf32> + return %0#0 : tensor<8x8x8x8xf32> +} + +//CHECK-LABEL: fusedBatchNormV3_training_mixedPrecision +func @fusedBatchNormV3_training_mixedPrecision(%arg0: tensor<8x8x8x8xbf16>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { + // CHECK: "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> + %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + // CHECK: "xla_hlo.convert"({{.*}}) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> + return %0#0 : tensor<8x8x8x8xbf16> +} + +//CHECK-LABEL: fusedBatchNormV3_NCHW +func @fusedBatchNormV3_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK: "xla_hlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 1 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tuple, tensor<8xf32>, tensor<8xf32>> + %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + return %0#0 : tensor<8x8x8x8xf32> +} + +// CHECK-LABEL: fusedBatchNormGrad_noTraining +func @fusedBatchNormGrad_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK-NEXT: %[[grad:.*]] = "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[eps:.*]] = xla_hlo.constant dense<1.000000e-03> : tensor + + // CHECK-NEXT: %[[add:.*]] = "xla_hlo.add"(%arg4, %[[eps]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[scr1:.*]] = "xla_hlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> + + // CHECK-NEXT: %[[sub:.*]] = "xla_hlo.sub"(%[[act]], %arg3) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul:.*]] = xla_hlo.mul %[[grad]], %[[sub]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8x8x8x8xf32> + // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64> + // CHECK-NEXT: %[[cmul:.*]] = "xla_hlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[init:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor + // CHECK-NEXT: %[[red1:.*]] = "xla_hlo.reduce"(%[[cmul]], %[[init]]) ( { + // CHECK-NEXT: ^bb0(%arg5: tensor, %arg6: tensor): // no predecessors + // CHECK-NEXT: %[[reduced:.*]] = xla_hlo.add %arg5, %arg6 : tensor + // CHECK-NEXT: "xla_hlo.return"(%[[reduced]]) : (tensor) -> () + // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[scr2:.*]] = "xla_hlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32> + + // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.mul %arg2, %[[scr1]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> + // CHECK-NEXT: %[[mul3:.*]] = "xla_hlo.mul"(%[[grad]], %[[mul2]]) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + + // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.mul %[[scr1]], %[[scr2]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> + + // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64> + // CHECK-NEXT: %[[cgrad:.*]] = "xla_hlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[init2:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor + // CHECK-NEXT: %[[red2:.*]] = "xla_hlo.reduce"(%[[cgrad]], %[[init2]]) ( { + // CHECK-NEXT: ^bb0(%arg5: tensor, %arg6: tensor): // no predecessors + // CHECK-NEXT: %[[reduced1:.*]] = xla_hlo.add %arg5, %arg6 : tensor + // CHECK-NEXT: "xla_hlo.return"(%[[reduced1]]) : (tensor) -> () + // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[offset_backprop:.*]] = "xla_hlo.convert"(%[[red2]]) : (tensor<8xf32>) -> tensor<8xf32> + + // CHECK-NEXT: %[[x_backprop:.*]] = "xla_hlo.convert"(%[[mul3]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> + + %0:5 = "tf.FusedBatchNormGrad"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + return %0#0 : tensor<8x8x8x8xf32> +} + +// CHECK-LABEL: fusedBatchNormGrad_Training +func @fusedBatchNormGrad_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK-NEXT: %[[grad:.*]] = "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[training:.*]] = "xla_hlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple, tensor<8xf32>, tensor<8xf32>> + // CHECK-NEXT: %[[tact:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 0 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[scale_backprop:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 1 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> + // CHECK-NEXT: %[[offset_backprop:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 2 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> + // CHECK-NEXT: %[[x_backprop:.*]] = "xla_hlo.convert"(%[[tact]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> + + %0:5 = "tf.FusedBatchNormGrad"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + return %0#0 : tensor<8x8x8x8xf32> +} + +// CHECK-LABEL: fusedBatchNormGradV2_noTraining +func @fusedBatchNormGradV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK-NEXT: %[[grad:.*]] = "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[eps:.*]] = xla_hlo.constant dense<1.000000e-03> : tensor + + // CHECK-NEXT: %[[add:.*]] = "xla_hlo.add"(%arg4, %[[eps]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[scr1:.*]] = "xla_hlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> + + // CHECK-NEXT: %[[sub:.*]] = "xla_hlo.sub"(%[[act]], %arg3) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul:.*]] = xla_hlo.mul %[[grad]], %[[sub]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8x8x8x8xf32> + // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64> + // CHECK-NEXT: %[[cmul:.*]] = "xla_hlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[init:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor + // CHECK-NEXT: %[[red1:.*]] = "xla_hlo.reduce"(%[[cmul]], %[[init]]) ( { + // CHECK-NEXT: ^bb0(%arg5: tensor, %arg6: tensor): // no predecessors + // CHECK-NEXT: %[[reduced:.*]] = xla_hlo.add %arg5, %arg6 : tensor + // CHECK-NEXT: "xla_hlo.return"(%[[reduced]]) : (tensor) -> () + // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[scr2:.*]] = "xla_hlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32> + + // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.mul %arg2, %[[scr1]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> + // CHECK-NEXT: %[[mul3:.*]] = "xla_hlo.mul"(%[[grad]], %[[mul2]]) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + + // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.mul %[[scr1]], %[[scr2]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> + + // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64> + // CHECK-NEXT: %[[cgrad:.*]] = "xla_hlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[init2:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor + // CHECK-NEXT: %[[red2:.*]] = "xla_hlo.reduce"(%[[cgrad]], %[[init2]]) ( { + // CHECK-NEXT: ^bb0(%arg5: tensor, %arg6: tensor): // no predecessors + // CHECK-NEXT: %[[reduced1:.*]] = xla_hlo.add %arg5, %arg6 : tensor + // CHECK-NEXT: "xla_hlo.return"(%[[reduced1]]) : (tensor) -> () + // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[offset_backprop:.*]] = "xla_hlo.convert"(%[[red2]]) : (tensor<8xf32>) -> tensor<8xf32> + + // CHECK-NEXT: %[[x_backprop:.*]] = "xla_hlo.convert"(%[[mul3]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> + + %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + return %0#0 : tensor<8x8x8x8xf32> +} + +// CHECK-LABEL: fusedBatchNormGradV2_Training +func @fusedBatchNormGradV2_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK-NEXT: %[[grad:.*]] = "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[training:.*]] = "xla_hlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple, tensor<8xf32>, tensor<8xf32>> + // CHECK-NEXT: %[[tact:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 0 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[scale_backprop:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 1 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> + // CHECK-NEXT: %[[offset_backprop:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 2 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> + // CHECK-NEXT: %[[x_backprop:.*]] = "xla_hlo.convert"(%[[tact]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> + + %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + return %0#0 : tensor<8x8x8x8xf32> +} + +// CHECK-LABEL: fusedBatchNormGradV2_noTraining_mixed_precision +func @fusedBatchNormGradV2_noTraining_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { + // CHECK-NEXT: %[[grad:.*]] = "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> + + // CHECK: %[[x_backprop:.*]] = "xla_hlo.convert"({{.*}}) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> + // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xbf16> + + %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + return %0#0 : tensor<8x8x8x8xbf16> +} + +// CHECK-LABEL: fusedBatchNormGradV2_Training_mixed_precision +func @fusedBatchNormGradV2_Training_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { + // CHECK-NEXT: %[[grad:.*]] = "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[training:.*]] = "xla_hlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple, tensor<8xf32>, tensor<8xf32>> + // CHECK-NEXT: %[[tact:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 0 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[scale_backprop:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 1 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> + // CHECK-NEXT: %[[offset_backprop:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 2 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> + // CHECK-NEXT: %[[x_backprop:.*]] = "xla_hlo.convert"(%[[tact]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> + // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xbf16> + + %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + return %0#0 : tensor<8x8x8x8xbf16> +} + +// CHECK-LABEL: fusedBatchNormGradV3_noTraining +func @fusedBatchNormGradV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK-NEXT: %[[grad:.*]] = "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[eps:.*]] = xla_hlo.constant dense<1.000000e-03> : tensor + + // CHECK-NEXT: %[[add:.*]] = "xla_hlo.add"(%arg4, %[[eps]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[scr1:.*]] = "xla_hlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> + + // CHECK-NEXT: %[[sub:.*]] = "xla_hlo.sub"(%[[act]], %arg3) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul:.*]] = xla_hlo.mul %[[grad]], %[[sub]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8x8x8x8xf32> + // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64> + // CHECK-NEXT: %[[cmul:.*]] = "xla_hlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[init:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor + // CHECK-NEXT: %[[red1:.*]] = "xla_hlo.reduce"(%[[cmul]], %[[init]]) ( { + // CHECK-NEXT: ^bb0(%arg6: tensor, %arg7: tensor): // no predecessors + // CHECK-NEXT: %[[reduced:.*]] = xla_hlo.add %arg6, %arg7 : tensor + // CHECK-NEXT: "xla_hlo.return"(%[[reduced]]) : (tensor) -> () + // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[scr2:.*]] = "xla_hlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32> + + // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.mul %arg2, %[[scr1]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> + // CHECK-NEXT: %[[mul3:.*]] = "xla_hlo.mul"(%[[grad]], %[[mul2]]) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + + // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.mul %[[scr1]], %[[scr2]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> + + // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64> + // CHECK-NEXT: %[[cgrad:.*]] = "xla_hlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[init2:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor + // CHECK-NEXT: %[[red2:.*]] = "xla_hlo.reduce"(%[[cgrad]], %[[init2]]) ( { + // CHECK-NEXT: ^bb0(%arg6: tensor, %arg7: tensor): // no predecessors + // CHECK-NEXT: %[[reduced1:.*]] = xla_hlo.add %arg6, %arg7 : tensor + // CHECK-NEXT: "xla_hlo.return"(%[[reduced1]]) : (tensor) -> () + // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[offset_backprop:.*]] = "xla_hlo.convert"(%[[red2]]) : (tensor<8xf32>) -> tensor<8xf32> + + // CHECK-NEXT: %[[x_backprop:.*]] = "xla_hlo.convert"(%[[mul3]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> + + %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + return %0#0 : tensor<8x8x8x8xf32> +} + +// CHECK-LABEL: fusedBatchNormGradV3_Training +func @fusedBatchNormGradV3_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK-NEXT: %[[grad:.*]] = "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[training:.*]] = "xla_hlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple, tensor<8xf32>, tensor<8xf32>> + // CHECK-NEXT: %[[tact:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 0 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[scale_backprop:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 1 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> + // CHECK-NEXT: %[[offset_backprop:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 2 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> + // CHECK-NEXT: %[[x_backprop:.*]] = "xla_hlo.convert"(%[[tact]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> + + %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + return %0#0 : tensor<8x8x8x8xf32> +} + +// CHECK-LABEL: fusedBatchNormGradV3_noTraining_mixed_precision +func @fusedBatchNormGradV3_noTraining_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { + // CHECK-NEXT: %[[grad:.*]] = "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> + + // CHECK: %[[x_backprop:.*]] = "xla_hlo.convert"({{.*}}) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> + // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xbf16> + + %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + return %0#0 : tensor<8x8x8x8xbf16> +} + +// CHECK-LABEL: fusedBatchNormGradV3_Training_mixed_precision +func @fusedBatchNormGradV3_Training_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { + // CHECK-NEXT: %[[grad:.*]] = "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[training:.*]] = "xla_hlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple, tensor<8xf32>, tensor<8xf32>> + // CHECK-NEXT: %[[tact:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 0 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[scale_backprop:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 1 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> + // CHECK-NEXT: %[[offset_backprop:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 2 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> + // CHECK-NEXT: %[[x_backprop:.*]] = "xla_hlo.convert"(%[[tact]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> + // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xbf16> + + %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + return %0#0 : tensor<8x8x8x8xbf16> +} + +// CHECK-LABEL: fusedBatchNormGradV3_noTraining_NCHW +func @fusedBatchNormGradV3_noTraining_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK-NEXT: %[[grad:.*]] = "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[eps:.*]] = xla_hlo.constant dense<1.000000e-03> : tensor + + // CHECK-NEXT: %[[add:.*]] = "xla_hlo.add"(%arg4, %[[eps]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[scr1:.*]] = "xla_hlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> + + // CHECK-NEXT: %[[sub:.*]] = "xla_hlo.sub"(%[[act]], %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul:.*]] = xla_hlo.mul %[[grad]], %[[sub]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8x8x8x8xf32> + // CHECK-NEXT: xla_hlo.constant dense<[0, 2, 3]> : tensor<3xi64> + // CHECK-NEXT: %[[cmul:.*]] = "xla_hlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[init:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor + // CHECK-NEXT: %[[red1:.*]] = "xla_hlo.reduce"(%[[cmul]], %[[init]]) ( { + // CHECK-NEXT: ^bb0(%arg6: tensor, %arg7: tensor): // no predecessors + // CHECK-NEXT: %[[reduced:.*]] = xla_hlo.add %arg6, %arg7 : tensor + // CHECK-NEXT: "xla_hlo.return"(%[[reduced]]) : (tensor) -> () + // CHECK-NEXT: }) {dimensions = dense<[0, 2, 3]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[scr2:.*]] = "xla_hlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32> + + // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.mul %arg2, %[[scr1]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> + // CHECK-NEXT: %[[mul3:.*]] = "xla_hlo.mul"(%[[grad]], %[[mul2]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + + // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.mul %[[scr1]], %[[scr2]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> + + // CHECK-NEXT: xla_hlo.constant dense<[0, 2, 3]> : tensor<3xi64> + // CHECK-NEXT: %[[cgrad:.*]] = "xla_hlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[init2:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor + // CHECK-NEXT: %[[red2:.*]] = "xla_hlo.reduce"(%[[cgrad]], %[[init2]]) ( { + // CHECK-NEXT: ^bb0(%arg6: tensor, %arg7: tensor): // no predecessors + // CHECK-NEXT: %[[reduced1:.*]] = xla_hlo.add %arg6, %arg7 : tensor + // CHECK-NEXT: "xla_hlo.return"(%[[reduced1]]) : (tensor) -> () + // CHECK-NEXT: }) {dimensions = dense<[0, 2, 3]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[offset_backprop:.*]] = "xla_hlo.convert"(%[[red2]]) : (tensor<8xf32>) -> tensor<8xf32> + + // CHECK-NEXT: %[[x_backprop:.*]] = "xla_hlo.convert"(%[[mul3]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> + + %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + return %0#0 : tensor<8x8x8x8xf32> +} + +// CHECK-LABEL: fusedBatchNormGradV3_Training_NCHW +func @fusedBatchNormGradV3_Training_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK: %{{.*}} = "xla_hlo.batch_norm_grad"(%{{.*}}, %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 1 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple, tensor<8xf32>, tensor<8xf32>> + %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + return %0#0 : tensor<8x8x8x8xf32> +} + //===----------------------------------------------------------------------===// // Bias op legalizations. //===----------------------------------------------------------------------===// @@ -87,6 +408,27 @@ func @broadcast_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2x return %0: tensor<1x2xi32> } +// CHECK-LABEL: func @shift_left +func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { + // CHECK: xla_hlo.shift_left %arg0, %arg1 : tensor<4xi32> + %0 = "tf.LeftShift"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %0 : tensor<4xi32> +} + +// CHECK-LABEL: func @div_dynamic +func @div_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "xla_hlo.div"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} + %0 = "tf.Div"(%arg0, %arg1) : (tensor, tensor) -> tensor + return %0: tensor +} + +// CHECK-LABEL: func @div_unranked +func @div_unranked(%arg0: tensor<*xi32>, %arg1: tensor) -> tensor { + // CHECK: tf.Div + %0 = "tf.Div"(%arg0, %arg1) : (tensor<*xi32>, tensor) -> tensor + return %0: tensor +} + // CHECK-LABEL: func @maximum func @maximum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK: xla_hlo.max %arg0, %arg1 : tensor<4xf32> @@ -145,6 +487,34 @@ func @broadcast_sub(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2x return %0: tensor<1x2xi32> } +// CHECK-LABEL: func @shift_right +func @shift_right(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { + // CHECK: xla_hlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32> + %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %0 : tensor<4xi32> +} + +// CHECK-LABEL: func @broadcast_shift_right +func @broadcast_shift_right(%arg0: tensor<4xi32>, %arg1: tensor<2x4xi32>) -> tensor<2x4xi32> { + // CHECK: "xla_hlo.shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} + %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> + return %0 : tensor<2x4xi32> +} + +// CHECK-LABEL: func @shift_right_unsigned +func @shift_right_unsigned(%arg0: tensor<4x!tf.uint8>, %arg1: tensor<4x!tf.uint8>) -> tensor<4x!tf.uint8> { + // CHECK: tf.RightShift + %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4x!tf.uint8>, tensor<4x!tf.uint8>) -> tensor<4x!tf.uint8> + return %0 : tensor<4x!tf.uint8> +} + +// CHECK-LABEL: func @broadcast_shift_right_unsigned +func @broadcast_shift_right_unsigned(%arg0: tensor<4x!tf.uint8>, %arg1: tensor<2x4x!tf.uint8>) -> tensor<2x4x!tf.uint8> { + // CHECK: tf.RightShift + %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4x!tf.uint8>, tensor<2x4x!tf.uint8>) -> tensor<2x4x!tf.uint8> + return %0 : tensor<2x4x!tf.uint8> +} + // CHECK-LABEL: func @and func @and(%arg0: tensor<2xi1>) -> tensor<2xi1> { // CHECK-NEXT: xla_hlo.and @@ -166,6 +536,13 @@ func @and_dynamic(%arg0: tensor, %arg1: tensor<1xi1>) -> tensor { return %0: tensor } +// CHECK-LABEL: func @and_unranked +func @and_unranked(%arg0: tensor<*xi1>, %arg1: tensor<*xi1>) -> tensor<*xi1> { + // CHECK: tf.LogicalAnd + %0 = "tf.LogicalAnd"(%arg0, %arg1) : (tensor<*xi1>, tensor<*xi1>) -> tensor<*xi1> + return %0: tensor<*xi1> +} + // CHECK-LABEL: func @or func @or(%arg0: tensor<2xi1>) -> tensor<2xi1> { // CHECK-NEXT: xla_hlo.or @@ -310,6 +687,20 @@ func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> te return %0: tensor<2x3xf16> } +// CHECK-LABEL: func @floordiv_dynamic +func @floordiv_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: tf.FloorDiv + %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor, tensor) -> tensor + return %0: tensor +} + +// CHECK-LABEL: func @floordiv_unranked +func @floordiv_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> { + // CHECK: tf.FloorDiv + %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> + return %0: tensor<*xi32> +} + // CHECK-LABEL: func @floormod_broadcast_numerator func @floormod_broadcast_numerator(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { // CHECK-DAG: [[REM:%.+]] = "xla_hlo.remainder"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} @@ -344,6 +735,20 @@ func @floormod_broadcast_denominator(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32 return %0: tensor<2x3xi32> } +// CHECK-LABEL: func @floormod_dynamic +func @floormod_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: tf.FloorMod + %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor, tensor) -> tensor + return %0: tensor +} + +// CHECK-LABEL: func @floormod_unranked +func @floormod_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> { + // CHECK: tf.FloorMod + %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> + return %0: tensor<*xi32> +} + // CHECK-LABEL: func @broadcast_to func @broadcast_to(%arg0: tensor<16xf32>) -> tensor<16x16x16x16xf32> { %cst = "tf.Const"() { value = dense<16> : tensor<4xi32> } : () -> tensor<4xi32> @@ -415,6 +820,13 @@ func @equal_incompatible_shape_both_dynamic(%arg0: tensor, %arg1: tensor< return %0: tensor<*xi1> } +// CHECK-LABEL: func @equal_unranked +func @equal_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi1> { + // CHECK: "tf.Equal" + %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi1> + return %0: tensor<*xi1> +} + // CHECK-LABEL: func @notequal func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> { // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} @@ -482,6 +894,20 @@ func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor< return %0: tensor<1x2xi1> } +// CHECK-LABEL: func @greater_dynamic +func @greater_dynamic(%arg0: tensor) -> tensor { + // CHECK: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} + %0 = "tf.Greater"(%arg0, %arg0) : (tensor, tensor) -> tensor + return %0: tensor +} + +// CHECK-LABEL: func @greater_uranked +func @greater_uranked(%arg0: tensor<*xi32>) -> tensor<*xi1> { + // CHECK: "tf.Greater" + %0 = "tf.Greater"(%arg0, %arg0) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi1> + return %0: tensor<*xi1> +} + // CHECK-LABEL: func @greater_equal func @greater_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} @@ -761,13 +1187,22 @@ func @maxpool_valid_padding(%arg0: tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> return %0 : tensor<2x3x5x7xi32> } +// CHECK-LABEL: maxpool_same_padding +// CHECK-SAME: %[[ARG:.*]]: tensor +func @maxpool_same_padding(%arg0: tensor<2x13x25x7xi32>) -> tensor<2x4x7x7xi32> { + // CHECK: padding = dense<{{\[\[}}0, 0, 1, 0], [0, 1, 1, 0]]> : tensor<2x4xi64> + + %0 = "tf.MaxPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 4, 1]} : (tensor<2x13x25x7xi32>) -> tensor<2x4x7x7xi32> + return %0 : tensor<2x4x7x7xi32> +} + //===----------------------------------------------------------------------===// // MaxPoolGrad op legalizations. //===----------------------------------------------------------------------===// -// CHECK-LABEL: @max_pool_grad +// CHECK-LABEL: @max_pool_grad_valid // CHECK-SAME: %[[INPUT:.*]]: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>, %[[GRAD:.*]]: tensor<10x12x12x64xf32> -func @max_pool_grad(%orig_input: tensor<10x24x24x64xf32>, %orig_output: tensor<10x12x12x64xf32>, %grad: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> { +func @max_pool_grad_valid(%orig_input: tensor<10x24x24x64xf32>, %orig_output: tensor<10x12x12x64xf32>, %grad: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> { // CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor // CHECK: %[[RESULT:.*]] = "xla_hlo.select_and_scatter"(%[[INPUT]], %[[GRAD]], %[[ZERO]]) ( { // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor, %[[VALUE_B:.*]]: tensor): @@ -789,6 +1224,18 @@ func @max_pool_grad(%orig_input: tensor<10x24x24x64xf32>, %orig_output: tensor<1 return %result : tensor<10x24x24x64xf32> } +// CHECK-LABEL: @max_pool_grad_same +func @max_pool_grad_same(%orig_input: tensor<2x13x25x7xf32>, %orig_output: tensor<2x4x7x7xf32>, %grad: tensor<2x4x7x7xf32>) -> tensor<2x13x25x7xf32> { + // CHECK: padding = dense<{{\[\[}}0, 0, 1, 0], [0, 1, 1, 0]]> : tensor<2x4xi64> + %result = "tf.MaxPoolGrad"(%orig_input, %orig_output, %grad) { + data_format = "NHWC", + ksize = [1, 2, 3, 1], + padding = "SAME", + strides = [1, 4, 4, 1] + } : (tensor<2x13x25x7xf32>, tensor<2x4x7x7xf32>, tensor<2x4x7x7xf32>) -> tensor<2x13x25x7xf32> + return %result : tensor<2x13x25x7xf32> +} + //===----------------------------------------------------------------------===// // OneHot op legalizations. //===----------------------------------------------------------------------===// @@ -1243,6 +1690,34 @@ func @log_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { return %0 : tensor<*xf32> } +// CHECK-LABEL: @log1p +func @log1p(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: "xla_hlo.log_plus_one"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + %0 = "tf.Log1p"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + +// CHECK-LABEL: func @log1p_dynamic +func @log1p_dynamic(%arg0: tensor) -> tensor { + // CHECK: "xla_hlo.log_plus_one"(%arg0) : (tensor) -> tensor + %0 = "tf.Log1p"(%arg0) : (tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func @log1p_unranked +func @log1p_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { + // CHECK: "xla_hlo.log_plus_one"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + %0 = "tf.Log1p"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// CHECK-LABEL: func @not_op_unranked +func @not_op_unranked(%arg0: tensor<*xi1>) -> tensor<*xi1> { + // CHECK: "xla_hlo.not"(%arg0) : (tensor<*xi1>) -> tensor<*xi1> + %0 = "tf.LogicalNot"(%arg0) : (tensor<*xi1>) -> tensor<*xi1> + return %0 : tensor<*xi1> +} + // CHECK-LABEL: @neg func @neg(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK: "xla_hlo.neg"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> @@ -1404,6 +1879,18 @@ func @expand_dims(%arg0: tensor<2xf32>, %axis: tensor) -> tensor<1x2xf32> { return %0 : tensor<1x2xf32> } +// CHECK-LABEL: func @sign +// CHECK-SAME: [[ARG:%arg.*]]: tensor<1x2x3x4xf32> +func @sign(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { + // CHECK: [[PRED:%.*]] = "xla_hlo.compare"([[ARG]], [[ARG]]) + // CHECK: [[ZEROS:%.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<1x2x3x4xf32> + // CHECK: [[SIGN:%.*]] = "xla_hlo.sign"([[ARG]]) + // CHECK: [[SELECT:%.*]] = "xla_hlo.select"([[PRED]], [[ZEROS]], [[SIGN]]) + // CHECK: return [[SELECT]] : tensor<1x2x3x4xf32> + %0 = "tf.Sign"(%arg0) : (tensor<1x2x3x4xf32>) -> (tensor<1x2x3x4xf32>) + return %0 : tensor<1x2x3x4xf32> +} + // CHECK-LABEL: slice_constant_start func @slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> { // CHECK: %[[START:.*]] = xla_hlo.constant dense<1> : tensor<1xi64> @@ -1525,23 +2012,45 @@ func @strided_slice_range_clamping(%input: tensor<4x8xf32>) -> tensor<0x3xf32> { return %output : tensor<0x3xf32> } -// CHECK-LABEL: strided_slice_shrink_axis -func @strided_slice_shrink_axis(%input: tensor<4x8xf32>) -> tensor { - %begin = "tf.Const"() {value = dense<[1, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>) - %end = "tf.Const"() {value = dense<[2, 4]> : tensor<2xi32>} : () -> (tensor<2xi32>) - %strides = "tf.Const"() {value = dense<[1, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>) +// CHECK-LABEL: strided_slice_begin_end_mask +// CHECK-SAME: %[[INPUT:[a-z0-9]+]]: tensor<4x128x1024xf32> +func @strided_slice_begin_end_mask(%input: tensor<4x128x1024xf32>) { - // CHECK: %[[SLICED:.*]] = "xla_hlo.slice" - // CHECK-DAG-SAME: start_indices = dense<[1, 3]> - // CHECK-DAG-SAME: limit_indices = dense<[2, 4]> - // CHECK-DAG-SAME: strides = dense<[1, 3]> - // CHECK-SAME: -> tensor<1x1xf32> + // For StridedSlice + // Dim #: 0, 1, 2 + // Input shape: [4, 128, 1024] + // Begin: 1, 4, -3 + // End: 8, 65, 42 + // Stride: 1, 4, -1 + // Begin mask: 1, 0, 0 (= 1) + // End mask: 0, 0, 1 (= 4) - // CHECK: "xla_hlo.reshape"(%[[SLICED]]) : (tensor<1x1xf32>) -> tensor + // So result shape: + // Dim #0: begin mask (1) -> begin = 0; end 8 canonicalized to 4: so 4 + // Dim #1: 4 to 65 stride 4: so 16 + // Dim #2: begin -3 + 1024 = 1021; end mask (1) -> end = -1: so 1022 + // result shape: [4, 16, 1022] - %output = "tf.StridedSlice"(%input, %begin, %end, %strides) {shrink_axis_mask = 3 - : i64} : (tensor<4x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor - return %output : tensor + // As output shape of StridedSlice differs, a reshape will follow. + + %begin = "tf.Const"() {value = dense<[1, 4, -3]> : tensor<3xi32>} : () -> (tensor<3xi32>) + %end = "tf.Const"() {value = dense<[8, 65, 42]> : tensor<3xi32>} : () -> (tensor<3xi32>) + %strides = "tf.Const"() {value = dense<[1, 4, -1]> : tensor<3xi32>} : () -> (tensor<3xi32>) + + // CHECK: %[[REVERSE:.*]] = "xla_hlo.reverse"(%[[INPUT]]) + + // CHECK: %[[SLICE:.*]] = "xla_hlo.slice"(%[[REVERSE]]) + // CHECK-DAG-SAME: limit_indices = dense<[4, 65, 1024]> + // CHECK-DAG-SAME: start_indices = dense<[0, 4, 2]> + // CHECK-DAG-SAME: strides = dense<[1, 4, 1]> + // CHECK-SAME: -> tensor<4x16x1022xf32> + + %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {begin_mask = 1, end_mask = 4} : (tensor<4x128x1024xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor + + // CHECK: "xla_hlo.reshape"(%[[SLICE]]) + // CHECK-SAME: -> tensor + + return } //===----------------------------------------------------------------------===// @@ -2268,7 +2777,7 @@ func @gather_v2_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xi32>) -> tensor<* func @strided_slice_grad(%grad: tensor<4x16x1022xf32>) -> tensor<4x128x1024xf32> { // For StridedSlice - // Dim #: 0, 1, 2 + // Dim #: 0, 1, 2 // Input shape: [4, 128, 1024] // Begin: 1, 4, -3 // End: 8, 65, 42 @@ -2277,7 +2786,7 @@ func @strided_slice_grad(%grad: tensor<4x16x1022xf32>) -> tensor<4x128x1024xf32> // End mask: 0, 0, 1 (= 4) // So result shape: - // Dim #0: begin mask (1) -> begin = 0; end 8 cannonicalized to 4: so 4 + // Dim #0: begin mask (1) -> begin = 0; end 8 canonicalized to 4: so 4 // Dim #1: 4 to 65 stride 4: so 16 // Dim #2: begin -3 + 1024 = 1021; end mask (1) -> end = -1: so 1022 // result shape: [4, 16, 1022] @@ -2302,3 +2811,20 @@ func @strided_slice_grad(%grad: tensor<4x16x1022xf32>) -> tensor<4x128x1024xf32> // CHECK: return [[PAD]] return %0: tensor<4x128x1024xf32> } + +// CHECK-LABEL: @tensor_scatter_update +func @tensor_scatter_update(%tensor: tensor, %indices: tensor, %updates: tensor) -> tensor { + // CHECK: "xla_hlo.scatter"(%arg0, %arg1, %arg2) ( { + // CHECK: ^bb0(%arg3: tensor, %arg4: tensor): + // CHECK: "xla_hlo.return"(%arg4) : (tensor) -> () + // CHECK: }) + // CHECK-SAME: indices_are_sorted = false + // CHECK-SAME: scatter_dimension_numbers + // CHECK-SAME: index_vector_dim = 1 : i64 + // CHECK-SAME: inserted_window_dims = dense<[0, 1]> : tensor<2xi64> + // CHECK-SAME: scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64> + // CHECK-SAME: update_window_dims = dense<1> : tensor<1xi64> + // CHECK-SAME: unique_indices = false + %0 = "tf.TensorScatterUpdate"(%tensor, %indices, %updates) : (tensor, tensor, tensor) -> tensor + return %0 : tensor +} diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir index dae20d0f469..1d2cf767939 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir @@ -32,10 +32,10 @@ func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32 // CHECK-NEXT: %2 = subi %1, %arg1 : tensor<4xi32> %2 = "xla_hlo.sub"(%1, %arg1) {name = "sub.5"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - // CHECK-NEXT: %3 = divis %2, %arg1 : tensor<4xi32> + // CHECK-NEXT: %3 = divi_signed %2, %arg1 : tensor<4xi32> %3 = "xla_hlo.div"(%2, %arg1) {name = "div.6"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - // CHECK-NEXT: %4 = remis %3, %arg1 : tensor<4xi32> + // CHECK-NEXT: %4 = remi_signed %3, %arg1 : tensor<4xi32> %4 = "xla_hlo.remainder"(%3, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> // CHECK-NEXT: return %4 : tensor<4xi32> diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-affine.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-affine.mlir index d4ee0fdc2e2..74fea0cc687 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-affine.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-affine.mlir @@ -59,7 +59,7 @@ func @float_div_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, // CHECK-LABEL: func @int_div_op func @int_div_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, %result: memref<7xi32>) -> () { - // CHECK: divis %{{.*}}, %{{.*}} : i32 + // CHECK: divi_signed %{{.*}}, %{{.*}} : i32 "xla_lhlo.div"(%lhs, %rhs, %result) {name = "div.1"} : (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> () return diff --git a/tensorflow/compiler/mlir/xla/tests/ops.mlir b/tensorflow/compiler/mlir/xla/tests/ops.mlir index a315a2318b5..c6db931e239 100644 --- a/tensorflow/compiler/mlir/xla/tests/ops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/ops.mlir @@ -13,6 +13,45 @@ func @invalid_type() -> !xla_hlo.foobar // ----- +// CHECK-LABEL: func @alltoall +func @alltoall(%data: tensor<4x16xf32>) -> tensor<16x4xf32> { + %0 = "xla_hlo.all_to_all"(%data) { + split_dimension = 1 : i64, + concat_dimension = 0 : i64, + split_count = 4 : i64, + replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64> + } : (tensor<4x16xf32>) -> tensor<16x4xf32> + return %0 : tensor<16x4xf32> +} + +// ----- + +// CHECK-LABEL: func @alltoall_unranked_input +func @alltoall_unranked_input(%data: tensor<*xf32>) -> tensor<*xf32> { + %0 = "xla_hlo.all_to_all"(%data) { + split_dimension = 1 : i64, + concat_dimension = 0 : i64, + split_count = 5 : i64, + replica_groups = dense<[[0, 1, 2, 3, 4]]> : tensor<1x5xi64> + } : (tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +func @alltoall_invalid_split_dim_size(%data: tensor<4x16xf32>) -> tensor<16x4xf32> { +// expected-error@+1 {{split dimension has size 16, expected to be a multiple of split_count 5}} + %0 = "xla_hlo.all_to_all"(%data) { + split_dimension = 1 : i64, + concat_dimension = 0 : i64, + split_count = 5 : i64, + replica_groups = dense<[[0, 1, 2, 3, 4]]> : tensor<1x5xi64> + } : (tensor<4x16xf32>) -> tensor<16x4xf32> + return %0 : tensor<16x4xf32> +} + +// ----- + // CHECK-LABEL: func @broadcast func @broadcast(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[1, 2]> : tensor<2xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32> @@ -189,6 +228,15 @@ func @dot_bad_precision_config(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) - // ----- +func @rng_uniform_invalid_type(%mu: tensor>, %sigma: tensor) -> tensor<2x3x5xf32> { + %shape = xla_hlo.constant dense<[2, 3, 5]> : tensor<3xi64> + // expected-error@+1 {{must be tensor of pred (AKA boolean or 1-bit integer) or 8/16/32/64-bit integer or floating-point values, but got 'tensor>'}} + %0 = "xla_hlo.rng_uniform"(%mu, %sigma, %shape) : (tensor>, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> + return %0 : tensor<2x3x5xf32> +} + +// ----- + // CHECK-LABEL: func @select func @select(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> { %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> diff --git a/tensorflow/compiler/mlir/xla/tests/translate/BUILD b/tensorflow/compiler/mlir/xla/tests/translate/BUILD index 857ee2896a2..c4e747c90f3 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/BUILD +++ b/tensorflow/compiler/mlir/xla/tests/translate/BUILD @@ -4,7 +4,7 @@ package(licenses = ["notice"]) glob_lit_tests( data = [":test_utilities"], - driver = "@local_config_mlir//:run_lit.sh", + driver = "@llvm-project//mlir:run_lit.sh", test_file_exts = [ "mlir", "hlo", @@ -18,7 +18,7 @@ filegroup( testonly = True, data = [ "//tensorflow/compiler/mlir:tf-mlir-translate", - "@llvm//:FileCheck", - "@llvm//:not", + "@llvm-project//llvm:FileCheck", + "@llvm-project//llvm:not", ], ) diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir index 442780a520c..125c958d6c3 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir @@ -355,6 +355,18 @@ func @main(%arg0: tensor<3x4xi32>, %arg1: tensor<4x5xi32>) -> tensor<3x5xi32> { // ----- +// CHECK: HloModule +func @main(%arg0: tensor<3x9xf32>) -> tensor<3x5xcomplex> { + %0 = "xla_hlo.fft"(%arg0) {fft_length = dense<9> : tensor<1xi64>, fft_type = "RFFT"} : (tensor<3x9xf32>) -> tensor<3x5xcomplex> + return %0 : tensor<3x5xcomplex> +} + +// CHECK: ENTRY +// CHECK: [[ARG:%.*]] = f32[3,9] parameter(0) +// CHECK: c64[3,5] fft(f32[3,9] [[ARG]]), fft_type=RFFT, fft_length={9} + +// ----- + // CHECK: HloModule func @main(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<10x300xf32> { // CHECK: [[ARG0:%.*]] = f32[200,100,300] parameter(0) @@ -396,6 +408,18 @@ func @main(%arg0: tuple, tensor>) -> tensor { // ----- +// CHECK: HloModule +func @main(%arg0: !xla_hlo.token) -> tuple, tensor>, !xla_hlo.token> { + %0 = "xla_hlo.infeed"(%arg0) {infeed_config = "foobar"} : (!xla_hlo.token) -> tuple, tensor>, !xla_hlo.token> + return %0 : tuple, tensor>, !xla_hlo.token> +} + +// CHECK: ENTRY +// CHECK: [[ARG:%.*]] = token[] parameter(0) +// CHECK: ROOT %[[RESULT:.*]] = ((s32[3], pred[]), token[]) infeed(token[] [[ARG]]), infeed_config="foobar" + +// ----- + // CHECK: HloModule func @main() -> tensor<1x10xf32> { %result = "xla_hlo.iota"() { @@ -409,6 +433,19 @@ func @main() -> tensor<1x10xf32> { // ----- +// CHECK: HloModule +func @main(%data: tensor<3xi32>, %token: !xla_hlo.token) -> !xla_hlo.token { + %0 = "xla_hlo.outfeed"(%data, %token) {outfeed_config = "foobar"} : (tensor<3xi32>, !xla_hlo.token) -> !xla_hlo.token + return %0 : !xla_hlo.token +} + +// CHECK: ENTRY +// CHECK: [[DATA:%.*]] = s32[3] parameter(0) +// CHECK: [[TOKEN:%.*]] = token[] parameter(1) +// CHECK: ROOT %[[RESULT:.*]] = token[] outfeed(s32[3] [[DATA]], token[] [[TOKEN]]), outfeed_config="foobar" + +// ----- + // CHECK: HloModule func @main(%arg: tensor<4x6xf32>, %pad: tensor) -> tensor<13x19xf32> { %0 = "xla_hlo.pad"(%arg, %pad) {edge_padding_high = dense<[4,5]> : tensor<2xi64>, edge_padding_low = dense<[2,3]> : tensor<2xi64>, interior_padding = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>, tensor) -> tensor<13x19xf32> @@ -504,6 +541,20 @@ func @main(%arg0 : tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32> { // ----- +// CHECK: HloModule +func @main(%mu: tensor, %sigma: tensor) -> tensor<2x3x5xf32> { + %shape = xla_hlo.constant dense<[2, 3, 5]> : tensor<3xi64> + %0 = "xla_hlo.rng_normal"(%mu, %sigma, %shape) : (tensor, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> + return %0 : tensor<2x3x5xf32> +} + +// CHECK: ENTRY +// CHECK: %[[MU:.*]] = f32[] parameter(0) +// CHECK: %[[SIGMA:.*]] = f32[] parameter(1) +// CHECK: ROOT %[[RESULT:.*]] = f32[2,3,5] rng(f32[] %[[MU]], f32[] %[[SIGMA]]), distribution=rng_normal + +// ----- + // CHECK: HloModule func @main() -> tensor<2x3x5xf32> { %0 = xla_hlo.constant dense<0.000000e+00> : tensor @@ -599,6 +650,62 @@ func @main(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>) -> te // ----- +// CHECK: HloModule +func @main(%arg: tensor<3x4xi32>, %token: !xla_hlo.token) -> !xla_hlo.token { + %0 = "xla_hlo.send"(%arg, %token) { + channel_id = { + handle = 5 : i64, + type = 2 : i64 // Device to host channel + }, + is_host_transfer = true + } : (tensor<3x4xi32>, !xla_hlo.token) -> !xla_hlo.token + return %0 : !xla_hlo.token +} + +// CHECK: ENTRY +// CHECK: [[ARG:%.*]] = s32[3,4] parameter(0) +// CHECK: [[TOKEN:%.*]] = token[] parameter(1) +// CHECK: [[SEND:%.*]] = (s32[3,4], u32[], token[]) send(s32[3,4] [[ARG]], token[] [[TOKEN]]), channel_id=5, is_host_transfer=true +// CHECK: ROOT +// CHECK-SAME: token[] send-done((s32[3,4], u32[], token[]) [[SEND]]), channel_id=5, is_host_transfer=true + +// ----- + +// CHECK: HloModule +func @main(%arg: tensor<3x4xi32>, %token: !xla_hlo.token) -> !xla_hlo.token { + %0 = "xla_hlo.send"(%arg, %token) { + channel_id = { + handle = 5 : i64, + type = 1 : i64 // Device to device channel + }, + is_host_transfer = false + } : (tensor<3x4xi32>, !xla_hlo.token) -> !xla_hlo.token + return %0 : !xla_hlo.token +} + +// CHECK: ENTRY +// CHECK: [[ARG:%.*]] = s32[3,4] parameter(0) +// CHECK: [[TOKEN:%.*]] = token[] parameter(1) +// CHECK: [[SEND:%.*]] = (s32[3,4], u32[], token[]) send(s32[3,4] [[ARG]], token[] [[TOKEN]]), channel_id=5 +// CHECK: ROOT +// CHECK-SAME: token[] send-done((s32[3,4], u32[], token[]) [[SEND]]), channel_id=5 + +// ----- + +// CHECK: HloModule +func @main(%arg: tensor<4x4xf32>, %size: tensor) -> tensor<4x4xf32> { + %0 = "xla_hlo.set_dimension_size"(%arg, %size) {dimension = 1 : i32} : (tensor<4x4xf32>, tensor) -> tensor<4x4xf32> + return %0 : tensor<4x4xf32> +} + +// CHECK: ENTRY +// CHECK: [[ARG:%.*]] = f32[4,4] parameter(0) +// CHECK: [[SIZE:%.*]] = s32[] parameter(1) +// CHECK: ROOT +// CHECK-SAME: f32[4,<=4] set-dimension-size(f32[4,4] [[ARG]], s32[] [[SIZE]]), dimensions={1} + +// ----- + // CHECK: HloModule func @main(%arg: tensor<3x4xi32>) -> tensor<1x2xi32> { %0 = "xla_hlo.slice"(%arg) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x2xi32> diff --git a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt index 5f9670be2f1..b598a9b8852 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt @@ -95,6 +95,15 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] { ROOT %call.2 = s64[] call(%arg0.1), to_apply=%call } +// CHECK-LABEL: func @test_cholesky +// CHECK-SAME: ([[ARG:%.*]]: tensor<1x291x291xf32>) -> tensor<1x291x291xf32> +%test_cholesky (a: f32[1,291,291]) -> f32[1,291,291] { + %a = f32[1,291,291] parameter(0) + // CHECK-NEXT: "xla_hlo.cholesky"([[ARG]]) {lower = true, name = {{.*}}} : (tensor<1x291x291xf32>) -> tensor<1x291x291xf32> + ROOT %out = f32[1,291,291] cholesky(f32[1,291,291] %a), lower=true +} + + // CHECK-LABEL: func @test_clamp( %test_clamp (Arg_0.1: f32[], Arg_1.2: f32[4], Arg_1.3: f32[]) -> f32[4] { %Arg_0.1 = f32[] parameter(0) @@ -364,6 +373,16 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] { ROOT %imag.3 = f32[4] imag(c64[4] %Arg_0.1) } +// CHECK-LABEL: func @test_infeed +// CHECK-SAME: ([[TOKEN:%.*]]: !xla_hlo.token) -> tuple, !xla_hlo.token> { +%test_infeed (token0: token[]) -> (s32[3], token[]) { + %token0 = token[] parameter(0) + // CHECK-NEXT: "xla_hlo.infeed"([[TOKEN]]) + // CHECK-SAME: infeed_config = "foobar" + ROOT %infeed = (s32[3], token[]) infeed(token[] %token0), infeed_config="foobar" +} + + // CHECK-LABEL: func @test_iota_1() -> tensor<4xf32> { %test_iota_1 () -> f32[4] { // CHECK-NEXT: "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32> @@ -444,6 +463,16 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] { ROOT %or.3 = pred[4] or(pred[4] %Arg_0.1, pred[4] %Arg_1.2) } +// CHECK-LABEL: func @test_outfeed +// CHECK-SAME: ([[DATA:%.*]]: tensor<3xi32>, [[TOKEN:%.*]]: !xla_hlo.token) -> !xla_hlo.token { +%test_outfeed (Arg_0.1: s32[3], Arg_1.2: token[]) -> token[] { + %Arg_0.1 = s32[3] parameter(0) + %Arg_1.2 = token[] parameter(1) + // CHECK-NEXT: "xla_hlo.outfeed"([[DATA]], [[TOKEN]]) + // CHECK-SAME: outfeed_config = "foobar" + ROOT %outfeed.3 = token[] outfeed(s32[3] %Arg_0.1, token[] %Arg_1.2), outfeed_config="foobar" +} + // CHECK-LABEL: func @test_pad(%arg0: tensor<4xf32>, %arg1: tensor) -> tensor<4xf32> { %test_pad (Arg_0.1: f32[4], Arg_1.2: f32[]) -> f32[4] { %Arg_0.1 = f32[4] parameter(0) @@ -488,6 +517,26 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] { ROOT %power.3 = f32[4] power(f32[4] %Arg_0.1, f32[4] %Arg_1.2) } +// CHECK-LABEL: func @test_rng_normal +// CHECK-SAME: ([[ARG0:%.*]]: tensor, [[ARG1:%.*]]: tensor) -> tensor<2x3x5xf32> +%test_rng_normal (Arg_0.1: f32[], Arg_1.2: f32[]) -> f32[2,3,5] { + %Arg_0.1 = f32[] parameter(0) + %Arg_1.2 = f32[] parameter(1) + // CHECK: [[CST:%.*]] = constant dense<[2, 3, 5]> : tensor<3xi64> + // CHECK: "xla_hlo.rng_normal"([[ARG0]], [[ARG1]], [[CST]]) + ROOT %rng.4 = f32[2,3,5] rng(f32[] %Arg_0.1, f32[] %Arg_1.2), distribution=rng_normal +} + +// CHECK-LABEL: func @test_rng_uniform +// CHECK-SAME: ([[ARG0:%.*]]: tensor, [[ARG1:%.*]]: tensor) -> tensor<2x3x5xf32> +%test_rng_uniform (Arg_0.1: f32[], Arg_1.2: f32[]) -> f32[2,3,5] { + %Arg_0.1 = f32[] parameter(0) + %Arg_1.2 = f32[] parameter(1) + // CHECK: [[CST:%.*]] = constant dense<[2, 3, 5]> : tensor<3xi64> + // CHECK: "xla_hlo.rng_uniform"([[ARG0]], [[ARG1]], [[CST]]) + ROOT %rng.4 = f32[2,3,5] rng(f32[] %Arg_0.1, f32[] %Arg_1.2), distribution=rng_uniform +} + // CHECK-LABEL: func @test_real %test_real (Arg_0.1: c64[4]) -> f32[4] { %Arg_0.1 = c64[4] parameter(0) @@ -603,6 +652,15 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] { ROOT %select.4 = s32[2,3] select(pred[2,3] %Arg_0.1, s32[2,3] %Arg_1.2, s32[2,3] %Arg_2.3) } +// CHECK-LABEL: func @test_set_dimension_size +// CHECK-SAME: ([[ARG:%.*]]: tensor<4x4xf32>, [[SIZE:%.*]]: tensor) +%test_set_dimension_size (Arg_0.1: f32[4,4], Arg_1.2: s32[]) -> f32[4,<=4] { + %Arg_0.1 = f32[4,4] parameter(0) + %Arg_1.2 = s32[] parameter(1) + // CHECK-NEXT: "xla_hlo.set_dimension_size"([[ARG]], [[SIZE]]) {dimension = 1 : i32, name = "{{.*}}"} : (tensor<4x4xf32>, tensor) -> tensor<4x4xf32> + ROOT %set-dimension-size.2 = f32[4,<=4] set-dimension-size(f32[4,4] %Arg_0.1, s32[] %Arg_1.2), dimensions={1} +} + // CHECK-LABEL: func @test_sine(%arg0: tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> { %test_sine (arg0.1: f32[1,16,16,3]) -> f32[1,16,16,3] { %arg0.1 = f32[1,16,16,3]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"} diff --git a/tensorflow/compiler/mlir/xla/transforms/canonicalize.td b/tensorflow/compiler/mlir/xla/transforms/canonicalize.td index d510a3df994..df9be382f11 100644 --- a/tensorflow/compiler/mlir/xla/transforms/canonicalize.td +++ b/tensorflow/compiler/mlir/xla/transforms/canonicalize.td @@ -29,7 +29,7 @@ def BuildSliceLimits : NativeCodeCall< def BuildSliceStrides : NativeCodeCall< "GetI64ElementsAttr(SmallVector(" - "$0->getType().cast().getRank(), 1), &$_builder)">; + "$0.getType().cast().getRank(), 1), &$_builder)">; def DynamicSliceToSlice: Pat<(HLO_DynamicSliceOp HLO_Tensor:$input, (HLO_ConstOp I64ElementsAttr:$starting_indices), diff --git a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc index 4a74fe4b2ae..9170b217471 100644 --- a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc +++ b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc @@ -16,18 +16,18 @@ limitations under the License. // This file implements logic for lowering HLO dialect to LHLO dialect. #include "absl/memory/memory.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/BlockAndValueMapping.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/Location.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Transforms/DialectConversion.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" @@ -39,8 +39,8 @@ namespace { constexpr StringRef kTempBufferAttr = "temp"; -Value* GetTensorStoreOrReturnMemRef(Value* value) { - for (const auto& user : value->getUsers()) { +Value GetTensorStoreOrReturnMemRef(Value value) { + for (const auto& user : value.getUsers()) { if (auto tensor_store = dyn_cast(user)) { if (tensor_store.getOperand(0) == value) { return tensor_store.getOperand(1); @@ -56,9 +56,9 @@ Value* GetTensorStoreOrReturnMemRef(Value* value) { return nullptr; } -Operation* GetLastUse(Value* value) { - Operation* last = value->getDefiningOp(); - for (auto& user : value->getUses()) { +Operation* GetLastUse(Value value) { + Operation* last = value.getDefiningOp(); + for (auto& user : value.getUses()) { Operation* user_op = user.getOwner(); if (!user_op->isBeforeInBlock(last)) { last = user_op; @@ -67,9 +67,9 @@ Operation* GetLastUse(Value* value) { return last; } -Value* InsertAllocAndDealloc(Location loc, Value* result, - ConversionPatternRewriter* rewriter) { - auto result_type = result->getType().dyn_cast(); +Value InsertAllocAndDealloc(Location loc, Value result, + ConversionPatternRewriter* rewriter) { + auto result_type = result.getType().dyn_cast(); if (!result_type || !result_type.hasStaticShape()) { emitError(loc, "tensor to buffer conversion expects statically shaped results"); @@ -79,7 +79,7 @@ Value* InsertAllocAndDealloc(Location loc, Value* result, Operation* last = GetLastUse(result); - Operation* op = result->getDefiningOp(); + Operation* op = result.getDefiningOp(); OpBuilder allocBuilder(op); auto alloc = allocBuilder.create(loc, memref_type); alloc.setAttr(kTempBufferAttr, rewriter->getBoolAttr(true)); @@ -93,8 +93,8 @@ Value* InsertAllocAndDealloc(Location loc, Value* result, /// For every tensor-type value that is produced in the original function, /// this function returns the buffer that can be used in the converted /// function to store that values held in the tensor. -Value* GetBufferForResultValue(Location loc, Value* result, - ConversionPatternRewriter* rewriter) { +Value GetBufferForResultValue(Location loc, Value result, + ConversionPatternRewriter* rewriter) { if (auto existing_memref = GetTensorStoreOrReturnMemRef(result)) { return existing_memref; } @@ -108,7 +108,7 @@ class HloToLhloOpConverter : public ConversionPattern { : ConversionPattern(HloOpTy::getOperationName(), 1, context) {} PatternMatchResult matchAndRewrite( - Operation* op, ArrayRef operands, + Operation* op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { if (op->getParentRegion()->getBlocks().size() != 1) { emitError(op->getLoc(), @@ -116,14 +116,14 @@ class HloToLhloOpConverter : public ConversionPattern { "region containing the operation"); } const auto& original_results = op->getResults(); - SmallVector buffer_args(operands.begin(), operands.end()); + SmallVector buffer_args(operands.begin(), operands.end()); for (auto result : original_results) { buffer_args.push_back( GetBufferForResultValue(op->getLoc(), result, &rewriter)); } rewriter.create(op->getLoc(), llvm::None, buffer_args, op->getAttrs()); - rewriter.replaceOp(op, ArrayRef(buffer_args).slice(operands.size()), + rewriter.replaceOp(op, ArrayRef(buffer_args).slice(operands.size()), original_results); return matchSuccess(); } @@ -135,7 +135,7 @@ struct HloToLHloReduceConverter using OpConversionPattern::OpConversionPattern; PatternMatchResult matchAndRewrite( - xla_hlo::ReduceOp op, ArrayRef operands, + xla_hlo::ReduceOp op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { auto loc = op.getLoc(); // TODO(b/137624192) Implement variadic reduce. @@ -146,7 +146,7 @@ struct HloToLHloReduceConverter "region containing the operation"); } const auto& original_results = op.getResults(); - SmallVector buffer_args(operands.begin(), operands.end()); + SmallVector buffer_args(operands.begin(), operands.end()); for (auto result : original_results) { buffer_args.push_back(GetBufferForResultValue(loc, result, &rewriter)); } @@ -161,7 +161,7 @@ struct HloToLHloReduceConverter int original_arg_count = entry_block.getNumArguments(); for (int i = 0; i < original_arg_count; ++i) { auto old_arg = entry_block.getArgument(i); - auto old_type = old_arg->getType().cast(); + auto old_type = old_arg.getType().cast(); auto new_type = MemRefType::get(old_type.getShape(), old_type.getElementType()); auto new_arg = entry_block.addArgument(new_type); @@ -169,7 +169,7 @@ struct HloToLHloReduceConverter } // Add an argument for the result. entry_block.addArgument( - entry_block.getArgument(original_arg_count)->getType()); + entry_block.getArgument(original_arg_count).getType()); // Remove the old arguments. for (int i = original_arg_count - 1; i >= 0; --i) { entry_block.eraseArgument(i); @@ -178,7 +178,7 @@ struct HloToLHloReduceConverter rewriter.setInsertionPointToEnd(&entry_block); rewriter.create(loc); - rewriter.replaceOp(op, ArrayRef(buffer_args).slice(operands.size()), + rewriter.replaceOp(op, ArrayRef(buffer_args).slice(operands.size()), original_results); return matchSuccess(); @@ -191,7 +191,7 @@ class HloToLhloTensorLoadConverter : public ConversionPattern { : ConversionPattern(TensorLoadOp::getOperationName(), 1, context) {} PatternMatchResult matchAndRewrite( - Operation* op, ArrayRef operands, + Operation* op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { rewriter.replaceOp(op, operands, op->getResults()); return matchSuccess(); @@ -205,7 +205,7 @@ class HloToLhloTensorStoreConverter : public ConversionPattern { : ConversionPattern(TensorStoreOp::getOperationName(), 1, context) {} PatternMatchResult matchAndRewrite( - Operation* op, ArrayRef operands, + Operation* op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { rewriter.eraseOp(op); return matchSuccess(); @@ -218,7 +218,7 @@ class HloToLhloReturnConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; PatternMatchResult matchAndRewrite( - xla_hlo::ReturnOp op, ArrayRef operands, + xla_hlo::ReturnOp op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { rewriter.eraseOp(op); return matchSuccess(); diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc index 8a8afc01bec..8351f94d172 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc @@ -18,17 +18,17 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Casting.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/Block.h" // TF:local_config_mlir -#include "mlir/IR/BlockAndValueMapping.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir -#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Block.h" // TF:llvm-project +#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/TypeUtilities.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/PassRegistry.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" @@ -99,8 +99,8 @@ LogicalResult LowerConditionalOp(mlir::xla_hlo::ConditionalOp conditional_op) { mapper, &builder))) return failure(); - tail_block->addArguments(conditional_op.getResult()->getType()); - conditional_op.getResult()->replaceAllUsesWith(tail_block->getArgument(0)); + tail_block->addArguments(conditional_op.getResult().getType()); + conditional_op.getResult().replaceAllUsesWith(tail_block->getArgument(0)); op_inst->erase(); return success(); @@ -171,8 +171,8 @@ LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) { auto cond_value = builder.create(loc, return_value); // Get the body block arguments. - llvm::SmallVector successor_args(cond_block->args_begin(), - cond_block->args_end()); + llvm::SmallVector successor_args(cond_block->args_begin(), + cond_block->args_end()); builder.create(loc, cond_value, body_block, successor_args, tail_block, successor_args); @@ -201,7 +201,7 @@ LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) { // Erase the original while loop. tail_block->addArgument(while_op.getType()); - while_op.getResult()->replaceAllUsesWith(tail_block->getArgument(0)); + while_op.getResult().replaceAllUsesWith(tail_block->getArgument(0)); op_inst->erase(); return success(); diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index eaed2da8fa7..9c58b242460 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -24,25 +24,26 @@ limitations under the License. #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Diagnostics.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Matchers.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir -#include "mlir/IR/Types.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Transforms/DialectConversion.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Diagnostics.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Matchers.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/TypeUtilities.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" #include "tensorflow/compiler/mlir/xla/convert_op_folder.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_utils.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" +#include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/kernels/conv_grad_shape_utils.h" #include "tensorflow/core/util/padding.h" @@ -54,25 +55,20 @@ namespace { class LegalizeTF : public FunctionPass { public: - struct Options : public PassOptions { - Option allow_partial_conversion{ - *this, "allow-partial-conversion", - llvm::cl::desc("Allow operations that can't be legalized."), - llvm::cl::init(false)}; - }; - - explicit LegalizeTF(bool allow_partial_conversion) - : FunctionPass(), - allow_partial_conversion_(allow_partial_conversion) {} - - explicit LegalizeTF(const Options &option) - : LegalizeTF(option.allow_partial_conversion) {} + LegalizeTF() = default; + LegalizeTF(const LegalizeTF &) {} + explicit LegalizeTF(bool allow_partial_conversion) { + allow_partial_conversion_ = allow_partial_conversion; + } /// Performs the lowering to XLA dialect. void runOnFunction() override; private: - bool allow_partial_conversion_; + Option allow_partial_conversion_{ + *this, "allow-partial-conversion", + llvm::cl::desc("Allow operations that can't be legalized."), + llvm::cl::init(false)}; }; /// Returns if the given TF data format string is the default format. @@ -126,7 +122,7 @@ static IntegerAttr GetHLOAxisFromTFAxis(IntegerAttr attr, int64_t rank, // corresponding to the tensorflow axis. In particular, the tensorflow axis can // be negative, in which case, the corresponding HLO axis is // (axis + rank-of-the-tensor). -static llvm::Optional GetIntegerHLOAxisFromTFAxis(Value *value, +static llvm::Optional GetIntegerHLOAxisFromTFAxis(Value value, int64_t rank) { DenseIntElementsAttr attrs; if (!matchPattern(value, m_Constant(&attrs)) || @@ -139,7 +135,7 @@ static llvm::Optional GetIntegerHLOAxisFromTFAxis(Value *value, /// Returns a `ConvertOp` that casts the elements to a i64 type while retaining /// the shape of the input value. -static ConvertOp CastValueToI64(Location loc, Value *value, +static ConvertOp CastValueToI64(Location loc, Value value, PatternRewriter *rewriter) { return rewriter->create(loc, value, rewriter->getIntegerType(64)); } @@ -223,14 +219,30 @@ static void BuildReduceBody(Type element_type, Region *body, builder->create(loc, reducer.getResult()); } +// Builds region taking two arguments and returning second argument as the +// result. Corresponds to the function f(x, y) = y. +// Used in Scatter op's computation to update specific elements. +static void BuildBinaryAssignmentRegion(Type element_type, Region *region, + OpBuilder *builder) {} + +// Builds a set of operations for applying reduction on the input value. A +// tf.sum op is created and will be legalized to tfl ops automatically. +static Value ApplyReduction(Location loc, Value input, + DenseIntElementsAttr reduce_dims, + OpBuilder *builder) { + auto reduce_dims_op = builder->create(loc, reduce_dims); + return builder->create(loc, input, reduce_dims_op, + builder->getBoolAttr(false)); +} + //===----------------------------------------------------------------------===// // BatchNorm op utilities. //===----------------------------------------------------------------------===// static IntegerAttr getFeatureDimensionAttr(Builder &b, StringAttr format, - Value *input) { + Value input) { return b.getI64IntegerAttr( - getFeatureDimension(format, input->getType().cast())); + getFeatureDimension(format, input.getType().cast())); } //===----------------------------------------------------------------------===// @@ -241,8 +253,8 @@ static IntegerAttr getFeatureDimensionAttr(Builder &b, StringAttr format, // Requires input to have ranked tensor. static DenseIntElementsAttr getBiasFeatureDimension(Builder &b, StringAttr format, - Value *input) { - auto inputType = input->getType().cast(); + Value input) { + auto inputType = input.getType().cast(); size_t featureDim = getFeatureDimension(format, inputType); RankedTensorType type = RankedTensorType::get(1, b.getIntegerType(64)); return DenseIntElementsAttr::get(type, featureDim); @@ -306,9 +318,9 @@ static DenseIntElementsAttr GetInteriorPadding(ElementsAttr tf_padding) { // same shape, this broadcasts size 1 tensors up to any rank. Dynamic dimensions // must be broadcasted with a size 1 tensor or another dynamic dimension. // Returns false on rankless. -static bool AreBroadcastCompatible(Value *x, Value *y) { - auto x_rankless = x->getType().dyn_cast(); - auto y_rankless = y->getType().dyn_cast(); +static bool AreBroadcastCompatible(Value x, Value y) { + auto x_rankless = x.getType().dyn_cast(); + auto y_rankless = y.getType().dyn_cast(); if (!x_rankless || !y_rankless) { return false; } @@ -387,16 +399,16 @@ static void BuildArgMinMaxReductionBody(Type input_element_type, Location loc = body->getLoc(); StringAttr compare_direction = StringAttr::get(direction, builder->getContext()); - Value *compare = builder->create( + Value compare = builder->create( loc, block->getArgument(0), block->getArgument(2), /*broadcast_dimensions=*/nullptr, compare_direction); - Value *selected_input = builder->create( + Value selected_input = builder->create( loc, input_type, compare, block->getArgument(0), block->getArgument(2)); - Value *selected_index = builder->create( + Value selected_index = builder->create( loc, index_type, compare, block->getArgument(1), block->getArgument(3)); - Value *return_values[] = {selected_input, selected_index}; + Value return_values[] = {selected_input, selected_index}; builder->create(loc, return_values); } @@ -404,9 +416,9 @@ static void BuildArgMinMaxReductionBody(Type input_element_type, // Slice op utilities. //===----------------------------------------------------------------------===// -static bool CanBeTranslatedToDynamicSlice(Value *input, Value *start_indices, +static bool CanBeTranslatedToDynamicSlice(Value input, Value start_indices, DenseIntElementsAttr slice_sizes) { - auto input_ty = input->getType().dyn_cast(); + auto input_ty = input.getType().dyn_cast(); int64_t input_rank = input_ty.getRank(); ArrayRef input_shape = input_ty.getShape(); DenseIntElementsAttr constant_start_indices; @@ -445,7 +457,7 @@ static bool CanBeTranslatedToDynamicSlice(Value *input, Value *start_indices, // the end. HLO slice size can't be -1. As such, we need to translate TF slice // size -1 to HLO slice size. static DenseIntElementsAttr TFSliceSizes2HLOSliceSizes( - Value *input, Value *start_indices, DenseIntElementsAttr slice_sizes, + Value input, Value start_indices, DenseIntElementsAttr slice_sizes, Builder *builder) { DenseIntElementsAttr constant_start_indices; if (!matchPattern(start_indices, m_Constant(&constant_start_indices))) { @@ -453,7 +465,7 @@ static DenseIntElementsAttr TFSliceSizes2HLOSliceSizes( .cast(); } - auto input_ty = input->getType().dyn_cast(); + auto input_ty = input.getType().dyn_cast(); int64_t input_rank = input_ty.getRank(); ArrayRef input_shape = input_ty.getShape(); SmallVector normalized_sizes; @@ -495,7 +507,7 @@ static void BuildSortComparisonBody(llvm::ArrayRef element_types, Location loc = body->getLoc(); StringAttr compare_direction = StringAttr::get(direction, builder->getContext()); - Value *compare = builder->create( + Value compare = builder->create( loc, block->getArgument(0), block->getArgument(1), /*broadcast_dimensions=*/nullptr, compare_direction); @@ -562,9 +574,9 @@ class ConvertConv : public OpRewritePattern { std::string data_format = op.data_format().str(); if (!FormatFromString(data_format, &format)) return Pattern::matchFailure(); - auto input_ty = op.input()->getType().template dyn_cast(); + auto input_ty = op.input().getType().template dyn_cast(); auto filter_ty = - op.filter()->getType().template dyn_cast(); + op.filter().getType().template dyn_cast(); auto result_ty = op.getType().template dyn_cast(); // Input, filter and the result needs to have static shape for calculation @@ -654,7 +666,7 @@ class ConvertConv : public OpRewritePattern { auto paddings_attr = rewriter.getNamedAttr( "padding", DenseElementsAttr::get(paddings_ty, paddings)); - SmallVector operands(op.getOperands()); + SmallVector operands(op.getOperands()); NamedAttribute attrs[] = {rhs_dilations_attr, window_strides_attr, dimension_numbers_attr, feature_group_count_attr, batch_group_count_attr, paddings_attr}; @@ -686,10 +698,10 @@ class ConvertBF16FloorDivOp : public OpRewritePattern { PatternRewriter &rewriter) const override { auto l = op.x(); auto r = op.y(); - auto element_type = getElementTypeOrSelf(l->getType()); + auto element_type = getElementTypeOrSelf(l.getType()); if (!element_type.isBF16()) return matchFailure(); - auto out_type = op.z()->getType().cast(); + auto out_type = op.z().getType().cast(); l = rewriter.create(op.getLoc(), l, rewriter.getF32Type()); r = rewriter.create(op.getLoc(), r, rewriter.getF32Type()); @@ -731,6 +743,263 @@ class ConvertEinsumOp : public OpRewritePattern { } }; +// The base class to convert TensorFlow FusedBatchNormGrad*Op to HLO +// BatchNormGradOp for training and a sequence of binary ops for inference. +// TODO(b/145536565): move to legalize_tf_patterns.td if it applies. +template +class ConvertFusedBatchNormGradBase + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(FusedBatchNormGradOpT op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value grad = op.y_backprop(); + Value act = op.x(); + Value scale = op.scale(); + Value mean = op.reserve_space_1(); + Value var = op.reserve_space_2(); + + // TODO(b/141785544): Update this to not require static shapes. + // activation shape needs to be static to convert negative indices in + // TensorFlow to absolute indices required by HLO. + RankedTensorType act_type = + act.getType().template dyn_cast(); + if (!act_type) return Pattern::matchFailure(); + Type act_ele_type = act_type.getElementType(); + // To support mixed precision, the statistics type, which maybe more + // precise than the input types, are used for this op. + Type kernel_type = + scale.getType().template cast().getElementType(); + grad = rewriter.create(loc, grad, kernel_type); + act = rewriter.create(loc, act, kernel_type); + + auto feature_dim_attr = + getFeatureDimensionAttr(rewriter, op.data_formatAttr(), act); + auto feature_dim = feature_dim_attr.getValue().getSExtValue(); + + // Gets the result values. + Value x_backprop, scale_backprop, offset_backprop; + if (op.is_training()) { // training + // TODO(b/145536565): handle GPU logic seperately. + // Infers the output type with the converted `act`. + Type feature_type = RankedTensorType::get( + {GetDimSize(act_type, feature_dim)}, kernel_type); + Type result_type = TupleType::get( + {act.getType(), feature_type, feature_type}, rewriter.getContext()); + + auto training_op = rewriter.create( + loc, result_type, act, scale, mean, var, grad, op.epsilon(), + feature_dim_attr.getValue()); + + x_backprop = + rewriter.create(loc, training_op.getResult(), 0); + + scale_backprop = + rewriter.create(loc, training_op.getResult(), 1); + + offset_backprop = + rewriter.create(loc, training_op.getResult(), 2); + } else { // inference + SmallVector non_feature_dims; + for (int64_t i = 0; i < act_type.getRank(); ++i) { + if (i == feature_dim) continue; + non_feature_dims.push_back(i); + } + auto reduce_dims = GetI64ElementsAttr(non_feature_dims, &rewriter); + auto broadcast_dims = GetI64ElementsAttr({feature_dim}, &rewriter); + auto no_broadcast_dims = GetI64ElementsAttr({}, &rewriter); + + // scratch1 = rsqrt(var + epsilon) + RankedTensorType scalar_float = RankedTensorType::get({}, kernel_type); + auto epsilon = rewriter.create( + loc, DenseFPElementsAttr::get(scalar_float, {op.epsilon()})); + auto add_op = rewriter.create(loc, var, epsilon.getResult(), + no_broadcast_dims); + Value scratch1 = rewriter.create(loc, add_op); + + // scratch2 = sum(y_backprop * (x - mean)) + auto sub_op = rewriter.create(loc, act, mean, broadcast_dims); + auto weighted_grad = + rewriter.create(loc, grad, sub_op, no_broadcast_dims); + Value scratch2 = + ApplyReduction(loc, weighted_grad, reduce_dims, &rewriter); + + // x_backprop = y_backprop * (scale * scratch1) + auto scaled_grad = + rewriter.create(loc, op.scale(), scratch1, no_broadcast_dims); + x_backprop = + rewriter.create(loc, grad, scaled_grad, broadcast_dims); + + // scale_backprop = scratch2 * scratch1 + scale_backprop = + rewriter.create(loc, scratch1, scratch2, no_broadcast_dims); + + // offset_backprop = sum(y_backprop) + offset_backprop = ApplyReduction(loc, grad, reduce_dims, &rewriter); + } + + x_backprop = rewriter.create(loc, x_backprop, act_ele_type); + // It doesn't matter what values we provide for the last 2 results. + rewriter.replaceOp(op, + {/*x_backprop=*/x_backprop, + /*scale_backprop=*/scale_backprop, + /*offset_backprop=*/offset_backprop, op.x(), op.x()}); + return Pattern::matchSuccess(); + } +}; + +using ConvertFusedBatchNormGradOp = + ConvertFusedBatchNormGradBase; +using ConvertFusedBatchNormGradV2Op = + ConvertFusedBatchNormGradBase; +using ConvertFusedBatchNormGradV3Op = + ConvertFusedBatchNormGradBase; + +// Converts TensorFlow FusedBatchNormV3Op to either HLO BatchNormTrainingOp or +// HLO BatchNormInferenceOp, depending on the value of the 'is_training' +// parameter. +class ConvertFusedBatchNormV3Op + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(TF::FusedBatchNormV3Op op, + PatternRewriter &rewriter) const override { + auto feature_dim = + getFeatureDimensionAttr(rewriter, op.data_formatAttr(), op.x()); + + auto input_type_tensor = op.x().getType().dyn_cast(); + auto input_element_type = input_type_tensor.getElementType(); + + auto scale_type_tensor = op.scale().getType().dyn_cast(); + auto scale_element_type = scale_type_tensor.getElementType(); + + // TODO(b/69928690): Support mixed precision in the XLA batch + // normalization operators. As a workaround, create a new x with the same + // element type as scale (which may be more precise than the input type). + Value bn_train_input = rewriter.create( + op.getLoc(), op.x(), scale_element_type); + TensorType bn_train_input_type_tensor = + bn_train_input.getType().cast(); + + if (op.is_training()) { + // Training case. + auto operand_shape = bn_train_input_type_tensor.getShape(); + // The mean and variance are each 1 dimensional arrays the size of the + // feature dimension, with the same element type as the operand (x). + // This shape must be constructed manually because the mean and variance + // inputs are empty in the training case. + Type mean_var_type = RankedTensorType::get( + {operand_shape[feature_dim.getInt()]}, scale_element_type); + // Op result type is a tuple of 3 values: output with same shape as input; + // batch_mean, and batch_var. + SmallVector operand_types = {bn_train_input_type_tensor, + mean_var_type, mean_var_type}; + Type result_type = TupleType::get(operand_types, rewriter.getContext()); + + auto bn_train_op = rewriter.create( + op.getLoc(), result_type, bn_train_input, op.scale(), op.offset(), + op.epsilon(), feature_dim.getValue()); + // HLO op outputs a tuple of tensors. Extract those results. + auto bn_train_op_result = bn_train_op.getResult(); + Value y_out = rewriter.create( + op.getLoc(), bn_train_op_result, 0); + Value batch_mean = rewriter.create( + op.getLoc(), bn_train_op_result, 1); + Value batch_variance = rewriter.create( + op.getLoc(), bn_train_op_result, 2); + + // Apply Bessel's correction on the variance. + int total_input_size = bn_train_input_type_tensor.getNumElements(); + int total_scale_size = scale_type_tensor.getNumElements(); + int sample_size = total_input_size / total_scale_size; + int sample_size_minus_one = std::max(1, sample_size - 1); + double factor = static_cast(sample_size) / + static_cast(sample_size_minus_one); + auto factor_const_op = rewriter.create( + op.getLoc(), rewriter.getFloatAttr(scale_element_type, factor)); + + auto corrected_variance = rewriter.create( + op.getLoc(), batch_variance.getType(), batch_variance, + factor_const_op, /*DenseIntElementsAttr=*/DenseIntElementsAttr()); + + // Convert back to input type to stay aligned with expected output type + // for TF op. + y_out = rewriter.create(op.getLoc(), y_out, + input_element_type); + + // TF FusedBatchNormV3 op expects 5 outputs. Outputs 3 and 4 are + // currently marked as "reserved spaces 1 and 2". They are used to + // pass the per-batch mean and variance to the gradiant. Here we + // maintain the same behavior by setting them to the mean and variance + // calculated by BatchNormTraining. Output 5 is unused; it doesn't + // matter what we pass there. + rewriter.replaceOp(op, {y_out, /*batch_mean=*/batch_mean, + /*batch_variance=*/corrected_variance, + /*reserve_space_1=*/batch_mean, + /*reserve_space_2=*/corrected_variance, + /*reserve_space_3=*/op.x()}); + } else { // Inference case. + auto bn_train_op = rewriter.create( + op.getLoc(), + /*result_type=*/bn_train_input_type_tensor, bn_train_input, + op.scale(), op.offset(), op.mean(), op.variance(), op.epsilon(), + feature_dim.getValue()); + + // Convert back to input type to stay aligned with expected output type + // for TF op. + auto y_out = rewriter.create(op.getLoc(), bn_train_op, + input_element_type); + + // The mean, variance, and reserved space outputs of the batch norm op are + // not used for inference. It doesn't matter what values we provide for + // the last 5 results. + rewriter.replaceOp( + op, {/*y=*/y_out, /*batch_mean=*/op.x(), + /*batch_variance=*/op.x(), /*reserve_space_1=*/op.x(), + /*reserve_space_2=*/op.x(), /*reserve_space_3=*/op.x()}); + } + return Pattern::matchSuccess(); + } +}; + +// Returns padding attribute for ReduceWindow op with given params. +// +// Requires padding to be either 'SAME' or 'VALID' and the number of input +// dimensions to be equal to the size of window dimensions and window strides. +static DenseIntElementsAttr GetReduceWindowPadding( + llvm::ArrayRef input_dims, ArrayAttr window_dims, + ArrayAttr window_strides, StringRef padding, Builder *builder) { + if (padding == "VALID") return {}; + DCHECK_EQ(padding.str(), "SAME"); + + llvm::SmallVector input_shape, window_shape, strides; + input_shape.reserve(input_dims.size()); + window_shape.reserve(window_shape.size()); + strides.reserve(window_strides.size()); + + for (const auto &dim : input_dims) input_shape.push_back(dim); + for (Attribute attr : window_dims) + window_shape.push_back(attr.cast().getInt()); + for (Attribute attr : window_strides) + strides.push_back(attr.cast().getInt()); + + std::vector> paddings = + ::xla::MakePadding(input_shape, window_shape, strides, + ::xla::Padding::kSame); + int64_t rank = paddings.size(); + llvm::SmallVector flatten_paddings(rank * 2); + for (int i = 0; i < rank; i++) { + flatten_paddings[i] = paddings[i].first; + flatten_paddings[rank + i] = paddings[i].second; + } + return DenseIntElementsAttr::get( + RankedTensorType::get({2, rank}, builder->getIntegerType(64)), + flatten_paddings); +} + // Converts MaxPool op to HLO ReduceWindow op by setting appropriate window // dimensions with max as the reduction function. // @@ -746,21 +1015,21 @@ class ConvertMaxPoolOp : public OpRewritePattern { PatternMatchResult matchAndRewrite(TF::MaxPoolOp op, PatternRewriter &rewriter) const override { - // TODO(hinsu): Support 'SAME' padding mode. - if (op.padding() != "VALID") return matchFailure(); - Type element_type = - op.input()->getType().cast().getElementType(); + op.input().getType().cast().getElementType(); if (!element_type.isIntOrFloat()) return matchFailure(); Location loc = op.getLoc(); ConstOp init = GetMinValueForType(element_type, loc, &rewriter); + auto input_ty = op.input().getType().dyn_cast(); + if (!input_ty) return matchFailure(); + DenseIntElementsAttr paddings_attr = GetReduceWindowPadding( + input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter); auto reduce = rewriter.create( loc, op.getType(), op.input(), init.getResult(), GetI64ElementsAttr(op.ksize()), GetI64ElementsAttr(op.strides()), /*base_dilations=*/DenseIntElementsAttr(), - /*window_dilations=*/DenseIntElementsAttr(), - /*paddings=*/DenseIntElementsAttr()); + /*window_dilations=*/DenseIntElementsAttr(), paddings_attr); BuildReduceBody(element_type, &reduce.body(), &rewriter); rewriter.replaceOp(op, reduce.getResult()); @@ -798,9 +1067,9 @@ class ConvertSigmoidOp : public OpRewritePattern { auto scalar_one = rewriter.create( op.getLoc(), - rewriter.getFloatAttr(getElementTypeOrSelf(operand->getType()), 0.5)); + rewriter.getFloatAttr(getElementTypeOrSelf(operand.getType()), 0.5)); - auto shaped_type = operand->getType().cast(); + auto shaped_type = operand.getType().cast(); auto constant_ones = rewriter.create( op.getLoc(), shaped_type, scalar_one, DenseIntElementsAttr::get( @@ -811,7 +1080,7 @@ class ConvertSigmoidOp : public OpRewritePattern { auto scaled_input = rewriter.create( op.getLoc(), operand, constant_ones, DenseIntElementsAttr()); auto tanh_op = - rewriter.create(op.getLoc(), operand->getType(), scaled_input); + rewriter.create(op.getLoc(), operand.getType(), scaled_input); auto mul_op = rewriter.create(op.getLoc(), tanh_op, constant_ones, /*DenseIntElementsAttr=*/DenseIntElementsAttr()); @@ -856,11 +1125,11 @@ class ConvertSoftmaxOp : public OpRewritePattern { PatternMatchResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { - Value *logits = op.logits(); + Value logits = op.logits(); // Softmax converter requires ranked type because the XLA reduce ops used // while lowering requires dimensions attribute to reduce along. - RankedTensorType type = logits->getType().dyn_cast(); + RankedTensorType type = logits.getType().dyn_cast(); if (!type) return Pattern::matchFailure(); auto loc = op.getLoc(); @@ -886,16 +1155,16 @@ class ConvertSoftmaxOp : public OpRewritePattern { rewriter.create(loc, type, logits, max_logits, batch_dims); // Exponentiate the inputs. - Value *exp = rewriter.create(loc, type, shifted_logits); + Value exp = rewriter.create(loc, type, shifted_logits); // Compute summation of the exponentials. auto exp_sum = rewriter.create(loc, exp, reduce_dim, /*keep_dims=*/rewriter.getBoolAttr(false)); - Value *sum = exp_sum.getResult(); + Value sum = exp_sum.getResult(); if (use_log) { - Value *log = rewriter.create(loc, sum); + Value log = rewriter.create(loc, sum); rewriter.replaceOpWithNewOp(op, shifted_logits, log, batch_dims); } else { rewriter.replaceOpWithNewOp(op, exp, sum, batch_dims); @@ -932,12 +1201,12 @@ class ConvertSizeOp : public OpRewritePattern { PatternMatchResult matchAndRewrite(TF::SizeOp op, PatternRewriter &rewriter) const override { - Value *input = op.input(); - auto input_ty = input->getType().dyn_cast(); + Value input = op.input(); + auto input_ty = input.getType().dyn_cast(); if (!input_ty) return Pattern::matchFailure(); const int64_t rank = input_ty.getRank(); - auto result_type = op.getResult()->getType(); + auto result_type = op.getResult().getType(); Operation *size = GetScalarConstOfType(result_type.cast().getElementType(), op.getLoc(), 1, &rewriter); @@ -995,7 +1264,7 @@ class ConvertSplitOp : public OpRewritePattern { PatternMatchResult matchAndRewrite(TF::SplitOp op, PatternRewriter &rewriter) const override { // We can only split along static dimensions. - auto input_type = op.value()->getType().dyn_cast(); + auto input_type = op.value().getType().dyn_cast(); if (!input_type) return matchFailure(); // We can only match when the split dimension is a constant scalar. @@ -1029,7 +1298,7 @@ class ConvertSplitOp : public OpRewritePattern { SmallVector strides(input_rank, 1); // All HLO slice results used to replace the original tf.Split op. - SmallVector slices; + SmallVector slices; slices.reserve(num_splits); for (int i = 0; i < num_splits; ++i) { @@ -1087,7 +1356,7 @@ class ConvertSplitVOp : public OpRewritePattern { PatternRewriter &rewriter) const override { // We can only split along static dimensions. // TODO(b/145731001): enhance to support dynamic-shaped inputs. - auto input_type = op.value()->getType().dyn_cast(); + auto input_type = op.value().getType().dyn_cast(); if (!input_type) return matchFailure(); // We can only match when the split dimension is a constant scalar. @@ -1141,7 +1410,7 @@ class ConvertSplitVOp : public OpRewritePattern { SmallVector strides(input_rank, 1); // All HLO slice results used to replace the original tf.Split op. - SmallVector slices; + SmallVector slices; slices.reserve(op.getNumResults()); for (int i = 0; i < op.getNumResults(); ++i) { @@ -1184,7 +1453,7 @@ class ConvertStridedSliceOp : public OpRewritePattern { // // TODO(hinsu): Relax this constraint for ops without negative indices and // strides. - auto input_ty = op.input()->getType().dyn_cast(); + auto input_ty = op.input().getType().dyn_cast(); if (!input_ty || !input_ty.hasStaticShape()) return matchFailure(); ArrayRef input_shape = input_ty.getShape(); @@ -1195,20 +1464,9 @@ class ConvertStridedSliceOp : public OpRewritePattern { auto result_ty = op.getType().dyn_cast(); if (!result_ty || !result_ty.hasStaticShape()) return matchFailure(); - // TODO(hinsu): Support non-zero mask values. Currently only - // 'shrink_axis_mask' is supported. - for (StringRef mask : - {"begin_mask", "end_mask", "ellipsis_mask", "new_axis_mask"}) { - auto attr = op.getAttrOfType(mask); - if (attr && attr.getValue() != 0) return matchFailure(); - } - - // TODO(hinsu): Support lowering for ops with dynamic begin and end values - // when it is possible to derive indices based on mask attributes. - DenseIntElementsAttr begin_indices, end_indices, strides; - if (!matchPattern(op.begin(), m_Constant(&begin_indices)) || - !matchPattern(op.end(), m_Constant(&end_indices)) || - !matchPattern(op.strides(), m_Constant(&strides))) + SmallVector begin_indices, end_indices, strides; + if (!op.GetSlicedBoundRanges(input_shape, &begin_indices, &end_indices, + &strides)) return matchFailure(); SmallVector hlo_begin_indices, hlo_end_indices, hlo_strides, @@ -1218,18 +1476,15 @@ class ConvertStridedSliceOp : public OpRewritePattern { hlo_end_indices.reserve(input_rank); hlo_strides.reserve(input_rank); - int64_t indices_elements = begin_indices.getNumElements(); + int64_t indices_elements = begin_indices.size(); if (input_rank < indices_elements) return matchFailure(); // Convert from TensorFlow negative or out of range indices and strides // values to legal HLO Slice attributes. for (int i = 0, e = indices_elements; i != e; i++) { - int64_t begin = begin_indices.getValue(i).getInt(); - int64_t end = end_indices.getValue(i).getInt(); - int64_t stride = strides.getValue(i).getInt(); - - if (begin < 0) begin = input_shape[i] + begin; - if (end < 0) end = input_shape[i] + end; + int64_t begin = begin_indices[i]; + int64_t end = end_indices[i]; + int64_t stride = strides[i]; if (stride < 0) { // Negative stride means that the output values are computed starting @@ -1297,8 +1552,8 @@ class ConvertStridedSliceGradOp &strides)) return matchFailure(); - Value *grad = op.dy(); - Type element_type = grad->getType().cast().getElementType(); + Value grad = op.dy(); + Type element_type = grad.getType().cast().getElementType(); // Perform reshape to undo any new/shrink axies done by strided slice. grad = rewriter.create( @@ -1338,7 +1593,7 @@ class ConvertStridedSliceGradOp if (!dims_to_reverse.empty()) { grad = rewriter.create( - op.getLoc(), grad->getType(), grad, + op.getLoc(), grad.getType(), grad, GetI64ElementsAttr(dims_to_reverse, &rewriter)); } @@ -1376,7 +1631,7 @@ class ConvertRangeOp : public OpRewritePattern { PatternMatchResult matchAndRewrite(TF::RangeOp op, PatternRewriter &rewriter) const override { auto result = op.getResult(); - auto result_type = result->getType(); + auto result_type = result.getType(); if (!result_type.cast().hasStaticShape()) { return matchFailure(); } @@ -1408,7 +1663,7 @@ class GenericConvertReductionOp : public OpRewritePattern { // TODO(b/141785544): Update this to not require static shapes. // Input shape needs to be static to convert negative indices in TensorFlow // to absolute indices required by HLO. - auto input_ty = op.input()->getType().template dyn_cast(); + auto input_ty = op.input().getType().template dyn_cast(); if (!input_ty) return this->matchFailure(); ArrayRef input_shape = input_ty.getShape(); @@ -1439,14 +1694,14 @@ class GenericConvertReductionOp : public OpRewritePattern { rewriter.create(loc, op.input(), reduce_element_type); // Each reduction op can have a different initial value. - Value *init = Derived::GetInitialValue(reduce_element_type, loc, rewriter); + Value init = Derived::GetInitialValue(reduce_element_type, loc, rewriter); auto reduction = rewriter.create( loc, casted_input.getResult(), init, GetI64ElementsAttr(xla_dimensions, &rewriter)); BuildReduceBody(reduce_element_type, &reduction.body(), &rewriter); - Value *result = reduction.getResult(0); + Value result = reduction.getResult(0); // The mean op needs to divide by the product of the reduced dimensions. if (std::is_same::value) { @@ -1490,8 +1745,8 @@ class ConvertMeanOp : public GenericConvertReductionOp { public: using GenericConvertReductionOp::GenericConvertReductionOp; - static Value *GetInitialValue(Type reduce_element_type, Location loc, - PatternRewriter &rewriter) { + static Value GetInitialValue(Type reduce_element_type, Location loc, + PatternRewriter &rewriter) { return GetScalarConstOfType(reduce_element_type, loc, 0, &rewriter); } }; @@ -1506,8 +1761,8 @@ class ConvertSumOp public: using GenericConvertReductionOp::GenericConvertReductionOp; - static Value *GetInitialValue(Type reduce_element_type, Location loc, - PatternRewriter &rewriter) { + static Value GetInitialValue(Type reduce_element_type, Location loc, + PatternRewriter &rewriter) { return GetScalarConstOfType(reduce_element_type, loc, 0, &rewriter); } }; @@ -1523,8 +1778,8 @@ class ConvertMaxOp public: using GenericConvertReductionOp::GenericConvertReductionOp; - static Value *GetInitialValue(Type reduce_element_type, Location loc, - PatternRewriter &rewriter) { + static Value GetInitialValue(Type reduce_element_type, Location loc, + PatternRewriter &rewriter) { return GetMinValueForType(reduce_element_type, loc, &rewriter); } }; @@ -1538,8 +1793,8 @@ class ConvertAllOp : public GenericConvertReductionOp { public: using GenericConvertReductionOp::GenericConvertReductionOp; - static Value *GetInitialValue(Type reduce_element_type, Location loc, - PatternRewriter &rewriter) { + static Value GetInitialValue(Type reduce_element_type, Location loc, + PatternRewriter &rewriter) { return GetScalarConstOfType(reduce_element_type, loc, 1, &rewriter); } }; @@ -1553,8 +1808,8 @@ class ConvertAnyOp : public GenericConvertReductionOp { public: using GenericConvertReductionOp::GenericConvertReductionOp; - static Value *GetInitialValue(Type reduce_element_type, Location loc, - PatternRewriter &rewriter) { + static Value GetInitialValue(Type reduce_element_type, Location loc, + PatternRewriter &rewriter) { return GetScalarConstOfType(reduce_element_type, loc, 0, &rewriter); } }; @@ -1571,7 +1826,7 @@ class ConvertArgMinMaxOp : public OpRewritePattern { PatternMatchResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { RankedTensorType input_type = - op.input()->getType().template dyn_cast(); + op.input().getType().template dyn_cast(); if (!input_type) { return this->matchFailure(); } @@ -1582,17 +1837,17 @@ class ConvertArgMinMaxOp : public OpRewritePattern { if (!input_element_type.isIntOrFloat()) return this->matchFailure(); Location loc = op.getLoc(); - Value *init_value = + Value init_value = Derived::GetInitialValue(input_element_type, loc, rewriter); RankedTensorType output_type = - op.output()->getType().template dyn_cast(); + op.output().getType().template dyn_cast(); if (!output_type) { return this->matchFailure(); } Type index_element_type = output_type.getElementType(); - Value *index_init_value = + Value index_init_value = GetScalarConstOfType(index_element_type, loc, 0, &rewriter); RankedTensorType index_type = @@ -1607,21 +1862,21 @@ class ConvertArgMinMaxOp : public OpRewritePattern { IntegerAttr iota_dimension = IntegerAttr::get(rewriter.getIntegerType(64), axis); - Value *index_values = + Value index_values = rewriter.create(loc, index_type, iota_dimension); std::vector dimensions = input_type.getShape(); dimensions.erase(dimensions.begin() + axis); ArrayRef reduction_result_shape(dimensions); - Value *operands[] = {op.input(), index_values}; - Value *init_values[] = {init_value, index_init_value}; + Value operands[] = {op.input(), index_values}; + Value init_values[] = {init_value, index_init_value}; DenseIntElementsAttr reduction_dimensions = GetI64ElementsAttr({axis}, &rewriter); auto reduction = rewriter.create( - loc, llvm::ArrayRef(operands), - llvm::ArrayRef(init_values), reduction_dimensions); + loc, llvm::ArrayRef(operands), + llvm::ArrayRef(init_values), reduction_dimensions); StringRef direction = Derived::GetDirection(); BuildArgMinMaxReductionBody(input_element_type, index_element_type, direction, &reduction.body(), &rewriter); @@ -1643,14 +1898,70 @@ class ConvertArgMaxOp public: using ConvertArgMinMaxOp::ConvertArgMinMaxOp; - static Value *GetInitialValue(Type reduce_element_type, Location loc, - PatternRewriter &rewriter) { + static Value GetInitialValue(Type reduce_element_type, Location loc, + PatternRewriter &rewriter) { return GetMinValueForType(reduce_element_type, loc, &rewriter); } static StringRef GetDirection() { return "GT"; } }; +// Converts TF TensorScatterUpdate op into Scatter Op with assignment: +// +// %result = "xla_hlo.scatter"(%tensor, %indices, %updates) +// { dimensions = ... } +// +class ConvertTensorScatterUpdateOp + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(TF::TensorScatterUpdateOp op, + PatternRewriter &rewriter) const override { + auto tensor_ty = op.tensor().getType().dyn_cast(); + auto indices_ty = op.indices().getType().dyn_cast(); + auto updates_ty = op.updates().getType().dyn_cast(); + + if (!tensor_ty || !indices_ty || !updates_ty) return matchFailure(); + // Last dimension of the indices needs to known at compile time for + // computation of the 'update_window_dims' attribute in the dimensions + // struct. + int64_t num_index_dims = indices_ty.getShape().back(); + if (ShapedType::isDynamic(num_index_dims)) return matchFailure(); + + int64_t tensor_rank = tensor_ty.getRank(); + int64_t indices_rank = indices_ty.getRank(); + int64_t updates_rank = updates_ty.getRank(); + + int64_t window_dims = tensor_rank - num_index_dims; + auto dims_attr = ScatterDimensionNumbers::get( + GetI64ElementsAttrForSeq(updates_rank - window_dims, updates_rank, + &rewriter), + GetI64ElementsAttrForSeq(0, num_index_dims, &rewriter), + GetI64ElementsAttrForSeq(0, num_index_dims, &rewriter), + rewriter.getI64IntegerAttr(indices_rank - 1), rewriter.getContext()); + + Location loc = op.getLoc(); + auto scatter = rewriter.create( + loc, op.getType(), op.tensor(), op.indices(), op.updates(), dims_attr); + + // Build region to assign the new value. + [&](Region *region) { + OpBuilder::InsertionGuard guard(rewriter); + Block *block = rewriter.createBlock(region); + + // Block arguments are scalars of the given element type. + Type type = + RankedTensorType::get(/*shape=*/{}, tensor_ty.getElementType()); + block->addArguments({type, type}); + rewriter.create(loc, block->getArgument(1)); + }(&scatter.update_computation()); + + rewriter.replaceOp(op, scatter.getResult()); + return matchSuccess(); + } +}; + // Converts Tile op to HLO BroadcastInDim and Reshape ops. // For shape [S1, S2] and multiples [M1, M2], // MS1 = M1 * S1; MS2 = M2 * S2 @@ -1666,7 +1977,7 @@ class ConvertTileOp : public OpRewritePattern { PatternMatchResult matchAndRewrite(TF::TileOp op, PatternRewriter &rewriter) const override { - auto input_ty = op.input()->getType().dyn_cast(); + auto input_ty = op.input().getType().dyn_cast(); if (!input_ty || !input_ty.hasStaticShape()) return matchFailure(); ArrayRef input_shape = input_ty.getShape(); Type element_type = input_ty.getElementType(); @@ -1707,7 +2018,7 @@ class ConvertTileOp : public OpRewritePattern { RankedTensorType::get(broadcasted_shape, element_type); Type output_type = op.getType(); - Value *result = rewriter.create( + Value result = rewriter.create( loc, broadcasted_type, op.input(), GetI64ElementsAttr(broadcast_dimensions, &rewriter)); @@ -1727,19 +2038,24 @@ class ConvertMaxPoolGradOp : public OpRewritePattern { PatternMatchResult matchAndRewrite(TF::MaxPoolGradOp op, PatternRewriter &rewriter) const override { - // TODO(parkers): Support 'SAME' padding mode. - if (op.padding() != "VALID") return matchFailure(); - Location loc = op.getLoc(); Type element_type = - op.orig_input()->getType().cast().getElementType(); + op.orig_input().getType().cast().getElementType(); + + // Compute paddings using the original input and kernel shape and strides. + // Here, ReduceWindow op as used as the MaxPool op is lowered to the + // ReduceWindow op. + auto input_ty = op.orig_input().getType().dyn_cast(); + if (!input_ty) return matchFailure(); + DenseIntElementsAttr paddings_attr = GetReduceWindowPadding( + input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter); auto result = rewriter.create( loc, op.getType(), op.orig_input(), op.grad(), GetScalarConstOfType(element_type, loc, 0, &rewriter), GetI64ElementsAttr(op.ksize()), GetI64ElementsAttr(op.strides()), - nullptr); + paddings_attr); BuildReduceBody(element_type, &result.scatter(), &rewriter); { @@ -1783,11 +2099,11 @@ class ConvertConv2DBackpropInputOp return Pattern::matchFailure(); auto out_backprop_ty = - op.out_backprop()->getType().dyn_cast(); + op.out_backprop().getType().dyn_cast(); if (!out_backprop_ty || !out_backprop_ty.hasStaticShape()) return matchFailure(); ArrayRef out_backprop_shape = out_backprop_ty.getShape(); - auto filter_ty = op.filter()->getType().dyn_cast(); + auto filter_ty = op.filter().getType().dyn_cast(); if (!filter_ty || !filter_ty.hasStaticShape()) return matchFailure(); ArrayRef filter_shape = filter_ty.getShape(); int num_spatial_dims = 2; @@ -1859,7 +2175,7 @@ class ConvertConv2DBackpropInputOp auto paddings_attr = DenseIntElementsAttr::get(paddings_ty, conv_paddings); auto spatial_dims_attr = GetI64ElementsAttr(spatial_dims, &rewriter); - Value *filter = op.filter(); + Value filter = op.filter(); if (feature_group_count != 1) { /* @@ -1876,7 +2192,7 @@ class ConvertConv2DBackpropInputOp // activation gradients // = gradients (with padding and dilation) mirrored_weights - Value *result = rewriter.create( + Value result = rewriter.create( loc, op.getType(), op.out_backprop(), filter, /*window_strides=*/GetI64ElementsAttr(ones, &rewriter), /*padding=*/paddings_attr, GetI64ElementsAttr(lhs_dilation, &rewriter), @@ -1927,11 +2243,11 @@ class ConvertConv2DBackpropFilterOp return Pattern::matchFailure(); auto out_backprop_ty = - op.out_backprop()->getType().dyn_cast(); + op.out_backprop().getType().dyn_cast(); if (!out_backprop_ty || !out_backprop_ty.hasStaticShape()) return matchFailure(); ArrayRef out_backprop_shape = out_backprop_ty.getShape(); - auto input_ty = op.input()->getType().dyn_cast(); + auto input_ty = op.input().getType().dyn_cast(); if (!input_ty || !input_ty.hasStaticShape()) return matchFailure(); ArrayRef input_shape = input_ty.getShape(); @@ -2077,7 +2393,7 @@ class ConvertConv2DBackpropFilterOp auto feature_dim_attr = rewriter.getI64IntegerAttr(feature_dim); Location loc = op.getLoc(); - Value *result = rewriter.create( + Value result = rewriter.create( loc, op.getType(), op.input(), op.out_backprop(), /*window_strides=*/GetI64ElementsAttr(window_strides, &rewriter), /*padding=*/paddings_attr, GetI64ElementsAttr(lhs_dilation, &rewriter), @@ -2116,7 +2432,7 @@ class ConvertOneHotOp : public OpRewritePattern { PatternMatchResult matchAndRewrite(TF::OneHotOp op, PatternRewriter &rewriter) const override { - auto indices_ty = op.indices()->getType().dyn_cast(); + auto indices_ty = op.indices().getType().dyn_cast(); if (!indices_ty || !indices_ty.hasStaticShape()) return matchFailure(); ArrayRef indices_shape = indices_ty.getShape(); Type element_type = indices_ty.getElementType(); @@ -2140,21 +2456,21 @@ class ConvertOneHotOp : public OpRewritePattern { Location loc = op.getLoc(); auto index_type = RankedTensorType::get(output_dims, element_type); - Value *compare = rewriter.create( + Value compare = rewriter.create( loc, op.indices(), rewriter.create( loc, index_type, IntegerAttr::get(rewriter.getIntegerType(64), axis)), GetI64ElementsAttr(broadcast_dims, &rewriter), StringAttr::get("EQ", rewriter.getContext())); - Value *on_value = rewriter.create( + Value on_value = rewriter.create( loc, op.getType(), op.on_value(), GetI64ElementsAttr(output_dims, &rewriter)); - Value *off_value = rewriter.create( + Value off_value = rewriter.create( loc, op.getType(), op.off_value(), GetI64ElementsAttr(output_dims, &rewriter)); - Value *result = rewriter.create(loc, op.getType(), compare, - on_value, off_value); + Value result = rewriter.create(loc, op.getType(), compare, + on_value, off_value); rewriter.replaceOp( op, {result}, @@ -2206,7 +2522,7 @@ class ConvertTopKV2Op : public OpRewritePattern { // The last dimension of the input tensor's shape should be known so we can // have clamped end_indices for slices. - TensorType input_type = op.input()->getType().cast(); + TensorType input_type = op.input().getType().cast(); if (!input_type.hasRank()) return matchFailure(); int64_t input_rank = input_type.getRank(); int64_t last_dim_index = input_rank - 1; @@ -2216,14 +2532,14 @@ class ConvertTopKV2Op : public OpRewritePattern { // Create an Itoa op for indices. auto i32_type = rewriter.getIntegerType(32); Type iota_type = RankedTensorType::get(input_type.getShape(), i32_type); - Value *iota_op = rewriter.create( + Value iota_op = rewriter.create( op.getLoc(), iota_type, rewriter.getI64IntegerAttr(last_dim_index)); // Create the sort op. It takes two inputs, one for the original input, the // other for the indices. auto sort_op = rewriter.create( - op.getLoc(), llvm::ArrayRef{op.input(), iota_op}, - last_dim_index, /*is_stable=*/true); + op.getLoc(), llvm::ArrayRef{op.input(), iota_op}, last_dim_index, + /*is_stable=*/true); BuildSortComparisonBody({input_type.getElementType(), i32_type}, /*direction=*/"GT", &sort_op.comparator(), &rewriter); @@ -2242,13 +2558,13 @@ class ConvertTopKV2Op : public OpRewritePattern { // Get the slice for the top K elements. - Value *values = rewriter.create( + Value values = rewriter.create( op.getLoc(), tuple_first_element, GetI64ElementsAttr(begin_indices, &rewriter), GetI64ElementsAttr(end_indices, &rewriter), GetI64ElementsAttr(strides, &rewriter)); - Value *indices = rewriter.create( + Value indices = rewriter.create( op.getLoc(), tuple_second_element, GetI64ElementsAttr(begin_indices, &rewriter), GetI64ElementsAttr(end_indices, &rewriter), @@ -2271,7 +2587,7 @@ class ConvertUnpackOp : public OpRewritePattern { PatternMatchResult matchAndRewrite(TF::UnpackOp op, PatternRewriter &rewriter) const override { - auto value_type = op.value()->getType().cast(); + auto value_type = op.value().getType().cast(); if (!value_type) return matchFailure(); int64_t value_rank = value_type.getRank(); @@ -2284,7 +2600,7 @@ class ConvertUnpackOp : public OpRewritePattern { SmallVector strides(value_rank, 1); // All HLO slice+reshape results used to replace the original tf.Unpack op. - SmallVector results; + SmallVector results; results.reserve(op.getNumResults()); for (int i = 0; i < op.getNumResults(); ++i) { @@ -2329,12 +2645,12 @@ class GenericConvertUnsortedSegmentReductionOp : public OpRewritePattern { PatternMatchResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { - auto data_type = op.data()->getType().template dyn_cast(); + auto data_type = op.data().getType().template dyn_cast(); if (!data_type) return this->matchFailure(); int64_t data_rank = data_type.getRank(); auto segment_ids_type = - op.segment_ids()->getType().template dyn_cast(); + op.segment_ids().getType().template dyn_cast(); if (!segment_ids_type) return this->matchFailure(); int64_t segment_ids_rank = segment_ids_type.getRank(); @@ -2353,22 +2669,20 @@ class GenericConvertUnsortedSegmentReductionOp : public OpRewritePattern { // Broadccast the initial value for reduction. This will become the // 'operand' parameter to scatter to for the final scatter op. - Value *init = ConcreteClass::GetInitialValue(data_type.getElementType(), - op.getLoc(), rewriter); + Value init = ConcreteClass::GetInitialValue(data_type.getElementType(), + op.getLoc(), rewriter); auto broadcasted_init = rewriter.create( op.getLoc(), output_type, init, GetI64ElementsAttr(output_shape, &rewriter)); // Parameters for the generated scatter op. - auto range = llvm::seq(segment_ids_rank, data_rank); - SmallVector update_window_dims(range.begin(), range.end()); SmallVector inserted_window_dims(1, 0); SmallVector scatter_dims_to_operand_dims(1, 0); int64_t index_vector_dim = segment_ids_rank; // Put all parameters in a StructAttr. auto dims_attr = ScatterDimensionNumbers::get( - GetI64ElementsAttr(update_window_dims, &rewriter), + GetI64ElementsAttrForSeq(segment_ids_rank, data_rank, &rewriter), GetI64ElementsAttr(inserted_window_dims, &rewriter), GetI64ElementsAttr(scatter_dims_to_operand_dims, &rewriter), rewriter.getI64IntegerAttr(index_vector_dim), rewriter.getContext()); @@ -2391,8 +2705,8 @@ class ConvertUnsortedSegmentMaxOp using GenericConvertUnsortedSegmentReductionOp:: GenericConvertUnsortedSegmentReductionOp; - static Value *GetInitialValue(Type reduce_element_type, Location loc, - PatternRewriter &rewriter) { + static Value GetInitialValue(Type reduce_element_type, Location loc, + PatternRewriter &rewriter) { return GetMinValueForType(reduce_element_type, loc, &rewriter); } }; @@ -2404,8 +2718,8 @@ class ConvertUnsortedSegmentMinOp using GenericConvertUnsortedSegmentReductionOp:: GenericConvertUnsortedSegmentReductionOp; - static Value *GetInitialValue(Type reduce_element_type, Location loc, - PatternRewriter &rewriter) { + static Value GetInitialValue(Type reduce_element_type, Location loc, + PatternRewriter &rewriter) { return GetMaxValueForType(reduce_element_type, loc, &rewriter); } }; @@ -2417,8 +2731,8 @@ class ConvertUnsortedSegmentProdOp using GenericConvertUnsortedSegmentReductionOp:: GenericConvertUnsortedSegmentReductionOp; - static Value *GetInitialValue(Type reduce_element_type, Location loc, - PatternRewriter &rewriter) { + static Value GetInitialValue(Type reduce_element_type, Location loc, + PatternRewriter &rewriter) { return GetScalarConstOfType(reduce_element_type, loc, 1, &rewriter); } }; @@ -2430,8 +2744,8 @@ class ConvertUnsortedSegmentSumOp using GenericConvertUnsortedSegmentReductionOp:: GenericConvertUnsortedSegmentReductionOp; - static Value *GetInitialValue(Type reduce_element_type, Location loc, - PatternRewriter &rewriter) { + static Value GetInitialValue(Type reduce_element_type, Location loc, + PatternRewriter &rewriter) { return GetScalarConstOfType(reduce_element_type, loc, 0, &rewriter); } }; @@ -2450,16 +2764,18 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) { // here for lowering to HLO. TF::PopulateLoweringTFPatterns(context, &patterns); patterns.insert< - ConvertArgMaxOp, ConvertBF16FloorDivOp, ConvertConv2D, ConvertEinsumOp, - ConvertMaxPoolOp, ConvertRangeOp, ConvertSigmoidOp, ConvertSizeOp, - ConvertMaxPoolOp, ConvertRangeOp, ConvertSigmoidOp, + ConvertAllOp, ConvertAnyOp, ConvertArgMaxOp, ConvertBF16FloorDivOp, + ConvertConv2D, ConvertConv2DBackpropFilterOp, + ConvertConv2DBackpropInputOp, ConvertEinsumOp, + ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op, + ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV3Op, ConvertMaxOp, + ConvertMaxPoolOp, ConvertMaxPoolGradOp, ConvertMeanOp, ConvertOneHotOp, + ConvertRangeOp, ConvertSigmoidOp, ConvertSizeOp, ConvertSoftmaxOp, ConvertSoftmaxOp, ConvertSplitOp, ConvertSplitVOp, - ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertTopKV2Op, - ConvertUnpackOp, ConvertMeanOp, ConvertSumOp, ConvertMaxOp, ConvertAllOp, - ConvertAnyOp, ConvertTileOp, ConvertMaxPoolGradOp, ConvertOneHotOp, - ConvertConv2DBackpropInputOp, ConvertConv2DBackpropFilterOp, - ConvertUnsortedSegmentMaxOp, ConvertUnsortedSegmentMinOp, + ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp, + ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op, + ConvertUnpackOp, ConvertUnsortedSegmentMaxOp, ConvertUnsortedSegmentMinOp, ConvertUnsortedSegmentProdOp, ConvertUnsortedSegmentSumOp>( op->getContext()); @@ -2482,7 +2798,7 @@ void LegalizeTF::runOnFunction() { signalPassFailure(); } -static PassRegistration pass( +static PassRegistration pass( "xla-legalize-tf", "Legalize from TensorFlow to the XLA dialect"); } // end namespace diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc index ac14bca6b2b..ee7cd7ea6db 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc @@ -28,19 +28,19 @@ limitations under the License. #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringSet.h" #include "llvm/ADT/iterator_range.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/BlockAndValueMapping.h" // TF:local_config_mlir -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir -#include "mlir/IR/Types.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir -#include "mlir/Transforms/DialectConversion.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/TypeUtilities.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/PassRegistry.h" // TF:llvm-project +#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" @@ -64,13 +64,12 @@ createLegalizeTFControlFlowPass() { namespace { -void Detuple(Value* tuple, Operation::result_range replace, - OpBuilder* builder) { +void Detuple(Value tuple, Operation::result_range replace, OpBuilder* builder) { // De-tuple the results of the xla hlo conditional result. for (auto result_it : llvm::enumerate(replace)) { auto get_tuple_value = builder->create( - result_it.value()->getLoc(), tuple, result_it.index()); - result_it.value()->replaceAllUsesWith(get_tuple_value); + result_it.value().getLoc(), tuple, result_it.index()); + result_it.value().replaceAllUsesWith(get_tuple_value); } } @@ -87,7 +86,7 @@ void ImportXlaRegion(mlir::FuncOp func, Region* dest_region, Location loc, auto entry_block = builder.createBlock(dest_region); auto tuple_arg = entry_block->addArgument( builder.getTupleType(func.getType().getInputs())); - llvm::SmallVector detupled_args; + llvm::SmallVector detupled_args; detupled_args.reserve(func.getNumArguments()); for (int64_t i = 0, s = func.getNumArguments(); i < s; i++) { @@ -110,12 +109,12 @@ void LowerIf(TF::IfOp op, ModuleOp module) { // XLA prefers tuple arguments for control flow due to XLA not supporting // multiple return values. - SmallVector inputs(op.input()); + SmallVector inputs(op.input()); builder.setInsertionPoint(op); auto tuple_input = builder.create(loc, inputs); // Create the new conditional op with tuple inputs. - SmallVector operands(op.getOperands()); + SmallVector operands(op.getOperands()); SmallVector types(op.getResultTypes()); auto result_type = builder.getTupleType(types); auto conditional = builder.create( @@ -142,12 +141,12 @@ void LowerWhile(TF::WhileOp op, ModuleOp module) { // XLA prefers tuple arguments for control flow due to XLA not supporting // multiple return values. - SmallVector inputs(op.input()); + SmallVector inputs(op.input()); builder.setInsertionPoint(op); - Value* tuple_input = builder.create(loc, inputs); + Value tuple_input = builder.create(loc, inputs); // Create the new while op with tuple inputs. - SmallVector operands(op.getOperands()); + SmallVector operands(op.getOperands()); SmallVector types(op.getResultTypes()); auto while_op = builder.create( loc, builder.getTupleType(types), tuple_input); diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index 34c55e7218b..eeccf788dac 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -20,6 +20,11 @@ include "mlir/Dialect/StandardOps/Ops.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td" +def SignedIntTensor : TensorOf<[I1, I8, I16, I32, I64]>; + +// IEEE compliant floating point tensors. +def IEEEFloatTensor : TensorOf<[F16, F32, F64]>; + //===----------------------------------------------------------------------===// // BatchNorm op patterns. //===----------------------------------------------------------------------===// @@ -30,19 +35,19 @@ def FalseBoolAttr : AttrConstraint>; def TrueBoolAttr : AttrConstraint>; def CastValueToI64: NativeCodeCall< - "CastValueToI64($0->getLoc(), $1, &$_builder)">; + "CastValueToI64($0.getLoc(), $1, &$_builder)">; // Here, $0 is an ElementsAttr with exactly one element of type integer. $1 is // the corresponding value of ranked tensor type whose axis is referred in $0. def GetHLOAxisFromTFAxis : NativeCodeCall< "GetHLOAxisFromTFAxis(" - "$0, $1->getType().cast().getRank(), &$_builder)">; + "$0, $1.getType().cast().getRank(), &$_builder)">; // Same as the above but with $1 of type operand_range from variadic TensorFlow // input. def GetHLOAxisFromTFAxisVariadic : NativeCodeCall< "GetHLOAxisFromTFAxis(" - "$0, (*$1.begin())->getType().cast().getRank(), " + "$0, (*$1.begin()).getType().cast().getRank(), " "&$_builder)">; def : Pattern< @@ -87,12 +92,13 @@ def AreBroadcastCompatible : Constraint, "types must be broadcastable">; class DirectBinaryPat - : Pat<(FromOp $l, $r), + : Pat<(FromOp AnyRankedTensor:$l, AnyRankedTensor:$r), (ToOp $l, $r, (BinBroadcastDimensions $l, $r))>; foreach fromToBinPair = [[TF_AddOp, HLO_AddOp], [TF_AddV2Op, HLO_AddOp], [TF_DivOp, HLO_DivOp], + [TF_LeftShiftOp, HLO_ShiftLeftOp], [TF_MaximumOp, HLO_MaxOp], [TF_MinimumOp, HLO_MinOp], [TF_MulOp, HLO_MulOp], @@ -101,18 +107,22 @@ foreach fromToBinPair = [[TF_AddOp, HLO_AddOp], [TF_SubOp, HLO_SubOp]] in def : DirectBinaryPat; +def LowerRightShiftSigned : + Pat<(TF_RightShiftOp AnyRankedTensor:$l, AnyRankedTensor:$r), + (HLO_ShiftRightArithmeticOp $l, $r, (BinBroadcastDimensions $l, $r)), + [(SignedIntTensor $r)]>; + +// TODO(hinsu): Lower unsigned types to HLO_ShiftRightLogical once the HLO op +// supports unsigned integers. + def : Pat<(TF_ComplexOp $r, $i), (HLO_ComplexOp $r, $i)>; -def IntegerTensor : TensorOf<[I1, I8, I16, I32, I64]>; - -// IEEE compliant floating point tensors. -def IEEEFloatTensor : TensorOf<[F16, F32, F64]>; - // Performs a substitution of FloorDiv, pseudo code below: // // return floor(div(x, y)) -def : Pat<(TF_FloorDivOp IEEEFloatTensor:$l, IEEEFloatTensor:$r), - (HLO_FloorOp (HLO_DivOp $l, $r, (BinBroadcastDimensions $l, $r)))>; +def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r), + (HLO_FloorOp (HLO_DivOp $l, $r, (BinBroadcastDimensions $l, $r))), + [(IEEEFloatTensor $l)]>; // Performs a substitution of FloorDir for integer tensors, which required // additional correction for a negative numerator / denominator. Equivalent @@ -131,7 +141,9 @@ def : Pat<(TF_FloorDivOp IEEEFloatTensor:$l, IEEEFloatTensor:$r), // without returning the broadcast of 'r' to broadcast('l', 'r'). // // NOTE: This should be optimized for unsigned integers. -def : Pat<(TF_FloorDivOp IntegerTensor:$l, IntegerTensor:$r), +// Requires static shaped inputs to create constant splats and computation of +// broadcast attributes. +def : Pat<(TF_FloorDivOp AnyStaticShapeTensor:$l, AnyStaticShapeTensor:$r), (HLO_SelectOp (HLO_CompareOp (HLO_CompareOp $l, (HLO_ConstOp (ConstantSplat<"0"> $l)), @@ -146,14 +158,17 @@ def : Pat<(TF_FloorDivOp IntegerTensor:$l, IntegerTensor:$r), (HLO_ConstOp (ConstantSplat<"1"> $r)), (NullDenseIntElementsAttr)), (BinBroadcastDimensions $l, $r))), - (HLO_AbsOp:$abs $r), (BinBroadcastDimensions $neg, $abs)))>; + (HLO_AbsOp:$abs $r), (BinBroadcastDimensions $neg, $abs))), + [(SignedIntTensor $l)]>; // Performs a substitution of FloorMod designed to correct for possibly negative // values. Pseudocode shown below: // // T trunc_mod = std::fmod(x, y); // return trunc_mod != 0 && (y < 0 != trunc_mod < 0) ? trunc_mod + y -def : Pat<(TF_FloorModOp $l, $r), +// Requires static shaped inputs to create constant splats and computation of +// broadcast attributes. +def : Pat<(TF_FloorModOp AnyStaticShapeTensor:$l, AnyStaticShapeTensor:$r), (HLO_SelectOp (HLO_AndOp (HLO_CompareOp @@ -186,8 +201,9 @@ def : Pat<(TF_BroadcastToOp:$result AnyRankedTensor:$input, $shape), //===----------------------------------------------------------------------===// class DirectLogicalBinaryPat - : Pat<(FromOp IntegerTensor:$l, IntegerTensor:$r), - (ToOp $l, $r, (BinBroadcastDimensions $l, $r))>; + : Pat<(FromOp AnyRankedTensor:$l, AnyRankedTensor:$r), + (ToOp $l, $r, (BinBroadcastDimensions $l, $r)), + [(SignedIntTensor $l)]>; foreach fromToBinPair = [[TF_LogicalAndOp, HLO_AndOp], [TF_LogicalOrOp, HLO_OrOp], @@ -199,7 +215,7 @@ foreach fromToBinPair = [[TF_LogicalAndOp, HLO_AndOp], //===----------------------------------------------------------------------===// class DirectComparePat - : Pat<(FromOp $l, $r), + : Pat<(FromOp AnyRankedTensor:$l, AnyRankedTensor:$r), (HLO_CompareOp $l, $r, (BinBroadcastDimensions $l, $r), direction)>; def : DirectComparePat; @@ -208,7 +224,7 @@ def : DirectComparePat; def : DirectComparePat; class EqualityPat - : Pat<(FromOp $l, $r, + : Pat<(FromOp AnyRankedTensor:$l, AnyRankedTensor:$r, TrueBoolAttr:$incompatible_shape_error), (HLO_CompareOp $l, $r, (BinBroadcastDimensions $l, $r), direction), [(AreBroadcastCompatible $l, $r)]>; @@ -235,10 +251,10 @@ def OneElementAttr "Scalar ElementsAttr">; def HasRankedFirstOperand - : ConstraintgetType().isa()">>; + : Constraint()">>; def IsShapedTensor - : ConstraintgetType().isa()">>; + : Constraint()">>; // This pattern converts TensorFlow axis format to HLO axis format which // doesn't wrap around like TensorFlow and is always positive. For this @@ -389,7 +405,7 @@ def : Pat<(TF_SliceOp:$op HLO_Tensor:$input, HLO_Tensor:$starting_indices, // Ternary op patterns. //===----------------------------------------------------------------------===// -def BothTypesMatch : ConstraintgetType() == $1->getType()">, +def BothTypesMatch : Constraint, "types must be equal">; foreach src = [TF_SelectOp, TF_SelectV2Op] in @@ -412,6 +428,8 @@ foreach Mapping = [ [TF_ImagOp, HLO_ImagOp], [TF_IsFiniteOp, HLO_IsFiniteOp], [TF_LogOp, HLO_LogOp], + [TF_Log1pOp, HLO_Log1pOp], + [TF_LogicalNotOp, HLO_NotOp], [TF_NegOp, HLO_NegOp], [TF_RealOp, HLO_RealOp], [TF_RsqrtOp, HLO_RsqrtOp], @@ -440,6 +458,19 @@ foreach TfOp = [TF_ExpandDimsOp, TF_ReshapeOp, TF_SqueezeOp, ] in { (HLO_ReshapeOp $arg), [(AnyStaticShapeTensor $res)]>; } +// Returns 0 if x is NaN, 0 if x is 0, -1 if x < 0 and 1 if x > 0. +def : Pat<(TF_SignOp $x), + (HLO_SelectOp + (HLO_CompareOp + $x, + $x, + (NullDenseIntElementsAttr), + HLO_COMPARISON_DIRECTION_NE + ), + (HLO_ConstOp (ConstantSplat<"0"> $x)), + (HLO_SignOp $x) + )>; + //===----------------------------------------------------------------------===// // RngUniform. //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc index 29f3eb9a8f5..5e12abc466c 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc @@ -16,10 +16,10 @@ limitations under the License. // This file implements logic for lowering XLA dialect to Standard dialect. #include "llvm/ADT/StringSwitch.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" #include "tensorflow/compiler/mlir/xla/transforms/rewriters.h" @@ -47,8 +47,8 @@ struct CompareIConvert : public RewritePattern { auto lhs = compare_op.lhs(); auto rhs = compare_op.rhs(); - auto lhs_type = lhs->getType().cast(); - auto rhs_type = rhs->getType().cast(); + auto lhs_type = lhs.getType().cast(); + auto rhs_type = rhs.getType().cast(); // Broadcasting not supported by this rewrite. if (lhs_type.getShape() != rhs_type.getShape()) return matchFailure(); @@ -86,8 +86,8 @@ struct CompareFConvert : public RewritePattern { auto lhs = compare_op.lhs(); auto rhs = compare_op.rhs(); - auto lhs_type = lhs->getType().cast(); - auto rhs_type = rhs->getType().cast(); + auto lhs_type = lhs.getType().cast(); + auto rhs_type = rhs.getType().cast(); // Broadcasting not supported by this rewrite. if (lhs_type.getShape() != rhs_type.getShape()) return matchFailure(); diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard_patterns.td index 43c57b9bf7f..a15b28193cd 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard_patterns.td @@ -31,8 +31,8 @@ def : Pat<(HLO_ConstOp ElementsAttr:$value), //===----------------------------------------------------------------------===// def IsSameSizePred : CPred< - "$0->getType().cast().getShape() " - "== $1->getType().cast().getShape()">; + "$0.getType().cast().getShape() " + "== $1.getType().cast().getShape()">; def IsSameSizeConstraint : Constraint; @@ -74,9 +74,9 @@ def : Pat<(HLO_MulOp HLO_IntTensor:$l, HLO_IntTensor:$r, [(IsSameSizeConstraint $l, $r)]>; def : Pat<(HLO_DivOp HLO_IntTensor:$l, HLO_IntTensor:$r, IsNullAttr:$broadcast_dimensions), - (DivISOp $l, $r), + (SignedDivIOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; def : Pat<(HLO_RemOp HLO_IntTensor:$l, HLO_IntTensor:$r, IsNullAttr:$broadcast_dimensions), - (RemISOp $l, $r), + (SignedRemIOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc index a8a2eb77586..8ad6717a3f1 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc @@ -18,8 +18,8 @@ limitations under the License. #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "absl/memory/memory.h" -#include "mlir/Dialect/Linalg/Utils/Utils.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/Dialect/Linalg/Utils/Utils.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project namespace mlir { namespace xla_lhlo { @@ -42,7 +42,7 @@ struct LhloFuseLinalg : public FunctionPass { // tiled. In order to greedily fuse the ops, we have to start from the tiled // root linalg ops, i.e. linalg ops that write to output buffers of the // function. - llvm::SmallDenseSet func_args; + llvm::SmallDenseSet func_args; for (auto func_arg : func.getArguments()) { func_args.insert(func_arg); } @@ -52,7 +52,7 @@ struct LhloFuseLinalg : public FunctionPass { const SmallVector tile_sizes( generic_op.getNumInputsAndOutputs(), 1); auto op = cast(generic_op.getOperation()); - for (const Value* result : op.getOutputs()) { + for (const Value result : op.getOutputs()) { if (!func_args.count(result)) continue; if (linalg::tileLinalgOp(b, op, tile_sizes, /*permutation=*/{}, &folder)) { diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc index f3b8ab9c311..647c304c686 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc @@ -16,14 +16,14 @@ limitations under the License. // This file implements logic for lowering LHLO dialect to Affine dialect. #include "absl/memory/memory.h" -#include "mlir/Dialect/AffineOps/AffineOps.h" // TF:local_config_mlir -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Location.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/Dialect/AffineOps/AffineOps.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project #include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" #include "tensorflow/compiler/mlir/xla/transforms/map_lhlo_to_scalar_op.h" @@ -39,15 +39,15 @@ struct BinaryOpConverter : public OpRewritePattern { PatternRewriter& rewriter) const override { const auto& lhs = op.lhs(); const auto& rhs = op.rhs(); - const auto& lhs_type = lhs->getType().template cast(); - const auto& rhs_type = rhs->getType().template cast(); + const auto& lhs_type = lhs.getType().template cast(); + const auto& rhs_type = rhs.getType().template cast(); const auto& element_type = lhs_type.getElementType(); if (lhs_type.getShape() != rhs_type.getShape()) { return this->matchFailure(); } const auto& shape = lhs_type.getShape(); - SmallVector induction_vars; + SmallVector induction_vars; const auto loc = op.getLoc(); for (int i = 0; i < shape.size(); ++i) { auto forOp = rewriter.create(loc, 0, shape[i]); diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc index 9f1f90cb2f0..28413041ac4 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc @@ -19,21 +19,21 @@ limitations under the License. #include "absl/memory/memory.h" #include "llvm/ADT/ArrayRef.h" -#include "mlir/Dialect/GPU/GPUDialect.h" // TF:local_config_mlir -#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // TF:local_config_mlir -#include "mlir/Dialect/LoopOps/LoopOps.h" // TF:local_config_mlir -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/BlockAndValueMapping.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/Location.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Transforms/DialectConversion.h" // TF:local_config_mlir +#include "mlir/Dialect/GPU/GPUDialect.h" // TF:llvm-project +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // TF:llvm-project +#include "mlir/Dialect/LoopOps/LoopOps.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project #include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" #include "tensorflow/compiler/mlir/xla/transforms/map_lhlo_to_scalar_op.h" @@ -49,13 +49,13 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; PatternMatchResult matchAndRewrite( - ReduceOp reduce_op, ArrayRef args, + ReduceOp reduce_op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { auto loc = reduce_op.getLoc(); // Only support 1d reductions for now. int64_t size = 0; for (auto result : reduce_op.out()) { - auto shaped_type = result->getType().dyn_cast(); + auto shaped_type = result.getType().dyn_cast(); if (!shaped_type || shaped_type.getRank() != 1) { return matchFailure(); } @@ -71,7 +71,7 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern { // Require all inputs to have the same shape. int64_t reduce_dim_size = 0; for (auto input : reduce_op.operands()) { - auto shaped_type = input->getType().dyn_cast(); + auto shaped_type = input.getType().dyn_cast(); if (!shaped_type || !shaped_type.hasStaticShape()) { return matchFailure(); } @@ -105,7 +105,7 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern { loc, mapping.lookup(std::get<0>(pair))); rewriter.create(loc, init_value, mapping.lookup(std::get<1>(pair)), - ArrayRef{index}); + ArrayRef{index}); } // Insert a loop into the body to compute the reduction. The loop ranges @@ -128,15 +128,15 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern { auto output = mapping.lookup(*reduce_op.out().begin()); // TODO(herhut) Move this to the SliceOp builder. auto resType = MemRefType::get( - llvm::None, output->getType().cast().getElementType(), + llvm::None, output.getType().cast().getElementType(), makeStridedLinearLayoutMap(llvm::None, MemRefType::getDynamicStrideOrOffset(), rewriter.getContext())); auto accumulator = rewriter.create( - loc, resType, output, ArrayRef{launch_op.getThreadIds().x}); - llvm::SmallVector indexings; + loc, resType, output, ArrayRef{launch_op.getThreadIds().x}); + llvm::SmallVector indexings; auto input_buffer = *reduce_op.operands().begin(); - auto input_type = input_buffer->getType().cast(); + auto input_type = input_buffer.getType().cast(); for (int64_t dim = 0; dim < input_type.getRank(); ++dim) { indexings.push_back(dim == reducing_dimension ? loop.getInductionVar() diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_linalg.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_linalg.cc index af7383c5101..739b9f3554d 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_linalg.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_linalg.cc @@ -17,20 +17,20 @@ limitations under the License. #include "absl/memory/memory.h" #include "llvm/ADT/APInt.h" -#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // TF:local_config_mlir -#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" // TF:local_config_mlir -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/AffineExpr.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/Location.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Transforms/DialectConversion.h" // TF:local_config_mlir +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // TF:llvm-project +#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/AffineExpr.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project #include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" #include "tensorflow/compiler/mlir/xla/transforms/map_lhlo_to_scalar_op.h" @@ -53,11 +53,11 @@ class PointwiseToLinalgConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; PatternMatchResult matchAndRewrite( - LhloOp lhlo_op, ArrayRef args, + LhloOp lhlo_op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { auto loc = lhlo_op.getLoc(); auto argType = - lhlo_op.getOperand(0)->getType().template dyn_cast(); + lhlo_op.getOperand(0).getType().template dyn_cast(); if (!argType || !argType.hasStaticShape()) { emitError(loc, "lhlo to linalg conversion expects statically shaped args"); @@ -73,7 +73,7 @@ class PointwiseToLinalgConverter : public OpConversionPattern { unsigned nloops = 0; int operandCount = args.size() - 1; for (const auto& arg : llvm::enumerate(args)) { - auto memrefType = arg.value()->getType().dyn_cast(); + auto memrefType = arg.value().getType().dyn_cast(); if (!memrefType) return ConversionPattern::matchFailure(); unsigned rank = memrefType.getRank(); if (!rank || (nloops && nloops != rank)) { @@ -101,7 +101,7 @@ class PointwiseToLinalgConverter : public OpConversionPattern { block->addArguments(bodyArgTypes); block->addArguments(bodyResultTypes); - SmallVector bodyArgs; + SmallVector bodyArgs; for (int i = 0, e = bodyArgTypes.size(); i < e; ++i) { bodyArgs.push_back(block->getArgument(i)); } @@ -121,11 +121,11 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; PatternMatchResult matchAndRewrite( - LhloOp lhlo_op, ArrayRef args, + LhloOp lhlo_op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { auto loc = lhlo_op.getLoc(); auto argType = - lhlo_op.getOperand(0)->getType().template dyn_cast(); + lhlo_op.getOperand(0).getType().template dyn_cast(); if (!argType || !argType.getElementType().isIntOrFloat() || (argType.getRank() != 0)) { return ConversionPattern::matchFailure(); @@ -136,7 +136,7 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern { auto rhs = rewriter.create(loc, lhlo_op.rhs()); Operation* op = MapLhloOpToStdScalarOp( llvm::cast(lhlo_op), argType.getElementType(), - llvm::ArrayRef{lhs, rhs}, rewriter); + llvm::ArrayRef{lhs, rhs}, rewriter); rewriter.create(loc, op->getResult(0), lhlo_op.out()); rewriter.eraseOp(lhlo_op); return ConversionPattern::matchSuccess(); @@ -148,12 +148,12 @@ class BroadcastInDimConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; PatternMatchResult matchAndRewrite( - BroadcastInDimOp broadcastOp, ArrayRef args, + BroadcastInDimOp broadcastOp, ArrayRef args, ConversionPatternRewriter& rewriter) const final { auto operandMemrefType = - broadcastOp.operand()->getType().dyn_cast(); + broadcastOp.operand().getType().dyn_cast(); auto resultMemrefType = - broadcastOp.output()->getType().dyn_cast(); + broadcastOp.output().getType().dyn_cast(); if (!operandMemrefType || !resultMemrefType) return matchFailure(); auto broadcastDims = broadcastOp.broadcast_dimensions(); if (!broadcastDims.hasValue()) return matchFailure(); @@ -167,7 +167,7 @@ class BroadcastInDimConverter : public OpConversionPattern { private: PatternMatchResult emitScalarBroadcast( - BroadcastInDimOp broadcastOp, ArrayRef args, + BroadcastInDimOp broadcastOp, ArrayRef args, MemRefType resultMemrefType, ConversionPatternRewriter* rewriter) const { unsigned nloops = resultMemrefType.getRank(); SmallVector indexingMaps{ @@ -195,7 +195,7 @@ class BroadcastInDimConverter : public OpConversionPattern { } PatternMatchResult emitNonScalarBroadcast( - BroadcastInDimOp broadcastOp, ArrayRef args, + BroadcastInDimOp broadcastOp, ArrayRef args, MemRefType operandMemrefType, MemRefType resultMemrefType, ConversionPatternRewriter* rewriter) const { SmallVector bodyArgTypes{operandMemrefType.getElementType()}; @@ -250,10 +250,10 @@ class IotaConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; PatternMatchResult matchAndRewrite( - IotaOp iotaOp, ArrayRef args, + IotaOp iotaOp, ArrayRef args, ConversionPatternRewriter& rewriter) const final { auto resultMemrefType = - iotaOp.getOperand()->getType().dyn_cast(); + iotaOp.getOperand().getType().dyn_cast(); if (!resultMemrefType) return matchFailure(); auto resultElementType = resultMemrefType.getElementType(); @@ -301,7 +301,7 @@ class ConstConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; PatternMatchResult matchAndRewrite( - ConstOp constOp, ArrayRef args, + ConstOp constOp, ArrayRef args, ConversionPatternRewriter& rewriter) const final { auto loc = constOp.getLoc(); auto valueAttr = constOp.value().cast(); diff --git a/tensorflow/compiler/mlir/xla/transforms/lower_complex.cc b/tensorflow/compiler/mlir/xla/transforms/lower_complex.cc index e09350f4f74..672398672de 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lower_complex.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lower_complex.cc @@ -23,14 +23,14 @@ limitations under the License. #include #include "llvm/ADT/STLExtras.h" -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir -#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir -#include "mlir/IR/Types.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/IR/TypeUtilities.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/PassRegistry.h" // TF:llvm-project #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_utils.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/xla/transforms/lower_general_dot.cc b/tensorflow/compiler/mlir/xla/transforms/lower_general_dot.cc index 515f818749e..c956cd6b277 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lower_general_dot.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lower_general_dot.cc @@ -17,15 +17,15 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringSwitch.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/Location.h" // TF:local_config_mlir -#include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/TypeUtilities.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" #include "tensorflow/compiler/mlir/xla/transforms/rewriters.h" @@ -44,12 +44,12 @@ using mlir::Value; namespace { -Value *TransposeReshape(Value *arg, mlir::Location loc, - llvm::ArrayRef left_dims, - llvm::ArrayRef right_dims, - llvm::ArrayRef arg_shape, - PatternRewriter *rewriter) { - auto element_type = mlir::getElementTypeOrSelf(arg->getType()); +Value TransposeReshape(Value arg, mlir::Location loc, + llvm::ArrayRef left_dims, + llvm::ArrayRef right_dims, + llvm::ArrayRef arg_shape, + PatternRewriter *rewriter) { + auto element_type = mlir::getElementTypeOrSelf(arg.getType()); int64_t left_size = 1; for (auto dim : left_dims) { @@ -91,10 +91,10 @@ Value *TransposeReshape(Value *arg, mlir::Location loc, transpose_result); } -Value *ProcessDotArg(Value *arg, mlir::Location loc, - ElementsAttr contract_dims_attr, bool outer_dims_first, - PatternRewriter *rewriter) { - auto shape = arg->getType().cast().getShape(); +Value ProcessDotArg(Value arg, mlir::Location loc, + ElementsAttr contract_dims_attr, bool outer_dims_first, + PatternRewriter *rewriter) { + auto shape = arg.getType().cast().getShape(); llvm::SmallVector is_outer_dim; is_outer_dim.resize(shape.size(), true); @@ -154,8 +154,8 @@ struct GeneralDotConvert /*outer_dims_first=*/false, &rewriter); // Dot resulting shape. - auto lhs_shape = lhs->getType().cast().getShape(); - auto rhs_shape = rhs->getType().cast().getShape(); + auto lhs_shape = lhs.getType().cast().getShape(); + auto rhs_shape = rhs.getType().cast().getShape(); auto new_dot_type = RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, dot_element_type); diff --git a/tensorflow/compiler/mlir/xla/transforms/map_lhlo_to_scalar_op.h b/tensorflow/compiler/mlir/xla/transforms/map_lhlo_to_scalar_op.h index 11e3af7649b..d61d3e35afc 100644 --- a/tensorflow/compiler/mlir/xla/transforms/map_lhlo_to_scalar_op.h +++ b/tensorflow/compiler/mlir/xla/transforms/map_lhlo_to_scalar_op.h @@ -18,7 +18,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project #include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" namespace mlir { @@ -40,7 +40,7 @@ struct ScalarOp { template <> struct ScalarOp { using FOp = ::mlir::DivFOp; - using IOp = ::mlir::DivISOp; + using IOp = ::mlir::SignedDivIOp; }; template <> struct ScalarOp { @@ -60,8 +60,8 @@ using ScalarIOp = typename ScalarOp::IOp; template Operation* MapLhloOpToStdScalarOp(LhloOp lhlo_op, ArrayRef result_types, - ArrayRef block_args, OpBuilder b) { - Type element_type = block_args.front()->getType(); + ArrayRef block_args, OpBuilder b) { + Type element_type = block_args.front().getType(); if (element_type.isa()) { return b.template create>(lhlo_op.getLoc(), result_types, block_args, mlir::None); @@ -76,10 +76,10 @@ Operation* MapLhloOpToStdScalarOp(LhloOp lhlo_op, ArrayRef result_types, template <> inline Operation* MapLhloOpToStdScalarOp( xla_lhlo::MaxOp lhlo_op, ArrayRef result_types, - ArrayRef block_args, OpBuilder b) { + ArrayRef block_args, OpBuilder b) { const auto& lhs = block_args[0]; const auto& rhs = block_args[1]; - Type element_type = lhs->getType(); + Type element_type = lhs.getType(); if (element_type.isa()) { auto lhs_gt_rhs = b.create>( lhlo_op.getLoc(), CmpIPredicate::sgt, lhs, rhs); @@ -96,10 +96,10 @@ inline Operation* MapLhloOpToStdScalarOp( template <> inline Operation* MapLhloOpToStdScalarOp( xla_lhlo::MinOp lhlo_op, ArrayRef result_types, - ArrayRef block_args, OpBuilder b) { + ArrayRef block_args, OpBuilder b) { const auto& lhs = block_args[0]; const auto& rhs = block_args[1]; - Type element_type = lhs->getType(); + Type element_type = lhs.getType(); if (element_type.isa()) { auto lhs_lt_rhs = b.create>( lhlo_op.getLoc(), CmpIPredicate::slt, lhs, rhs); @@ -116,8 +116,8 @@ inline Operation* MapLhloOpToStdScalarOp( template <> inline Operation* MapLhloOpToStdScalarOp( xla_lhlo::AndOp lhlo_op, ArrayRef result_types, - ArrayRef block_args, OpBuilder b) { - Type element_type = block_args.front()->getType(); + ArrayRef block_args, OpBuilder b) { + Type element_type = block_args.front().getType(); return element_type.isa() ? b.create<::mlir::AndOp>(lhlo_op.getLoc(), result_types, block_args, mlir::None) @@ -150,10 +150,10 @@ inline Optional getIntCmpPredicate( template <> inline Operation* MapLhloOpToStdScalarOp( xla_lhlo::CompareOp lhlo_op, ArrayRef result_types, - ArrayRef block_args, OpBuilder b) { + ArrayRef block_args, OpBuilder b) { const auto& lhs = block_args[0]; const auto& rhs = block_args[1]; - Type element_type = lhs->getType(); + Type element_type = lhs.getType(); if (element_type.isa()) { Optional predicate = getIntCmpPredicate(lhlo_op.comparison_direction()); @@ -172,7 +172,7 @@ inline Operation* MapLhloOpToStdScalarOp( template <> inline Operation* MapLhloOpToStdScalarOp( xla_lhlo::SelectOp lhlo_op, ArrayRef result_types, - ArrayRef block_args, OpBuilder b) { + ArrayRef block_args, OpBuilder b) { return b.create<::mlir::SelectOp>(lhlo_op.getLoc(), result_types, block_args, mlir::None); } @@ -180,8 +180,8 @@ inline Operation* MapLhloOpToStdScalarOp( template <> inline Operation* MapLhloOpToStdScalarOp( xla_lhlo::ExpOp lhlo_op, ArrayRef result_types, - ArrayRef block_args, OpBuilder b) { - Type element_type = block_args.front()->getType(); + ArrayRef block_args, OpBuilder b) { + Type element_type = block_args.front().getType(); return element_type.isa() ? b.create<::mlir::ExpOp>(lhlo_op.getLoc(), result_types, block_args, mlir::None) diff --git a/tensorflow/compiler/mlir/xla/transforms/passes.h b/tensorflow/compiler/mlir/xla/transforms/passes.h index d659a3a87f4..21d1f08f3ea 100644 --- a/tensorflow/compiler/mlir/xla/transforms/passes.h +++ b/tensorflow/compiler/mlir/xla/transforms/passes.h @@ -18,8 +18,8 @@ limitations under the License. #include -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project namespace mlir { diff --git a/tensorflow/compiler/mlir/xla/transforms/rewriters.h b/tensorflow/compiler/mlir/xla/transforms/rewriters.h index e4a014f137f..5f546d4651e 100644 --- a/tensorflow/compiler/mlir/xla/transforms/rewriters.h +++ b/tensorflow/compiler/mlir/xla/transforms/rewriters.h @@ -18,8 +18,8 @@ limitations under the License. #include -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project namespace mlir { namespace xla_hlo { diff --git a/tensorflow/compiler/mlir/xla/type_to_shape.cc b/tensorflow/compiler/mlir/xla/type_to_shape.cc index 37c657c99ae..d82b2d33779 100644 --- a/tensorflow/compiler/mlir/xla/type_to_shape.cc +++ b/tensorflow/compiler/mlir/xla/type_to_shape.cc @@ -17,11 +17,11 @@ limitations under the License. #include -#include "mlir/IR/AffineMap.h" // TF:local_config_mlir -#include "mlir/IR/Diagnostics.h" // TF:local_config_mlir -#include "mlir/IR/Location.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/Support/DebugStringHelper.h" // TF:local_config_mlir +#include "mlir/IR/AffineMap.h" // TF:llvm-project +#include "mlir/IR/Diagnostics.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/Support/DebugStringHelper.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" diff --git a/tensorflow/compiler/mlir/xla/type_to_shape.h b/tensorflow/compiler/mlir/xla/type_to_shape.h index 4bc3fac9b1c..c9989def939 100644 --- a/tensorflow/compiler/mlir/xla/type_to_shape.h +++ b/tensorflow/compiler/mlir/xla/type_to_shape.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_XLA_TYPE_TO_SHAPE_H_ #include "llvm/ADT/STLExtras.h" -#include "mlir/IR/Types.h" // TF:local_config_mlir +#include "mlir/IR/Types.h" // TF:llvm-project #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" diff --git a/tensorflow/compiler/mlir/xla/type_to_shape_test.cc b/tensorflow/compiler/mlir/xla/type_to_shape_test.cc index fc4eea79347..98f9b36c84b 100644 --- a/tensorflow/compiler/mlir/xla/type_to_shape_test.cc +++ b/tensorflow/compiler/mlir/xla/type_to_shape_test.cc @@ -17,9 +17,9 @@ limitations under the License. #include -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc b/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc index e79c03447c8..16be296ce6c 100644 --- a/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc +++ b/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc @@ -17,8 +17,8 @@ limitations under the License. #include "llvm/Support/CommandLine.h" #include "llvm/Support/MemoryBuffer.h" -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/Translation.h" // TF:local_config_mlir +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/Translation.h" // TF:llvm-project #include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h" #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h" #include "tensorflow/compiler/xla/debug_options_flags.h" diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 01a0f0a86f2..4c3dcd81eb7 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -316,6 +316,11 @@ tf_xla_py_test( timeout = "moderate", srcs = ["matrix_inverse_op_test.py"], python_version = "PY3", + tags = [ + "noasan", + "nomsan", + "notsan", + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", diff --git a/tensorflow/compiler/tests/dense_layer_test.py b/tensorflow/compiler/tests/dense_layer_test.py index 8020aa28ce4..8e653d2511c 100644 --- a/tensorflow/compiler/tests/dense_layer_test.py +++ b/tensorflow/compiler/tests/dense_layer_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import os + import numpy as np from tensorflow.compiler.tests import test_utils diff --git a/tensorflow/compiler/tests/depthwise_conv_op_test.py b/tensorflow/compiler/tests/depthwise_conv_op_test.py index a49985f0446..0f0ea50fde9 100644 --- a/tensorflow/compiler/tests/depthwise_conv_op_test.py +++ b/tensorflow/compiler/tests/depthwise_conv_op_test.py @@ -68,21 +68,21 @@ def ConfigsToTest(): Tuple (input_size, filter_size, out_size, stride, padding), the depthwise convolution parameters. """ - input_sizes = [[4, 5, 5, 48], [4, 8, 8, 84], [4, 17, 17, 48], [4, 9, 27, 8], - [4, 31, 31, 7], [4, 35, 35, 2], [4, 147, 147, 2], - [3, 299, 299, 3], [5, 183, 183, 1]] - filter_sizes = [[1, 1, 48, 2], [1, 3, 84, 1], [3, 1, 48, 4], [3, 3, 8, 1], - [3, 3, 7, 1], [5, 5, 2, 1], [3, 3, 2, 8], [2, 2, 3, - 8], [5, 5, 1, 2]] - out_sizes = [[4, 5, 5, 96], [4, 8, 8, 84], [4, 17, 17, 192], [4, 9, 27, 8], - [4, 31, 31, 7], [4, 35, 35, 2], [4, 49, 49, 16], + input_sizes = [[4, 5, 5, 48], [2, 5, 5, 48], [4, 8, 8, 84], [4, 17, 17, 48], + [4, 9, 27, 8], [4, 31, 31, 7], [4, 35, 35, 2], + [4, 147, 147, 2], [3, 299, 299, 3], [5, 183, 183, 1]] + filter_sizes = [[1, 1, 48, 2], [2, 2, 48, 8], [1, 3, 84, 1], [3, 1, 48, 4], + [3, 3, 8, 1], [3, 3, 7, 1], [5, 5, 2, 1], [3, 3, 2, 8], + [2, 2, 3, 8], [5, 5, 1, 2]] + out_sizes = [[4, 5, 5, 96], [2, 5, 5, 384], [4, 8, 8, 84], [4, 17, 17, 192], + [4, 9, 27, 8], [4, 31, 31, 7], [4, 35, 35, 2], [4, 49, 49, 16], [3, 150, 150, 24], [5, 92, 92, 2]] - strides = [1, 1, 1, 1, 1, 1, 3, 2, 2] + strides = [1, 1, 1, 1, 1, 1, 1, 3, 2, 2] # pylint: disable=invalid-name VALID = "VALID" SAME = "SAME" # pylint: enable=invalid-name - paddings = [SAME, SAME, SAME, SAME, SAME, SAME, VALID, SAME, SAME, SAME] + paddings = [SAME, SAME, SAME, SAME, SAME, SAME, SAME, VALID, SAME, SAME, SAME] for i, f, o, s, p in zip(input_sizes, filter_sizes, out_sizes, strides, paddings): yield i, f, o, s, p diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py index 109a7932c20..3bde1574f0e 100644 --- a/tensorflow/compiler/tests/jit_test.py +++ b/tensorflow/compiler/tests/jit_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import os + import numpy as np from tensorflow.compiler.tests import test_utils diff --git a/tensorflow/compiler/tests/matrix_diag_ops_test.py b/tensorflow/compiler/tests/matrix_diag_ops_test.py index 1ca9b157fa1..4c03211da5a 100644 --- a/tensorflow/compiler/tests/matrix_diag_ops_test.py +++ b/tensorflow/compiler/tests/matrix_diag_ops_test.py @@ -21,19 +21,10 @@ from __future__ import print_function import numpy as np from tensorflow.compiler.tests import xla_test -from tensorflow.python.compat import compat from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest -# LINT.IfChange -matrix_diag_v3_forward_compat_date = (2019, 12, 6) -# LINT.ThenChange( -# //tensorflow/python/kernel_tests/diag_op_test.py, -# //tensorflow/python/ops/array_ops.py, -# //tensorflow/python/ops/parallel_for/array_test.py -# ) - default_v2_alignment = "LEFT_LEFT" alignment_list = ["RIGHT_LEFT", "LEFT_RIGHT"] @@ -404,25 +395,20 @@ class MatrixDiagTest(xla_test.XLATestCase): # From here onwards are v2-only tests. def testSquare(self): - if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): - for align in alignment_list: - for _, tests in [square_cases(align)]: - for diag_index, (vecs, solution) in tests.items(): - params = {"diagonal": vecs[0], "k": diag_index, "align": align} - self._assertOpOutputMatchesExpected(params, solution[0]) + for align in alignment_list: + for _, tests in [square_cases(align)]: + for diag_index, (vecs, solution) in tests.items(): + params = {"diagonal": vecs[0], "k": diag_index, "align": align} + self._assertOpOutputMatchesExpected(params, solution[0]) def testSquareBatch(self): - if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): - for align in alignment_list: - for _, tests in [square_cases(align)]: - for diag_index, (vecs, solution) in tests.items(): - params = {"diagonal": vecs, "k": diag_index, "align": align} - self._assertOpOutputMatchesExpected(params, solution) + for align in alignment_list: + for _, tests in [square_cases(align)]: + for diag_index, (vecs, solution) in tests.items(): + params = {"diagonal": vecs, "k": diag_index, "align": align} + self._assertOpOutputMatchesExpected(params, solution) def testRectangularBatch(self): - if not compat.forward_compatible(*matrix_diag_v3_forward_compat_date): - return - # Stores expected num_rows and num_cols (when the other is given). # expected[(d_lower, d_upper)] = (expected_num_rows, expected_num_cols) test_list = list() @@ -513,22 +499,21 @@ class MatrixDiagTest(xla_test.XLATestCase): }, solution_given_num_cols) def testPadding(self): - if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): - for padding_value, align in zip_to_first_list_length([555, -11], - alignment_list): - for _, tests in all_tests(align): - for diag_index, (vecs, solution) in tests.items(): - mask = (solution == 0) - solution = solution + (mask * padding_value) - self._assertOpOutputMatchesExpected( - { - "diagonal": vecs, - "k": diag_index, - "num_rows": solution.shape[-2], - "num_cols": solution.shape[-1], - "padding_value": padding_value, - "align": align - }, solution) + for padding_value, align in zip_to_first_list_length([555, -11], + alignment_list): + for _, tests in all_tests(align): + for diag_index, (vecs, solution) in tests.items(): + mask = (solution == 0) + solution = solution + (mask * padding_value) + self._assertOpOutputMatchesExpected( + { + "diagonal": vecs, + "k": diag_index, + "num_rows": solution.shape[-2], + "num_cols": solution.shape[-1], + "padding_value": padding_value, + "align": align + }, solution) class MatrixSetDiagTest(xla_test.XLATestCase): @@ -634,36 +619,34 @@ class MatrixSetDiagTest(xla_test.XLATestCase): # From here onwards are v2-only tests. def testSingleMatrix(self): - if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): - for align in alignment_list: - for _, tests in all_tests(align): - for diag_index, (vecs, banded_mat) in tests.items(): - mask = (banded_mat[0] == 0) - input_mat = np.random.randint(10, size=mask.shape) - solution = input_mat * mask + banded_mat[0] - self._assertOpOutputMatchesExpected( - { - "input": input_mat, - "diagonal": vecs[0], - "k": diag_index, - "align": align - }, solution) + for align in alignment_list: + for _, tests in all_tests(align): + for diag_index, (vecs, banded_mat) in tests.items(): + mask = (banded_mat[0] == 0) + input_mat = np.random.randint(10, size=mask.shape) + solution = input_mat * mask + banded_mat[0] + self._assertOpOutputMatchesExpected( + { + "input": input_mat, + "diagonal": vecs[0], + "k": diag_index, + "align": align + }, solution) def testBatch(self): - if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): - for align in alignment_list: - for _, tests in all_tests(align): - for diag_index, (vecs, banded_mat) in tests.items(): - mask = (banded_mat == 0) - input_mat = np.random.randint(10, size=mask.shape) - solution = input_mat * mask + banded_mat - self._assertOpOutputMatchesExpected( - { - "input": input_mat, - "diagonal": vecs, - "k": diag_index, - "align": align - }, solution) + for align in alignment_list: + for _, tests in all_tests(align): + for diag_index, (vecs, banded_mat) in tests.items(): + mask = (banded_mat == 0) + input_mat = np.random.randint(10, size=mask.shape) + solution = input_mat * mask + banded_mat + self._assertOpOutputMatchesExpected( + { + "input": input_mat, + "diagonal": vecs, + "k": diag_index, + "align": align + }, solution) class MatrixDiagPartTest(xla_test.XLATestCase): @@ -705,45 +688,42 @@ class MatrixDiagPartTest(xla_test.XLATestCase): # From here onwards are v2-only tests. def testSingleMatrix(self): - if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): - for align in alignment_list: - test_list = [square_cases(align), tall_cases(align), fat_cases(align)] - for mat, tests in test_list: - for diag_index, (solution, _) in tests.items(): - self._assertOpOutputMatchesExpected( - { - "input": mat[0], - "k": diag_index, - "align": align - }, solution[0]) + for align in alignment_list: + test_list = [square_cases(align), tall_cases(align), fat_cases(align)] + for mat, tests in test_list: + for diag_index, (solution, _) in tests.items(): + self._assertOpOutputMatchesExpected( + { + "input": mat[0], + "k": diag_index, + "align": align + }, solution[0]) def testBatch(self): - if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): - for align in alignment_list: - for mat, tests in all_tests(align): - for diag_index, (solution, _) in tests.items(): - self._assertOpOutputMatchesExpected( - { - "input": mat, - "k": diag_index, - "align": align - }, solution) + for align in alignment_list: + for mat, tests in all_tests(align): + for diag_index, (solution, _) in tests.items(): + self._assertOpOutputMatchesExpected( + { + "input": mat, + "k": diag_index, + "align": align + }, solution) def testPadding(self): - if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): - for padding_value, align in zip_to_first_list_length([555, -11], - alignment_list): - for mat, tests in all_tests(align): - for diag_index, (solution, _) in tests.items(): - mask = (solution == 0) - solution = solution + (mask * padding_value) - self._assertOpOutputMatchesExpected( - { - "input": mat, - "k": diag_index, - "padding_value": padding_value, - "align": align - }, solution) + for padding_value, align in zip_to_first_list_length([555, -11], + alignment_list): + for mat, tests in all_tests(align): + for diag_index, (solution, _) in tests.items(): + mask = (solution == 0) + solution = solution + (mask * padding_value) + self._assertOpOutputMatchesExpected( + { + "input": mat, + "k": diag_index, + "padding_value": padding_value, + "align": align + }, solution) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/quantized_ops_test.py b/tensorflow/compiler/tests/quantized_ops_test.py index 100be3b9aa5..5d4fb39f2ea 100644 --- a/tensorflow/compiler/tests/quantized_ops_test.py +++ b/tensorflow/compiler/tests/quantized_ops_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import math + import numpy as np from tensorflow.compiler.tests import xla_test diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py index a39f633858a..57709c2cd10 100644 --- a/tensorflow/compiler/tests/reduce_ops_test.py +++ b/tensorflow/compiler/tests/reduce_ops_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import functools import itertools + from absl.testing import parameterized import numpy as np diff --git a/tensorflow/compiler/tests/reverse_ops_test.py b/tensorflow/compiler/tests/reverse_ops_test.py index 7dc323b0ab5..abfb73ade38 100644 --- a/tensorflow/compiler/tests/reverse_ops_test.py +++ b/tensorflow/compiler/tests/reverse_ops_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import itertools + import numpy as np from tensorflow.compiler.tests import xla_test diff --git a/tensorflow/compiler/tests/segment_reduction_ops_test.py b/tensorflow/compiler/tests/segment_reduction_ops_test.py index 500617bc38b..ae86b6c30da 100644 --- a/tensorflow/compiler/tests/segment_reduction_ops_test.py +++ b/tensorflow/compiler/tests/segment_reduction_ops_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import functools + import numpy as np from tensorflow.compiler.tests import xla_test diff --git a/tensorflow/compiler/tests/self_adjoint_eig_op_test.py b/tensorflow/compiler/tests/self_adjoint_eig_op_test.py index 0c1a1d145d4..9507a8c9c92 100644 --- a/tensorflow/compiler/tests/self_adjoint_eig_op_test.py +++ b/tensorflow/compiler/tests/self_adjoint_eig_op_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import itertools + from absl.testing import parameterized import numpy as np diff --git a/tensorflow/compiler/tests/svd_op_test.py b/tensorflow/compiler/tests/svd_op_test.py index 7791b409a37..7e05eeb4c0a 100644 --- a/tensorflow/compiler/tests/svd_op_test.py +++ b/tensorflow/compiler/tests/svd_op_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import itertools + from absl.testing import parameterized import numpy as np diff --git a/tensorflow/compiler/tests/tensor_list_ops_test.py b/tensorflow/compiler/tests/tensor_list_ops_test.py index 7d2425ee205..d49a6a37785 100644 --- a/tensorflow/compiler/tests/tensor_list_ops_test.py +++ b/tensorflow/compiler/tests/tensor_list_ops_test.py @@ -19,6 +19,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function import os + from absl.testing import parameterized import numpy as np from tensorflow.compiler.tests import xla_test diff --git a/tensorflow/compiler/tests/while_test.py b/tensorflow/compiler/tests/while_test.py index 3ef12ced704..420dc04bec3 100644 --- a/tensorflow/compiler/tests/while_test.py +++ b/tensorflow/compiler/tests/while_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import os + import numpy as np from tensorflow.compiler.tests import xla_test diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index f6e9780eabc..65679bd021a 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -500,7 +500,8 @@ cc_library( deps = [ "//tensorflow/core:framework", "//tensorflow/core:lib_proto_parsing", - ], + "//tensorflow/core:lib", + ] + if_tensorrt([":tensorrt_lib"]), ) tf_proto_library( diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index 855e5d4285f..4e76287a953 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -51,6 +51,7 @@ limitations under the License. #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/tensor_coding.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/public/version.h" #include "tensorflow/core/util/strided_slice_op.h" #if GOOGLE_CUDA @@ -200,18 +201,6 @@ int64 TFAttrs::get(const string& key) const { return this->at(key)->i(); } -template -inline nvinfer1::Dims TensorShapeToTrtDims(const TensorShapeType& shape, - bool ignore_first_dim) { - nvinfer1::Dims trt_dims; - const int offset = (ignore_first_dim ? 1 : 0); - for (int i = offset; i < shape.dims(); i++) { - trt_dims.d[i - offset] = shape.dim_size(i); - } - trt_dims.nbDims = shape.dims() - offset; - return trt_dims; -} - template Status TensorShapeArrayToTrtDims(const Container& shape, nvinfer1::Dims* out, bool ignore_first_dim = false) { @@ -286,7 +275,7 @@ Status ValidateTensorProperties(const string& producer_node_type, } *trt_dims = TensorShapeToTrtDims(shape, /*ignore_first_dim=*/use_implicit_batch); - // Get batch size for tensor if it will not be included the the shape. + // Get batch size for tensor if it will not be included the shape. if (use_implicit_batch) { *batch_size = shape.dim_size(0); } @@ -314,66 +303,6 @@ Status ValidateTensorProperties(const string& producer_node_type, return Status::OK(); } -string DebugString(const nvinfer1::DimensionType type) { - switch (type) { - case nvinfer1::DimensionType::kSPATIAL: - return "kSPATIAL"; - case nvinfer1::DimensionType::kCHANNEL: - return "kCHANNEL"; - case nvinfer1::DimensionType::kINDEX: - return "kINDEX"; - case nvinfer1::DimensionType::kSEQUENCE: - return "kSEQUENCE"; - default: - return StrCat(static_cast(type), "=unknown"); - } -} - -string DebugString(const nvinfer1::DataType trt_dtype) { - switch (trt_dtype) { - case nvinfer1::DataType::kFLOAT: - return "kFLOAT"; - case nvinfer1::DataType::kHALF: - return "kHALF"; - case nvinfer1::DataType::kINT8: - return "kINT8"; - case nvinfer1::DataType::kINT32: - return "kINT32"; - default: - return "Invalid TRT data type"; - } -} - -string DebugString(const nvinfer1::Dims& dims) { - string out = StrCat("nvinfer1::Dims(nbDims=", dims.nbDims, ", d="); - for (int i = 0; i < dims.nbDims; ++i) { - StrAppend(&out, dims.d[i]); - if (VLOG_IS_ON(2)) { - StrAppend(&out, "[", DebugString(dims.type[i]), "],"); - } else { - StrAppend(&out, ","); - } - } - StrAppend(&out, ")"); - return out; -} - -string DebugString(const nvinfer1::Permutation& permutation, int len) { - string out = "nvinfer1::Permutation("; - for (int i = 0; i < len; ++i) { - StrAppend(&out, permutation.order[i], ","); - } - StrAppend(&out, ")"); - return out; -} - -string DebugString(const nvinfer1::ITensor& tensor) { - return StrCat("nvinfer1::ITensor(@", reinterpret_cast(&tensor), - ", name=", tensor.getName(), - ", dtype=", DebugString(tensor.getType()), - ", dims=", DebugString(tensor.getDimensions()), ")"); -} - Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l, const TRT_TensorOrWeights& operand_r, const bool check_feasibility, @@ -581,14 +510,6 @@ inline nvinfer1::Dims GetTrtDimsForTensor(const Tensor& tensor) { return dims; } -inline bool HasStaticShape(const nvinfer1::Dims& dims) { - if (dims.nbDims < 0) return false; - for (int d = 0; d < dims.nbDims; ++d) { - if (dims.d[d] < 0) return false; - } - return true; -} - int64_t Prod(const nvinfer1::Dims& dims) { int64_t count = 1; for (int d = 0; d < dims.nbDims; ++d) { @@ -732,9 +653,10 @@ size_t TRT_ShapedWeights::size_bytes() const { } string TRT_ShapedWeights::DebugString() const { - return StrCat("TRT_ShapedWeights(shape=", convert::DebugString(shape_), - ", type=", convert::DebugString(type_), - ", values=", reinterpret_cast(GetValues()), ")"); + return StrCat( + "TRT_ShapedWeights(shape=", tensorflow::tensorrt::DebugString(shape_), + ", type=", tensorflow::tensorrt::DebugString(type_), + ", values=", reinterpret_cast(GetValues()), ")"); } // A fake ITensor implementation used to check whether the TF-TRT converter can @@ -858,7 +780,7 @@ nvinfer1::Dims TRT_TensorOrWeights::GetTrtDims() const { string TRT_TensorOrWeights::DebugString() const { string output = "TRT_TensorOrWeights(type="; if (is_tensor()) { - StrAppend(&output, "tensor=", convert::DebugString(*tensor()), + StrAppend(&output, "tensor=", tensorflow::tensorrt::DebugString(*tensor()), ", batch_size=", batch_size_); } else { StrAppend(&output, "weights=", weights_.DebugString()); @@ -1210,11 +1132,8 @@ static void InitializeTrtPlugins(nvinfer1::ILogger* trt_logger) { mutex_lock lock(plugin_mutex); if (plugin_initialized) return; - LOG(INFO) << "Linked TensorRT version: " << NV_TENSORRT_MAJOR << "." - << NV_TENSORRT_MINOR << "." << NV_TENSORRT_PATCH; - const int loaded_version = getInferLibVersion(); - LOG(INFO) << "Loaded TensorRT version: " << loaded_version / 1000 << "." - << (loaded_version / 100) % 10 << "." << loaded_version % 100; + LOG(INFO) << "Linked TensorRT version: " << GetLinkedTensorRTVersion(); + LOG(INFO) << "Loaded TensorRT version: " << GetLoadedTensorRTVersion(); plugin_initialized = initLibNvInferPlugins(trt_logger, ""); if (!plugin_initialized) { @@ -1451,6 +1370,19 @@ Status Converter::BuildCudaEngine( } } +#if IS_TRT_VERSION_GE(6, 0, 0, 0) + string precision_mode_str; + TF_RETURN_IF_ERROR( + TrtPrecisionModeToName(precision_mode_, &precision_mode_str)); + string trt_network_name = StrCat( + "TF:", TF_VERSION_STRING, ", ", "TRT:", GetLoadedTensorRTVersion(), "-", + "Precision:", precision_mode_str, ", ", "Calibration:", use_calibration_, + ", ", "Max-Batch-Size:", max_batch_size, ", ", + "Max-Workspace-Size:", max_workspace_size_bytes); + VLOG(1) << "Setting TensorRT network name to " << trt_network_name; + network()->setName(trt_network_name.c_str()); +#endif // #if IS_TRT_VERSION_GE(6, 0, 0, 0) + VLOG(1) << "Building TensorRT engine"; engine->reset(trt_builder_->buildCudaEngine(*network())); #endif @@ -2230,7 +2162,37 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group, conv_layer = layer; } nvinfer1::ITensor* output_tensor = conv_layer->getOutput(0); - + // Add an extra padding for Deconv because TRT doesn't accept the + // argument output_shape and thus the TRT output shape could be wrong + // in case of strides>1. + if (is_conv2d_backprop_input) { + auto tf_output_shape = + static_cast(backprop_output_size.weights().GetValues()); + nvinfer1::Dims trt_output_shape = output_tensor->getDimensions(); + // What determines the padding size is the difference between the given + // input_sizes (tf_output_shape) and TRT computed size. + const int height_diff = tf_output_shape[h_index] - trt_output_shape.d[1]; + const int width_diff = tf_output_shape[w_index] - trt_output_shape.d[2]; + if ((height_diff < 0) || (width_diff < 0)) { + return errors::InvalidArgument( + "input_sizes argument of Conv2DBackprop (i.e. output_shape argument " + "of conv2d_transpose) ", + "is too small for the given out_backprop argument of Conv2DBackprop " + "(i.e. input argument of conv2d_transpose). Expect: ", + "(", tf_output_shape[h_index], ", ", tf_output_shape[w_index], + ") >= ", "(", trt_output_shape.d[1], ", ", trt_output_shape.d[2], + ") for op ", node_def.name()); + } + // Only add a padding layer if padding sizes are larger than 0 + if ((height_diff > 0) || (width_diff > 0)) { + nvinfer1::DimsHW pre_padding(0, 0); + nvinfer1::DimsHW post_padding(height_diff, width_diff); + nvinfer1::IPaddingLayer* padding_layer = + params->converter->network()->addPadding(*output_tensor, pre_padding, + post_padding); + output_tensor = padding_layer->getOutput(0); + } + } // Restore transpose. if (need_transpose) { TF_RETURN_IF_ERROR(params->converter->TransposeTensor( @@ -5145,6 +5107,17 @@ Status ConvertTopK(OpConverterParams* params) { CheckInputsWeights(*params, {{"input", false}, {"k", true}})); TF_RETURN_IF_ERROR( AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF})); + TFAttrs attrs(node_def); + const bool sorted = attrs.get("sorted"); + if (!sorted) { + // TensorRT only supports sorted output. Although TensorFlow API + // doesn't specify the order of output elements in case sorted=false, + // but it's safer to not convert because the output of TensorRT might + // be different with TensorFlow which can cause confusion. + return errors::InvalidArgument("Only sorted=True is supported, at", + node_def.name()); + } + nvinfer1::ITensor* tensor = inputs.at(0).tensor(); const int num_dims = tensor->getDimensions().nbDims; if (num_dims == 0) { diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h index 6fb3620bf81..a9f579c9ed7 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h @@ -42,14 +42,6 @@ namespace tensorrt { namespace convert { using ::stream_executor::port::StatusOr; -#define IS_TRT_VERSION_GE(major, minor, patch, build) \ - ((NV_TENSORRT_MAJOR > major) || \ - (NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR > minor) || \ - (NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR == minor && \ - NV_TENSORRT_PATCH > patch) || \ - (NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR == minor && \ - NV_TENSORRT_PATCH == patch && NV_TENSORRT_BUILD >= build)) - struct EngineConnection { // Constructs a non-control edge. EngineConnection(const string& outside, int out_id, int out_port, @@ -164,11 +156,6 @@ class OutputEdgeValidator { bool operator()(const Edge* out_edge) const; }; -string DebugString(const nvinfer1::DimensionType type); -string DebugString(const nvinfer1::DataType trt_dtype); -string DebugString(const nvinfer1::Dims& dims); -string DebugString(const nvinfer1::Permutation& permutation, int len); -string DebugString(const nvinfer1::ITensor& tensor); int64_t TrtWeightDimsNumElements(const nvinfer1::Dims& dims); int64_t TrtTensorDimsNumElements(const nvinfer1::Dims& dims); @@ -341,7 +328,7 @@ class TRT_TensorOrWeights { // size represented in the shapes or the batch sizes are different. See // b/118387490 for more details. // - // if use_implicit_batch is false, batch_size_ is unused and + // If use_implicit_batch is false, batch_size_ is unused and // tensor_->getDimensions() will contain the entire shape (A,B,C). int batch_size_ = -1; diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index 358004abac7..fa361c29933 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -1714,15 +1714,14 @@ TEST_F(OpConverterTest, ConvertReshape) { }; // Reshape at batch dimension, should fail. - const int kReshapeBatchDimsCases = 5; - TestParams params[kReshapeBatchDimsCases] = { + std::vector params = { TestParams{1, {1, 2, 3}, {3, 1, 1, 2}}, TestParams{1, {1, 2, -1}, {-1, 1, 1, 2}}, TestParams{1, {1, 2, 3}, {-1, 1, 1, 2}}, TestParams{-1, {1, 2, 3}, {1, 1, 1, 2}}, TestParams{-1, {-1, 2, 3}, {1, 1, 1, 6}}, // TODO(laigd): it should pass. }; - for (int i = 0; i < kReshapeBatchDimsCases; ++i) { + for (int i = 0; i < params.size(); ++i) { Reset(); const std::vector& dims = params[i].tensor_dims; AddTestTensor("input", dims, params[i].batch_size); @@ -1734,8 +1733,7 @@ TEST_F(OpConverterTest, ConvertReshape) { } // Reshape on non batch dimensions, ok. - const int kReshapeOKCases = 8; - TestParams ok_params[kReshapeOKCases] = { + std::vector ok_params = { TestParams{-1, {1, 2, 3}, {-1, 1, 3, 2}}, TestParams{1, {1, 2, 3}, {-1, 1, 3, 2}}, TestParams{1, {1, 2, 3}, {1, 1, 3, 2}}, @@ -1745,7 +1743,7 @@ TEST_F(OpConverterTest, ConvertReshape) { TestParams{2, {1, 1}, {2}}, TestParams{2, {}, {2, 1}}, }; - for (int i = 0; i < kReshapeOKCases; ++i) { + for (int i = 0; i < ok_params.size(); ++i) { const int batch_size = std::max(1, ok_params[i].batch_size); const auto& shape = ok_params[i].shape; Reset(); @@ -2549,14 +2547,13 @@ TEST_F(OpConverterTest, ConvertCombinedNMS) { }; // Ok. - const int kCombinedNMSOKCases = 1; - TestParams ok_params[kCombinedNMSOKCases] = { + std::vector ok_params = { // TODO(aaroey): there is a bug in TRT's CombinedNonMaxSuppression // implementation that, the extra output classes that are outside of the // range specified by valid_detections[i] are not zeros but -1s. TestParams{{1, 1, 4}, {1, 3}, 3, 2, .5f, 0, {2, 4}, {2}, {2}}}; - for (int i = 0; i < kCombinedNMSOKCases; ++i) { + for (int i = 0; i < ok_params.size(); ++i) { Reset(); AddTestTensor("boxes", ok_params[i].boxes_tensor_dims); @@ -2814,14 +2811,13 @@ TEST_F(OpConverterTest, ConvertExpandDims) { }; // Ok. - const int kExpandDimsOKCases = 8; - TestParams ok_params[kExpandDimsOKCases] = { + std::vector ok_params = { TestParams{{2, 3}, 1, {1, 2, 3}}, TestParams{{2, 3}, -3, {1, 2, 3}}, TestParams{{2, 3}, 3, {2, 3, 1}}, TestParams{{2, 3}, -1, {2, 3, 1}}, TestParams{{2, 3}, 2, {2, 1, 3}}, TestParams{{2, 3}, -2, {2, 1, 3}}, TestParams{{6}, 1, {1, 6}}, TestParams{{6}, -1, {6, 1}}, }; - for (int i = 0; i < kExpandDimsOKCases; ++i) { + for (int i = 0; i < ok_params.size(); ++i) { Reset(); AddTestTensor("input", ok_params[i].input_dims); AddTestWeights("weights", {1}, {ok_params[i].axis}); @@ -2931,8 +2927,7 @@ TEST_F(OpConverterTest, ConvertSqueeze) { }; // Ok. - const int kSqueezeOKCases = 10; - TestParams ok_params[kSqueezeOKCases] = { + std::vector ok_params = { TestParams{{1, 2, 3}, {1}, {2, 3}}, TestParams{{1, 2, 3}, {-3}, {2, 3}}, TestParams{{2, 3, 1}, {3}, {2, 3}}, @@ -2944,7 +2939,7 @@ TEST_F(OpConverterTest, ConvertSqueeze) { TestParams{{1, 6}, {1}, {6}}, TestParams{{6, 1}, {2}, {6}}, }; - for (int i = 0; i < kSqueezeOKCases; ++i) { + for (int i = 0; i < ok_params.size(); ++i) { Reset(); NodeDef node_def = get_squeeze_nodedef(ok_params[i].axis); AddTestTensor("input", ok_params[i].input_dims); @@ -3114,13 +3109,8 @@ TEST_F(OpConverterTest, ConvertStridedSlice) { // Same input is used for all tests. const std::vector ok_input = {1, 2, 3, 4, 5, 6}; -#if IS_TRT_VERSION_GE(5, 1, 3, 1) - const int kStridedSliceOKCases = 31; -#else - const int kStridedSliceOKCases = 27; -#endif // Ok. - TestParams ok_params[kStridedSliceOKCases] = { + std::vector ok_params = { // 2D Crop. TestParams{ /*input_dims=*/{1, 2, 3}, @@ -3484,6 +3474,7 @@ TEST_F(OpConverterTest, ConvertStridedSlice) { /*expected_output_dims=*/{1, 2, 1}, /*expected_output=*/{2, 5}, }, +#if IS_TRT_VERSION_GE(5, 1, 3, 1) TestParams{ /*input_dims=*/{1, 2, 3}, /*begin=*/{0, 0, 0, 0, 1}, @@ -3537,9 +3528,10 @@ TEST_F(OpConverterTest, ConvertStridedSlice) { /*expected_output_dims=*/{}, /*expected_output=*/{1}, }, +#endif // IS_TRT_VERSION_GE(5, 1, 3, 1) }; - for (int i = 0; i < kStridedSliceOKCases; i++) { + for (int i = 0; i < ok_params.size(); i++) { Reset(); NodeDef node_def = get_strided_slice_nodedef( ok_params[i].begin_mask, ok_params[i].end_mask, @@ -3672,8 +3664,7 @@ TEST_F(OpConverterTest, ConvertSlice) { }; // Ok. - const int kSliceOKCases = 5; - TestParams ok_params[kSliceOKCases] = { + std::vector ok_params = { TestParams{{1, 2, 3}, {0, 0, 0, 0}, {-1, -1, -1, -1}, @@ -3687,7 +3678,7 @@ TEST_F(OpConverterTest, ConvertSlice) { TestParams{{6}, {0, 1}, {-1, 3}, {3}, {2, 3, 4}}, }; - for (int i = 0; i < kSliceOKCases; i++) { + for (int i = 0; i < ok_params.size(); i++) { Reset(); NodeDef node_def = get_slice_nodedef(); AddTestTensor("input", ok_params[i].input_dims); @@ -3856,8 +3847,7 @@ TEST_F(OpConverterTest, ConvertConv2D) { }; // Ok. - const int kConv2DOKCases = 7; - TestParams ok_params[kConv2DOKCases] = { + std::vector ok_params = { // Basic TestParams{/*input_dims=*/{1, 2, 3}, /*input=*/{0, 1, 2, 3, 3, 4}, @@ -3942,9 +3932,34 @@ TEST_F(OpConverterTest, ConvertConv2D) { /*is_conv2d_backprop_input=*/true, /*expected_output_dims=*/{1, 2, 4}, /*expected_output=*/{0, 0, -1, 1, -2, 2, -3, 3}}, + // Transpose Strided NHWC + TestParams{/*input_dims=*/{2, 2, 1}, + /*input=*/{0, 1, 2, 3}, + /*filter_dims=*/{1, 2, 1, 1}, + /*filter=*/{-1, 1}, + /*strides=*/{1, 1, 2, 1}, + /*padding=*/"SAME", + /*data_format=*/"NHWC", + /*dilations=*/{1, 1, 1, 1}, + /*is_conv2d_backprop_input=*/true, + /*expected_output_dims=*/{2, 4, 1}, + /*expected_output=*/{0, 0, -1, 1, -2, 2, -3, 3}}, + // Transpose Strided NHWC with VALID padding + TestParams{/*input_dims=*/{3, 1, 1}, + /*input=*/{0, 1, 2}, + /*filter_dims=*/{2, 1, 1, 1}, + /*filter=*/{-1, 1}, + /*strides=*/{1, 2, 1, 1}, + /*padding=*/"VALID", + /*data_format=*/"NHWC", + /*dilations=*/{1, 1, 1, 1}, + /*is_conv2d_backprop_input=*/true, + /*expected_output_dims=*/{7, 1, 1}, + /*expected_output=*/{0, 0, -1, 1, -2, 2, 0}}, + }; - for (int i = 0; i < kConv2DOKCases; i++) { + for (int i = 0; i < ok_params.size(); i++) { Reset(); NodeDef node_def = get_conv2d_nodedef( ok_params[i].strides, ok_params[i].padding, ok_params[i].data_format, @@ -3953,10 +3968,10 @@ TEST_F(OpConverterTest, ConvertConv2D) { AddTestWeights("weights", ok_params[i].filter_dims, ok_params[i].filter); if (ok_params[i].is_conv2d_backprop_input) { - AddTestWeights( - "input_sizes", - {static_cast(ok_params[i].expected_output.size())}, - ok_params[i].expected_output); + std::vector tf_input_sizes = ok_params[i].expected_output_dims; + tf_input_sizes.insert(tf_input_sizes.begin(), 1); // Add batch dimension. + QCHECK_EQ(4, tf_input_sizes.size()); + AddTestWeights("input_sizes", {4}, tf_input_sizes); } RunValidationAndConversion(node_def); TRT_TensorOrWeights output; @@ -4141,8 +4156,7 @@ TEST_F(OpConverterTest, ConvertConv3D) { }; // Start here - const int kConv3DOKCases = 8; - TestParams ok_params[kConv3DOKCases] = { + std::vector ok_params = { // Basic - just 1x1 conv - input = output TestParams{ /*input_dims=*/{1, 3, 3, 3}, // CDHW @@ -4277,7 +4291,7 @@ TEST_F(OpConverterTest, ConvertConv3D) { }; - for (int i = 0; i < kConv3DOKCases; i++) { + for (int i = 0; i < ok_params.size(); i++) { Reset(); NodeDef node_def = get_conv3d_nodedef( ok_params[i].strides, ok_params[i].padding, ok_params[i].data_format, @@ -4361,8 +4375,7 @@ TEST_F(OpConverterTest, ConvertPool3D) { const std::vector common_array{-4, 2, 15, 3, 6, -3, 22, 1, 88, 56, 36, 1, 1, 105, 1, 16, -28, 1, 42, 9, 3, 1, 7, 1, 11, 61, 5}; - const int kPool3DOKCases = 10; - TestParams ok_params[kPool3DOKCases] = { + std::vector ok_params = { // Basic - just 1x1 max pooling - input = output TestParams{/*input_dims=*/{1, 3, 3, 3}, /*input=*/common_array, @@ -4472,7 +4485,7 @@ TEST_F(OpConverterTest, ConvertPool3D) { // the corners }}; - for (int i = 0; i < kPool3DOKCases; i++) { + for (int i = 0; i < ok_params.size(); i++) { Reset(); NodeDef node_def = get_pool3d_nodedef( ok_params[i].ksize, ok_params[i].strides, ok_params[i].padding, @@ -4572,10 +4585,9 @@ void TestConvertGather(OpConverterTest* test) { }; // Input is the same {1, 2, 3, 4, 5, 6} for all cases. - const int kGatherOKCases = 11; const std::vector params_input = {CType(1), CType(2), CType(3), CType(4), CType(5), CType(6)}; - TestParams ok_params[kGatherOKCases] = { + std::vector ok_params = { // Vector indices, and output rank is rank(params). TestParams{ /*params_shape=*/{1, 1, 2, 3}, @@ -4680,7 +4692,7 @@ void TestConvertGather(OpConverterTest* test) { }; // Ok. - for (int i = 0; i < kGatherOKCases; i++) { + for (int i = 0; i < ok_params.size(); i++) { test->Reset(); const auto& params_shape = ok_params[i].params_shape; if (ok_params[i].params_is_tensor) { @@ -4993,8 +5005,7 @@ void TestConvertConcat(OpConverterTest* test) { InitTestVector(6, /*start_value=*/CType(6))}; // TODO(hinsu): Use std::vector instead of an array to avoid use of explicit // size. - const int kConcatOKCases = 4; - TestParams ok_params[kConcatOKCases] = { + std::vector ok_params = { { /*input_shapes=*/{{1, 2, 3}, {1, 2, 3}}, /*input_values=*/common_input, @@ -5034,7 +5045,7 @@ void TestConvertConcat(OpConverterTest* test) { }, }; - for (int i = 0; i < kConcatOKCases; ++i) { + for (int i = 0; i < ok_params.size(); ++i) { test->Reset(); const int num_inputs = ok_params[i].input_shapes.size(); EXPECT_EQ(num_inputs, ok_params[i].input_values.size()); @@ -5167,8 +5178,7 @@ void TestConvertSplit(OpConverterTest* test) { }; const std::vector common_input = InitTestVector(6); - const int kSplitOKCases = 4; - TestParams ok_params[kSplitOKCases] = { + std::vector ok_params = { // Identity (num_split = 1) {/*input_shape=*/{1, 2, 3}, /*value=*/common_input, /*axis=*/1, /*num_split=*/1, /*expected_output_dims=*/{1, 2, 3}, @@ -5201,7 +5211,7 @@ void TestConvertSplit(OpConverterTest* test) { {InitTestVector(3), InitTestVector(3, CType(3))}}, }; - for (int i = 0; i < kSplitOKCases; ++i) { + for (int i = 0; i < ok_params.size(); ++i) { test->Reset(); NodeDef node_def = get_split_nodedef(dtype, ok_params[i].num_split); // Create inputs. @@ -5343,8 +5353,7 @@ void TestConvertUnpack(OpConverterTest* test) { }; const std::vector common_input = InitTestVector(6); - const int kUnpackOKCases = 4; - TestParams ok_params[kUnpackOKCases] = { + std::vector ok_params = { {/*input_shape=*/{1, 2, 3}, /*value=*/common_input, /*axis=*/1, /*num=*/1, /*expected_output_dims=*/{2, 3}, /*expected_outputs=*/{InitTestVector(6)}}, @@ -5381,7 +5390,7 @@ void TestConvertUnpack(OpConverterTest* test) { {CType(5)}}}, }; - for (int i = 0; i < kUnpackOKCases; ++i) { + for (int i = 0; i < ok_params.size(); ++i) { test->Reset(); NodeDef node_def = get_unpack_nodedef(dtype, ok_params[i].num, ok_params[i].axis); diff --git a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc index 40fd3a7b65f..757ddd159c9 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc @@ -71,7 +71,7 @@ Status TRTOptimizationPass::Init( trt_logger_name_ = params.at("trt_logger").s(); } if (params.count("use_implicit_batch")) { - use_implicit_batch = params.at("use_implicit_batch").b(); + use_implicit_batch_ = params.at("use_implicit_batch").b(); } return Status::OK(); } @@ -264,7 +264,7 @@ Status TRTOptimizationPass::Optimize(grappler::Cluster* cluster, cp.is_dyn_op = is_dynamic_op_; cp.max_cached_engines = max_cached_batches_; cp.use_calibration = use_calibration_; - cp.use_implicit_batch = use_implicit_batch; + cp.use_implicit_batch = use_implicit_batch_; auto status = ConvertAfterShapes(cp); VLOG(1) << "Returning from " << name_; return status; diff --git a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h index cc17b3409e6..3ce0d09b7c0 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h +++ b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h @@ -42,7 +42,7 @@ class TRTOptimizationPass : public grappler::CustomGraphOptimizer { max_cached_batches_(1), max_workspace_size_bytes_(256LL << 20), use_calibration_(true), - use_implicit_batch(true) { + use_implicit_batch_(true) { VLOG(1) << "Constructing " << name_; } @@ -74,7 +74,7 @@ class TRTOptimizationPass : public grappler::CustomGraphOptimizer { int max_cached_batches_; int64_t max_workspace_size_bytes_; bool use_calibration_; - bool use_implicit_batch; + bool use_implicit_batch_; }; } // namespace convert diff --git a/tensorflow/compiler/tf2tensorrt/convert/utils.cc b/tensorflow/compiler/tf2tensorrt/convert/utils.cc index ca21c193d63..d142bc58bef 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/utils.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/utils.cc @@ -17,6 +17,8 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" namespace tensorflow { namespace tensorrt { @@ -51,5 +53,101 @@ Status TrtPrecisionModeFromName(const string& name, TrtPrecisionMode* mode) { return Status::OK(); } +#if GOOGLE_CUDA && GOOGLE_TENSORRT +using absl::StrAppend; +using absl::StrCat; + +string DebugString(const nvinfer1::DimensionType type) { + switch (type) { + case nvinfer1::DimensionType::kSPATIAL: + return "kSPATIAL"; + case nvinfer1::DimensionType::kCHANNEL: + return "kCHANNEL"; + case nvinfer1::DimensionType::kINDEX: + return "kINDEX"; + case nvinfer1::DimensionType::kSEQUENCE: + return "kSEQUENCE"; + default: + return StrCat(static_cast(type), "=unknown"); + } +} + +string DebugString(const nvinfer1::Dims& dims) { + string out = StrCat("nvinfer1::Dims(nbDims=", dims.nbDims, ", d="); + for (int i = 0; i < dims.nbDims; ++i) { + StrAppend(&out, dims.d[i]); + if (VLOG_IS_ON(2)) { + StrAppend(&out, "[", DebugString(dims.type[i]), "],"); + } else { + StrAppend(&out, ","); + } + } + StrAppend(&out, ")"); + return out; +} + +string DebugString(const nvinfer1::DataType trt_dtype) { + switch (trt_dtype) { + case nvinfer1::DataType::kFLOAT: + return "kFLOAT"; + case nvinfer1::DataType::kHALF: + return "kHALF"; + case nvinfer1::DataType::kINT8: + return "kINT8"; + case nvinfer1::DataType::kINT32: + return "kINT32"; + default: + return "Invalid TRT data type"; + } +} + +string DebugString(const nvinfer1::Permutation& permutation, int len) { + string out = "nvinfer1::Permutation("; + for (int i = 0; i < len; ++i) { + StrAppend(&out, permutation.order[i], ","); + } + StrAppend(&out, ")"); + return out; +} + +string DebugString(const nvinfer1::ITensor& tensor) { + return StrCat("nvinfer1::ITensor(@", reinterpret_cast(&tensor), + ", name=", tensor.getName(), + ", dtype=", DebugString(tensor.getType()), + ", dims=", DebugString(tensor.getDimensions()), ")"); +} + +#endif + +string GetLinkedTensorRTVersion() { + int major, minor, patch; +#if GOOGLE_CUDA && GOOGLE_TENSORRT + major = NV_TENSORRT_MAJOR; + minor = NV_TENSORRT_MINOR; + patch = NV_TENSORRT_PATCH; +#else + major = 0; + minor = 0; + patch = 0; +#endif + return absl::StrCat(major, ".", minor, ".", patch); +} + +string GetLoadedTensorRTVersion() { + int major, minor, patch; +#if GOOGLE_CUDA && GOOGLE_TENSORRT + int ver = getInferLibVersion(); + major = ver / 1000; + ver = ver - major * 1000; + minor = ver / 100; + patch = ver - minor * 100; +#else + major = 0; + minor = 0; + patch = 0; +#endif + return absl::StrCat(major, ".", minor, ".", patch); +} + } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/compiler/tf2tensorrt/convert/utils.h b/tensorflow/compiler/tf2tensorrt/convert/utils.h index eb60829d31d..9015c24b1f4 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/utils.h +++ b/tensorflow/compiler/tf2tensorrt/convert/utils.h @@ -17,9 +17,15 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_UTILS_H_ #include +#include +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/lib/core/status.h" +#if GOOGLE_CUDA && GOOGLE_TENSORRT +#include "third_party/tensorrt/NvInfer.h" +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT + namespace tensorflow { namespace tensorrt { @@ -45,6 +51,60 @@ Status TrtPrecisionModeToName(TrtPrecisionMode mode, string* name); Status TrtPrecisionModeFromName(const string& name, TrtPrecisionMode* mode); +// Define a hash function for vector because it is used as the key +// for the engine cache. +struct VectorTensorShapeHasher { + std::size_t operator()(const std::vector& key) const { + return std::hash()(TensorShapeUtils::ShapeListString(key)); + } +}; + +#if GOOGLE_CUDA && GOOGLE_TENSORRT + +#define IS_TRT_VERSION_GE(major, minor, patch, build) \ + ((NV_TENSORRT_MAJOR > major) || \ + (NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR > minor) || \ + (NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR == minor && \ + NV_TENSORRT_PATCH > patch) || \ + (NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR == minor && \ + NV_TENSORRT_PATCH == patch && NV_TENSORRT_BUILD >= build)) + +string DebugString(const nvinfer1::DimensionType type); +string DebugString(const nvinfer1::Dims& dims); +string DebugString(const nvinfer1::DataType trt_dtype); +string DebugString(const nvinfer1::Permutation& permutation, int len); +string DebugString(const nvinfer1::ITensor& tensor); + +inline bool HasStaticShape(const nvinfer1::Dims& dims) { + if (dims.nbDims < 0) return false; + for (int d = 0; d < dims.nbDims; ++d) { + if (dims.d[d] < 0) return false; + } + return true; +} + +template +inline nvinfer1::Dims TensorShapeToTrtDims(const TensorShapeType& shape, + bool ignore_first_dim) { + nvinfer1::Dims trt_dims; + const int offset = (ignore_first_dim ? 1 : 0); + for (int i = offset; i < shape.dims(); i++) { + trt_dims.d[i - offset] = shape.dim_size(i); + } + trt_dims.nbDims = shape.dims() - offset; + return trt_dims; +} + +// Return a string that includes compile time +// TensorRT library version information {Maj, Min, Patch}. +string GetLinkedTensorRTVersion(); + +// Return a string that includes runtime time +// TensorRT library version information {Maj, Min, Patch}. +string GetLoadedTensorRTVersion(); + +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT + } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc index ca591460c65..c14de3a6736 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc @@ -529,6 +529,25 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx, EngineContext* engine_context) { VLOG(1) << "Executing TRT engine: " << name(); auto& cuda_engine = engine_context->cuda_engine; + + if (VLOG_IS_ON(2)) { +#if IS_TRT_VERSION_GE(6, 0, 0, 0) + VLOG(2) << " Network name: " << cuda_engine->getName(); +#endif // #if IS_TRT_VERSION_GE(6, 0, 0, 0) + VLOG(2) << " Activation size: " << cuda_engine->getDeviceMemorySize() + << " bytes"; + VLOG(2) << " Workspace size: " << cuda_engine->getWorkspaceSize() + << " bytes"; + VLOG(2) << " Datatype of " << cuda_engine->getNbBindings() + << " inputs/outputs"; + string binding_types = ""; + for (int i = 0; i < cuda_engine->getNbBindings(); i++) { + binding_types += " " + string(cuda_engine->getBindingName(i)) + ": " + + DebugString(cuda_engine->getBindingDataType(i)) + "\n"; + } + VLOG(2) << binding_types; + } + const bool kRetry = true; // All inputs must have the same batch size, so just get it from the first // input. @@ -694,6 +713,8 @@ StatusOr TRTEngineOp::GetEngine( // single element containing the only engine. if (static_engine_) { if (cache.size()) { + // TODO(laigd): need a better shape compatibility check for the case where + // implicit batch is disabled. if (!use_implicit_batch_ || AreShapesCompatible(input_shapes, cache.begin()->first)) { return cache.begin()->second.get(); diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h index 8d603ac4d55..808b689127e 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h @@ -114,14 +114,6 @@ class LRUCache { } }; -// Define a hash function for vector because it is used as the key -// for the engine cache. -struct VectorTensorShapeHasher { - std::size_t operator()(const std::vector& key) const { - return std::hash()(TensorShapeUtils::ShapeListString(key)); - } -}; - #if GOOGLE_CUDA #if GOOGLE_TENSORRT diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index c509afbc33a..afe96952358 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -627,7 +627,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/core:core_cpu_lib", "//tensorflow/core:session_options", - "@llvm//:support", + "@llvm-project//llvm:support", ], alwayslink = 1, ) diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 242448e443e..dbc8397441f 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -48,6 +48,7 @@ tf_kernel_library( "function_ops.cc", "gather_op.cc", "gather_op_helpers.h", + "gather_scatter_ops.cc", "identity_op.cc", "image_ops.cc", "image_resize_ops.cc", diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc index 4f79ce109fb..dda0d79337a 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc @@ -512,22 +512,26 @@ xla::StatusOr MakeXlaBackpropFilterConvOp( filter_in_depth = filter_shape.dimensions(attrs.num_spatial_dims), feature_group_count = in_depth / filter_in_depth; + // In the case of depthwise convolutions, the computation can be done by the + // batch_group_count parameter. + bool use_batch_group_count = in_depth > 1 && in_depth == filter_in_depth && + (feature_group_count != 1 || attrs.depthwise); + + if (use_batch_group_count) { + feature_group_count = 1; + } + // The activations (inputs) form the LHS of the convolution. // Activations have shape: [batch, in_rows, in_cols, ..., in_depth] // For the gradient computation, we need to: // 1. In the case of group convolution, move the num_groups dimension before // the batch dimension // 2. Swap the roles of the batch and feature dimensions. - if (feature_group_count != 1 && !attrs.depthwise) { + if (!use_batch_group_count && feature_group_count != 1 && !attrs.depthwise) { activations = TransposeInputForGroupConvolutionBackpropFilter( activations, input_shape, feature_group_count, n_dim, c_dim); } - // In the case of depthwise convolution with no multiplier, - // the computation can be done by the batch_group_count parameter. - bool use_batch_group_count = - filter_tensor_shape.dim_size(num_dims - 1) == 1 && attrs.depthwise; - std::vector> padding(attrs.num_spatial_dims); std::vector rhs_dilation(attrs.num_spatial_dims); std::vector window_strides(attrs.num_spatial_dims); diff --git a/tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc b/tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc new file mode 100644 index 00000000000..19aa85f9d42 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc @@ -0,0 +1,102 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class GatherOp : public XlaOpKernel { + public: + explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) { + string dnums_attr; + OP_REQUIRES_OK(context, context->GetAttr("dimension_numbers", &dnums_attr)); + OP_REQUIRES( + context, dnums_.ParsePartialFromString(dnums_attr), + errors::InvalidArgument("Error parsing gather dimension numbers")); + OP_REQUIRES_OK( + context, context->GetAttr("indices_are_sorted", &indices_are_sorted_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + std::vector slice_sizes; + OP_REQUIRES_OK(ctx, + ctx->ConstantInputAsIntVector("slice_sizes", &slice_sizes)); + xla::XlaOp result = + xla::Gather(ctx->Input("operand"), ctx->Input("start_indices"), dnums_, + slice_sizes, indices_are_sorted_); + ctx->SetOutput(0, result); + } + + private: + xla::GatherDimensionNumbers dnums_; + bool indices_are_sorted_; +}; + +REGISTER_XLA_OP(Name("XlaGather"), GatherOp); + +class ScatterOp : public XlaOpKernel { + public: + explicit ScatterOp(OpKernelConstruction* context) : XlaOpKernel(context) { + OP_REQUIRES_OK( + context, context->GetAttr("update_computation", &update_computation_)); + string dnums_attr; + OP_REQUIRES_OK(context, context->GetAttr("dimension_numbers", &dnums_attr)); + OP_REQUIRES( + context, dnums_.ParsePartialFromString(dnums_attr), + errors::InvalidArgument("Error parsing scatter dimension numbers")); + OP_REQUIRES_OK( + context, context->GetAttr("indices_are_sorted", &indices_are_sorted_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + const DataType dtype = ctx->input_type(0); + + XlaCompiler::Argument update_computation_arg; + update_computation_arg.kind = XlaCompiler::Argument::kParameter; + update_computation_arg.type = dtype; + update_computation_arg.shape = TensorShape(); + + XlaCompiler::CompileOptions compile_options; + compile_options.use_tuple_arg = false; + compile_options.always_return_tuple = false; + compile_options.is_entry_computation = false; + XlaCompiler::CompilationResult update_computation; + OP_REQUIRES_OK(ctx, ctx->compiler()->CompileFunction( + compile_options, *update_computation_, + {update_computation_arg, update_computation_arg}, + &update_computation)); + + xla::XlaOp result = + xla::Scatter(ctx->Input("operand"), ctx->Input("scatter_indices"), + ctx->Input("updates"), *update_computation.computation, + dnums_, indices_are_sorted_); + ctx->SetOutput(0, result); + } + + private: + const NameAttrList* update_computation_; + xla::ScatterDimensionNumbers dnums_; + bool indices_are_sorted_; +}; + +REGISTER_XLA_OP(Name("XlaScatter"), ScatterOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index e6076907980..83a894e91fe 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -111,6 +111,11 @@ XLAJIT_MAKE_UNARY(Real, xla::Real(x)); XLAJIT_MAKE_UNARY(Imag, xla::Imag(x)); XLAJIT_MAKE_UNARY(Erf, xla::Erf(x)); XLAJIT_MAKE_UNARY(Erfc, xla::Erfc(x)); +XLAJIT_MAKE_UNARY(Erfinv, xla::ErfInv(x)); +// ndtri = sqrt(2) * erfinv(2 * x - 1) +XLAJIT_MAKE_UNARY(Ndtri, xla::ScalarLike(x, std::sqrt(2.0)) * + xla::ErfInv(xla::ScalarLike(x, 2.0) * x - + xla::ScalarLike(x, 1.0))); XLAJIT_MAKE_UNARY(Lgamma, xla::Lgamma(x)); XLAJIT_MAKE_UNARY(Digamma, xla::Digamma(x)); XLAJIT_MAKE_UNARY(BesselI0e, xla::BesselI0e(x)); diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index 33b740a706c..6b71cca9c2a 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -665,5 +665,50 @@ REGISTER_OP("XlaReplicaId") }) .Doc("Replica ID."); +REGISTER_OP("XlaGather") + .Input("operand: T") + .Input("start_indices: Tindices") + .Input("slice_sizes: Tindices") + .Attr("dimension_numbers: string") + .Attr("indices_are_sorted: bool") + .Attr("T: numbertype") + .Attr("Tindices: {int32, int64}") + .Output("output: T") + .SetShapeFn(UnchangedRank) + .Doc(R"doc( +Wraps the XLA Gather operator documented at + https://www.tensorflow.org/xla/operation_semantics#gather +operand: The array we're gathering from. +start_indices: Array containing the starting indices of the slices we gather. +dimension_numbers: A serialized xla::GatherDimensionNumbers proto. +slice_sizes: slice_sizes[i] is the bounds for the slice on dimension i. +indices_are_sorted: Boolean indicating if the indices are sorted. +)doc"); + +REGISTER_OP("XlaScatter") + .Input("operand: T") + .Input("scatter_indices: Tindices") + .Input("updates: T") + .Attr("update_computation: func") + .Attr("dimension_numbers: string") + .Attr("indices_are_sorted: bool") + .Attr("T: numbertype") + .Attr("Tindices: {int32, int64}") + .Output("output: T") + .SetShapeFn(UnchangedRank) + .Doc(R"doc( +Wraps the XLA Scatter operator documented at + https://www.tensorflow.org/xla/operation_semantics#scatter. + +operand: Array to be scattered into. +scatter_indices: Array containing the starting indices of the slices that must + be scattered to. +updates: Array containing the values that must be used for scattering. +update_computation: Computation to be used for combining the existing values in + the input array and the updates during scatter. +dimension_numbers: A serialized xla::ScatterDimensionNumbers proto. +indices_are_sorted: Boolean indicating if the indices are sorted. +)doc"); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index 24f1e7b41ec..bf258482e56 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -81,7 +81,8 @@ ceil = _unary_op(math_ops.ceil) digamma = _unary_op(math_ops.digamma) erf = _unary_op(math_ops.erf) erfc = _unary_op(math_ops.erfc) -# TODO(phawkins): implement erfinv +erfinv = _unary_op(math_ops.erfinv) +ndtri = _unary_op(math_ops.ndtri) exp = _unary_op(math_ops.exp) expm1 = _unary_op(math_ops.expm1) floor = _unary_op(math_ops.floor) @@ -415,3 +416,27 @@ sort = gen_xla_ops.xla_sort key_value_sort = gen_xla_ops.xla_key_value_sort while_loop = gen_xla_ops.xla_while dequantize = gen_xla_ops.xla_dequantize + + +def gather(operand, start_indices, dimension_numbers, slice_sizes, + indices_are_sorted=False, name=None): + return gen_xla_ops.xla_gather( + operand, + start_indices, + slice_sizes=slice_sizes, + dimension_numbers=dimension_numbers.SerializeToString(), + indices_are_sorted=indices_are_sorted, + name=name) + + +def scatter(operand, scatter_indices, updates, update_computation, + dimension_numbers, indices_are_sorted=False, name=None): + return gen_xla_ops.xla_scatter( + operand, + scatter_indices, + updates, + update_computation=update_computation, + dimension_numbers=dimension_numbers.SerializeToString(), + indices_are_sorted=indices_are_sorted, + name=name) + diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 4e2866865a2..3a430c36a82 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -417,7 +417,6 @@ cc_library( ":array3d", ":array4d", ":shape_util", - ":sparse_index_array", ":status_macros", ":types", ":util", @@ -463,7 +462,6 @@ cc_library( ":array4d", ":literal", ":shape_util", - ":sparse_index_array", ":status_macros", ":types", ":util", @@ -840,29 +838,6 @@ tf_cc_test( ], ) -cc_library( - name = "sparse_index_array", - srcs = ["sparse_index_array.cc"], - hdrs = ["sparse_index_array.h"], - deps = [ - ":array2d", - ":shape_util", - ":xla_data_proto_cc", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/types:span", - ], -) - -tf_cc_test( - name = "sparse_index_array_test", - srcs = ["sparse_index_array_test.cc"], - deps = [ - ":sparse_index_array", - ":test", - "//tensorflow/core:test_main", - ], -) - cc_library( name = "parse_flags_from_env", srcs = ["parse_flags_from_env.cc"], diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index fd31fb17bba..47fe026385e 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -129,7 +129,7 @@ cc_library( "//tensorflow/stream_executor:device_memory_allocator", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", - "@llvm//:support", + "@llvm-project//llvm:support", ], ) @@ -147,7 +147,7 @@ cc_library( "//tensorflow/compiler/xla/service:compiler", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/memory", - "@llvm//:support", + "@llvm-project//llvm:support", ], ) @@ -253,6 +253,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], ) diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc index 8c85482c8f8..9153ac9e524 100644 --- a/tensorflow/compiler/xla/client/lib/math.cc +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -15,9 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/math.h" -// This macro is required to make MSVC defines math constants in math.h -#define _USE_MATH_DEFINES -#include +#include #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" diff --git a/tensorflow/compiler/xla/client/lib/tridiagonal.cc b/tensorflow/compiler/xla/client/lib/tridiagonal.cc index d2ea6d57069..13cc3630137 100644 --- a/tensorflow/compiler/xla/client/lib/tridiagonal.cc +++ b/tensorflow/compiler/xla/client/lib/tridiagonal.cc @@ -36,6 +36,8 @@ namespace { struct TridiagonalSystemShape { const int64 rank; const int64 num_equations; + TridiagonalSystemShape(int64 rk, int64 num_eqs) + : rank(rk), num_equations(num_eqs) {} }; Status CheckSecondToLastDimension(const Shape& op_shape, int64 rank, @@ -109,9 +111,7 @@ StatusOr CheckSystemAndReturnShape(XlaOp lower_diagonal, TF_RETURN_IF_ERROR(CheckSecondToLastDimension(upper_diagonal_shape, rank, 1, "upper diagonal")); - TridiagonalSystemShape result = {.rank = rank, - .num_equations = num_equations}; - return result; + return TridiagonalSystemShape(rank, num_equations); } XlaOp Coefficient(XlaOp operand, int64 i) { diff --git a/tensorflow/compiler/xla/g3doc/index.md b/tensorflow/compiler/xla/g3doc/index.md index 39715fbe7a9..38c6672685d 100644 --- a/tensorflow/compiler/xla/g3doc/index.md +++ b/tensorflow/compiler/xla/g3doc/index.md @@ -81,32 +81,19 @@ For a detailed usage example, see the ### Explicit compilation Explicit compilation API offers a more fine-grained control for choosing which -functions should be compiled with XLA. However, it requires restructuring source -code, as not all TensorFlow operations can be represented in XLA. That is, using -explicit compilation on API on functions which can not be represented in XLA -results in an exception. +functions should be compiled with XLA. However, it might require restructuring +of the source code, as not all TensorFlow operations can be represented in XLA. -#### TF2: Use `@tf.function(experimental_compile=True)` +Note: Using the explicit compilation on API on functions which can not be +represented in XLA results in an exception. Optimizing sections of the program using [`tf.function`](https://www.tensorflow.org/api_docs/python/tf/function) is a -standard approach for -[improving performance](https://www.tensorflow.org/tutorials/customization/performance) -of TF2 programs. You can enable compilation with XLA by setting the -`experimental_compile` argument of `tf.function` to `True`. - -Note: `experimental_compile` only works in -[eager](https://www.tensorflow.org/guide/eager) mode. - -#### TF1: Use `xla.compile` - -If you are using TF1, you can use the `xla.compile` API for explicit compilation -using XLA. See the [tutorial colab](./tutorials/xla_compile.ipynb) for usage -examples. - -Note: Gradient computation of graph in `xla.compile()` is prohibited because it -can cause performance degradation. To avoid this issue, move gradient -computation inside `xla.compile()`. +standard approach for [improving +performance](https://www.tensorflow.org/tutorials/customization/performance) of +TF2 programs. You can enable compilation with XLA by setting the +`experimental_compile` argument of `tf.function` to `True`. See the [tutorial +colab](./tutorials/experimental_compile.ipynb) for usage examples. ### AOT (Ahead-of-time) compilation for CPU with `tfcompile` diff --git a/tensorflow/compiler/xla/g3doc/operation_semantics.md b/tensorflow/compiler/xla/g3doc/operation_semantics.md index ee7b2b20928..0185bb4bb2f 100644 --- a/tensorflow/compiler/xla/g3doc/operation_semantics.md +++ b/tensorflow/compiler/xla/g3doc/operation_semantics.md @@ -2053,8 +2053,8 @@ window_strides, padding)` : : : as to have the same output shape : : : : as input if the stride is 1, or : : : : Padding\:\:kValid, which uses no : -: : : no padding and "stops" the : -: : : window once it no longer fits) : +: : : padding and "stops" the window : +: : : once it no longer fits) : Below code and figure shows an example of using `ReduceWindow`. Input is a matrix of size [4x6] and both window_dimensions and window_stride_dimensions are diff --git a/tensorflow/compiler/xla/g3doc/tutorials/experimental_compile.ipynb b/tensorflow/compiler/xla/g3doc/tutorials/experimental_compile.ipynb new file mode 100644 index 00000000000..c8c08fc3ffa --- /dev/null +++ b/tensorflow/compiler/xla/g3doc/tutorials/experimental_compile.ipynb @@ -0,0 +1,268 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "Using XLA with tf.function", + "provenance": [], + "collapsed_sections": [], + "toc_visible": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + } + }, + "cells": [ + { + "metadata": { + "colab_type": "text", + "id": "f4TSNCvpENrW" + }, + "cell_type": "markdown", + "source": [ + "##### Copyright 2019 The TensorFlow Authors." + ] + }, + { + "metadata": { + "cellView": "form", + "colab_type": "code", + "id": "vamNSA0vEP-m", + "colab": {} + }, + "cell_type": "code", + "source": [ + "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "e1oSi4lHFt3z" + }, + "source": [ + "# Using XLA via `tf.function` and `experimental_compile`" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "sDy5lSBd4BDE", + "colab_type": "text" + }, + "source": [ + "In this colab, we train a TensorFlow model to classify the MNIST dataset, where the training function is compiled using XLA.\n", + "\n", + "We start by loading TensorFlow, with eager execution enabled." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "b7noD9NjFRL-" + }, + "source": [ + "\n", + " \n", + " \n", + " \n", + "
\n", + " View on TensorFlow.org\n", + " \n", + " Run in Google Colab\n", + " \n", + " View source on GitHub\n", + "
" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab_type": "code", + "id": "45kUPj5ZFrRa" + }, + "source": [ + "import tensorflow as tf\n", + "\n", + "tf.enable_eager_execution()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "GZVNiRmTDV-5" + }, + "source": [ + "Then, we define some necessary constants and prepare the MNIST dataset." + ] + }, + { + "cell_type": "code", + "metadata": { + "colab_type": "code", + "id": "f37TSEGvGX4_", + "colab": {} + }, + "source": [ + "# Size of each input image, 28 x 28 pixels\n", + "IMAGE_SIZE = 28 * 28\n", + "# Number of distinct number labels, [0..9]\n", + "NUM_CLASSES = 10\n", + "# Number of examples in each training batch (step)\n", + "TRAIN_BATCH_SIZE = 100\n", + "# Number of training steps to run\n", + "TRAIN_STEPS = 1000\n", + "\n", + "# Loads MNIST dataset.\n", + "train, test = tf.keras.datasets.mnist.load_data()\n", + "train_ds = tf.data.Dataset.from_tensor_slices(train).batch(TRAIN_BATCH_SIZE).repeat()\n", + "\n", + "# Casting from raw data to the required datatypes.\n", + "def cast(images, labels):\n", + " images = tf.cast(\n", + " tf.reshape(images, [-1, IMAGE_SIZE]), tf.float32)\n", + " labels = tf.cast(labels, tf.int64)\n", + " return (images, labels)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lv7I-u_82v1S", + "colab_type": "text" + }, + "source": [ + "Finally, we define the model and the optimizer. For the model, we shall use a single dense layer." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "7O2NcEfG206Q", + "colab_type": "code", + "colab": {} + }, + "source": [ + "layer = tf.keras.layers.Dense(NUM_CLASSES)\n", + "optimizer = tf.keras.optimizers.Adam()\n" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "x_ZehpZP-SfS" + }, + "source": [ + "# Define the training function\n", + "\n", + "In the training function, we get predicted labels using the layer defined above, and then we minimize the gradient of the loss using the optimizer. In order to compile the computation using XLA, we place it inside `tf.function` with `experimental_compile=True`." + ] + }, + { + "cell_type": "code", + "metadata": { + "colab_type": "code", + "id": "ZbhJl_WvGa3g", + "colab": {} + }, + "source": [ + "@tf.function(experimental_compile=True)\n", + "def train_mnist(images, labels):\n", + " images, labels = cast(images, labels)\n", + "\n", + " with tf.GradientTape() as tape:\n", + " predicted_labels = layer(images)\n", + " loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(\n", + " logits=predicted_labels, labels=labels\n", + " ))\n", + " layer_variables = layer.trainable_variables\n", + " grads = tape.gradient(loss, layer_variables)\n", + " optimizer.apply_gradients(zip(grads, layer_variables))\n" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "EZD1m_n1DxAF" + }, + "source": [ + "# Train and test the model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gukC2Hol3sFZ", + "colab_type": "text" + }, + "source": [ + "Once we have defined the training function, we can define the model." + ] + }, + { + "cell_type": "code", + "metadata": { + "colab_type": "code", + "id": "qe28bAHNHUG2", + "colab": {} + }, + "source": [ + "for images, labels in train_ds:\n", + " if optimizer.iterations > TRAIN_STEPS:\n", + " break\n", + " train_mnist(images, labels)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "qgsKmz3n2UiW" + }, + "source": [ + "And, finally, check the accuracy:" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab_type": "code", + "id": "_GxF6jTRHVuA" + }, + "source": [ + "images, labels = cast(test[0], test[1])\n", + "predicted_labels = layer(images)\n", + "correct_prediction = tf.equal(tf.argmax(predicted_labels, 1), labels)\n", + "accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n", + "print(\"Prediction accuracy after training: %s\" % accuracy)" + ], + "execution_count": 0 + } + ] +} diff --git a/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb b/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb deleted file mode 100644 index 715585db337..00000000000 --- a/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb +++ /dev/null @@ -1,373 +0,0 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "The XLA compile API", - "version": "0.3.2", - "provenance": [], - "collapsed_sections": [], - "toc_visible": true - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - } - }, - "cells": [ - { - "metadata": { - "colab_type": "text", - "id": "f4TSNCvpENrW" - }, - "cell_type": "markdown", - "source": [ - "##### Copyright 2018 The TensorFlow Authors." - ] - }, - { - "metadata": { - "cellView": "form", - "colab_type": "code", - "id": "vamNSA0vEP-m", - "colab": {} - }, - "cell_type": "code", - "source": [ - "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", - "# you may not use this file except in compliance with the License.\n", - "# You may obtain a copy of the License at\n", - "#\n", - "# https://www.apache.org/licenses/LICENSE-2.0\n", - "#\n", - "# Unless required by applicable law or agreed to in writing, software\n", - "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - "# See the License for the specific language governing permissions and\n", - "# limitations under the License." - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "e1oSi4lHFt3z" - }, - "cell_type": "markdown", - "source": [ - "# The XLA compile API" - ] - }, - { - "metadata": { - "colab_type": "text", - "id": "b7noD9NjFRL-" - }, - "cell_type": "markdown", - "source": [ - "\n", - " \n", - " \n", - " \n", - "
\n", - " View on TensorFlow.org\n", - " \n", - " Run in Google Colab\n", - " \n", - " View source on GitHub\n", - "
" - ] - }, - { - "metadata": { - "colab_type": "text", - "id": "v9YbsuLZaBXy" - }, - "cell_type": "markdown", - "source": [ - "\n", - "\n", - "Import TensorFlow and the XLA library. XLA contains `xla.compile()`, an API that compiles part or all of a model with [XLA](https://www.tensorflow.org/extend/xla/)." - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "45kUPj5ZFrRa", - "colab": {} - }, - "cell_type": "code", - "source": [ - "import tensorflow as tf\n", - "\n", - "from tensorflow.contrib.compiler import xla" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "GZVNiRmTDV-5" - }, - "cell_type": "markdown", - "source": [ - "Define some necessary constants and prepare the MNIST dataset." - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "f37TSEGvGX4_", - "colab": {} - }, - "cell_type": "code", - "source": [ - "# Size of each input image, 28 x 28 pixels\n", - "IMAGE_SIZE = 28 * 28\n", - "# Number of distinct number labels, [0..9]\n", - "NUM_CLASSES = 10\n", - "# Number of examples in each training batch (step)\n", - "TRAIN_BATCH_SIZE = 100\n", - "# Number of training steps to run\n", - "TRAIN_STEPS = 1000" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "code", - "id": "TiVXchblG5hK", - "colab": {} - }, - "cell_type": "code", - "source": [ - "# Loads MNIST dataset.\n", - "train, test = tf.keras.datasets.mnist.load_data()\n", - "train_ds = tf.data.Dataset.from_tensor_slices(train).batch(TRAIN_BATCH_SIZE).repeat()\n", - "test_ds = tf.data.Dataset.from_tensor_slices(test).batch(TRAIN_BATCH_SIZE)\n", - "\n", - "iterator = tf.data.Iterator.from_structure(train_ds.output_types, train_ds.output_shapes)\n", - "images, labels = iterator.get_next()\n", - "images = tf.reshape(images, [-1, IMAGE_SIZE])\n", - "images, labels = tf.cast(images, tf.float32), tf.cast(labels, tf.int64)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "x_ZehpZP-SfS" - }, - "cell_type": "markdown", - "source": [ - "# Define the model constructing function\n", - "\n", - "Following code block contains a function that constructs a simple model with one dense layer, including both forward and backward propagation.\n", - "\n", - "When called, it returns two values. `y` is a `tf.Tensor` representing predicted probability of each target class, `train_step` is a `tf.Operation` that increments `global_step` and applies variable update." - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "ZbhJl_WvGa3g", - "colab": {} - }, - "cell_type": "code", - "source": [ - "def build_mnist_model(x, y_):\n", - " y = tf.keras.layers.Dense(NUM_CLASSES).apply(x)\n", - "\n", - " cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=y_, logits=y)\n", - " train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)\n", - "\n", - " return y, train_step" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "7Jh3lyQHDfM9" - }, - "cell_type": "markdown", - "source": [ - "# Enable XLA\n", - "\n", - "Use `xla.compile` with the `build_mnist_model` function to enable XLA. Following code block wraps the model with `xla.compile()`, which allows the target function with provided inputs to be executed by XLA." - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "kYpCXCdRHNuN", - "colab": {} - }, - "cell_type": "code", - "source": [ - "[y] = xla.compile(build_mnist_model, inputs=[images, labels])" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "4giQh62IrZGF" - }, - "cell_type": "markdown", - "source": [ - "When compiling the graph, XLA replaces all the graph nodes constructed in the target function with a few XLA ops.\n", - "\n", - "xla.compile does not return any\n", - "`tf.Operation` nodes that can be executed independently from the generated XLA ops. Instead, returned `tf.Operation` nodes from the target function are added as control dependencies of all returned `tf.Tensor` values. This triggers execution of the `tf.Operation` nodes when the returned tensors are evaluated.\n", - "\n", - "In pseudo-code, xla.compile's implementation looks as follows:\n", - "\n", - "---\n", - "```\n", - "# Ask Tensorflow to execute code in XLA-friendly manner\n", - "\n", - "y, train_step = build_mnist_model(images, labels)\n", - "with tf.control_dependencies([train_step]):\n", - " y = tf.identity(y)\n", - "\n", - "# Ask Tensorflow to STOP executing code in XLA-friendly manner\n", - "```\n", - "---\n", - "\n", - "xla.compile() always returns a list of `tf.Tensor`'s (even if there is only one-element)." - ] - }, - { - "metadata": { - "colab_type": "text", - "id": "TPGas4jjFLZl" - }, - "cell_type": "markdown", - "source": [ - "If you were to print the constructed graph now, you will see that it is not much different from a normal Tensorflow graph and you won't be able to find XLA ops mentioned before. This is because the actual compilation happens later when you try to execute the graph with `sess.run()`. At that time, Tensorflow triggers a series of graph rewrite passes that actually generate XLA ops, which compiles and executes computation when all inputs are ready." - ] - }, - { - "metadata": { - "colab_type": "text", - "id": "EZD1m_n1DxAF" - }, - "cell_type": "markdown", - "source": [ - "# Train and test the model" - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "qe28bAHNHUG2", - "colab": {} - }, - "cell_type": "code", - "source": [ - "# Creates session and initialize all variables.\n", - "# xla.compile() doesn't work with Keras model.fit() API or TF eager mode yet.\n", - "sess = tf.Session()\n", - "sess.run(tf.global_variables_initializer())" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "qgsKmz3n2UiW" - }, - "cell_type": "markdown", - "source": [ - "Following code block trains model. Evaluating `y` also triggers its control dependency node `train_step`, which updates model variables." - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "_GxF6jTRHVuA", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 34 - }, - "outputId": "fbf299ca-02d5-4e95-f9fe-8f3c0432d132" - }, - "cell_type": "code", - "source": [ - "# Feeds training dataset\n", - "sess.run(iterator.make_initializer(train_ds))\n", - "\n", - "# Runs TRAIN_STEPS steps\n", - "for i in range(TRAIN_STEPS):\n", - " sess.run(y)\n", - "\n", - "print(\"Model trained for %s steps.\" % TRAIN_STEPS)" - ], - "execution_count": 21, - "outputs": [ - { - "output_type": "stream", - "text": [ - "Model trained for 1000 steps.\n" - ], - "name": "stdout" - } - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "dHlQlRSRHXD1", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 34 - }, - "outputId": "9c3677a2-ec84-406f-9d2c-d722844f3093" - }, - "cell_type": "code", - "source": [ - "# Tests trained model\n", - "\n", - "# Feeds testing dataset\n", - "sess.run(iterator.make_initializer(test_ds))\n", - "\n", - "# Calculates accuracy\n", - "correct_prediction = tf.equal(tf.argmax(y, 1), labels)\n", - "accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n", - "print(\"Prediction accuracy after training: %s\" % sess.run(accuracy))" - ], - "execution_count": 22, - "outputs": [ - { - "output_type": "stream", - "text": [ - "Prediction accuracy after training: 0.91\n" - ], - "name": "stdout" - } - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "ynJQIuzjHYOb", - "colab": {} - }, - "cell_type": "code", - "source": [ - "# Cleans up session\n", - "sess.close()" - ], - "execution_count": 0, - "outputs": [] - } - ] -} diff --git a/tensorflow/compiler/xla/layout.cc b/tensorflow/compiler/xla/layout.cc index 5f0b5c62187..d234e729688 100644 --- a/tensorflow/compiler/xla/layout.cc +++ b/tensorflow/compiler/xla/layout.cc @@ -52,7 +52,6 @@ string Tile::ToString() const { for (const int64 dimension : proto.minor_to_major()) { layout.add_minor_to_major(dimension); } - layout.set_max_sparse_elements(proto.max_sparse_elements()); for (const TileProto& tile_proto : proto.tiles()) { *layout.add_tiles() = Tile::CreateFromProto(tile_proto); } @@ -68,7 +67,6 @@ LayoutProto Layout::ToProto() const { for (const int64 dimension : minor_to_major()) { proto.add_minor_to_major(dimension); } - proto.set_max_sparse_elements(max_sparse_elements_); for (const Tile& tile : tiles()) { *proto.add_tiles() = tile.ToProto(); } @@ -78,10 +76,7 @@ LayoutProto Layout::ToProto() const { } string Layout::ToString() const { - if (format() == SPARSE) { - CHECK_EQ(tiles_size(), 0) << "Sparse layout should not be tiled."; - return absl::StrCat("sparse{", max_sparse_elements(), "}"); - } else if (format() == DENSE) { + if (format() == DENSE) { string colon_string = tiles().empty() ? "" : "T"; for (Tile tile : tiles()) { absl::StrAppend(&colon_string, tile.ToString()); @@ -107,10 +102,6 @@ bool Layout::Equal::operator()(const Layout& lhs, const Layout& rhs) { if (lhs.format() == DENSE && lhs.minor_to_major() != rhs.minor_to_major()) { return false; } - if (lhs.format() == SPARSE && - lhs.max_sparse_elements() != rhs.max_sparse_elements()) { - return false; - } if (!ignore_tiles_ && lhs.tiles() != rhs.tiles()) { return false; } diff --git a/tensorflow/compiler/xla/layout.h b/tensorflow/compiler/xla/layout.h index 1234d01755b..fd6d62ac2f7 100644 --- a/tensorflow/compiler/xla/layout.h +++ b/tensorflow/compiler/xla/layout.h @@ -203,12 +203,6 @@ class Layout { absl::Span tiles() const { return tiles_; } absl::InlinedVector* mutable_tiles() { return &tiles_; } - // Methods for accessing the int64 fields. - int64 max_sparse_elements() const { return max_sparse_elements_; } - Layout& set_max_sparse_elements(int64 value) { - max_sparse_elements_ = value; - return *this; - } int64 element_size_in_bits() const { return element_size_in_bits_; } Layout& set_element_size_in_bits(int64 value) { element_size_in_bits_ = value; @@ -233,8 +227,7 @@ class Layout { template friend H AbslHashValue(H h, const Layout& l) { - return H::combine(std::move(h), l.format_, l.minor_to_major_, - l.max_sparse_elements_, l.tiles_, + return H::combine(std::move(h), l.format_, l.minor_to_major_, l.tiles_, l.element_size_in_bits_); } @@ -255,11 +248,6 @@ class Layout { // And the major dim is [8,100,100,3][1], which is size 100. absl::InlinedVector minor_to_major_; - // The maximum number of elements that can be stored for SPARSE formats. This - // can be used to determine the maximum size in bytes of arrays stored in - // memory. This field must be zero unless the format is SPARSE. - int64 max_sparse_elements_ = 0; - // The tiles used in tiling-based layout. absl::InlinedVector tiles_; diff --git a/tensorflow/compiler/xla/layout_test.cc b/tensorflow/compiler/xla/layout_test.cc index 26805c5c0a2..7bcc19c9725 100644 --- a/tensorflow/compiler/xla/layout_test.cc +++ b/tensorflow/compiler/xla/layout_test.cc @@ -34,8 +34,6 @@ class LayoutTest : public ::testing::Test {}; TEST_F(LayoutTest, ToString) { EXPECT_EQ(Layout().ToString(), "invalid{}"); EXPECT_EQ(Layout({4, 5, 6}).ToString(), "{4,5,6}"); - EXPECT_EQ(Layout().set_format(SPARSE).set_max_sparse_elements(123).ToString(), - "sparse{123}"); EXPECT_EQ(Layout({4, 5, 6}).ToString(), "{4,5,6}"); EXPECT_EQ(Layout({3, 2, 1, 0}, {Tile({42, 123}), Tile({4, 5})}).ToString(), "{3,2,1,0:T(42,123)(4,5)}"); @@ -65,11 +63,6 @@ TEST_F(LayoutTest, StreamOut) { } } -TEST_F(LayoutTest, SparseLayoutMaxElements) { - EXPECT_EQ(LayoutUtil::MaxSparseElements(LayoutUtil::MakeSparseLayout(101)), - 101); -} - TEST_F(LayoutTest, Equality) { EXPECT_EQ(Layout(), Layout()); const std::vector empty_dims; @@ -90,12 +83,6 @@ TEST_F(LayoutTest, Equality) { Layout({0, 1, 2}).set_memory_space(3)); EXPECT_NE(Layout({0, 1, 2}).set_memory_space(1), Layout({0, 1, 2}).set_memory_space(3)); - EXPECT_EQ(Layout().set_format(SPARSE), Layout().set_format(SPARSE)); - EXPECT_EQ(Layout().set_format(SPARSE).set_max_sparse_elements(42), - Layout().set_format(SPARSE).set_max_sparse_elements(42)); - EXPECT_NE(Layout().set_format(SPARSE).set_max_sparse_elements(42), - Layout().set_format(SPARSE).set_max_sparse_elements(24)); - EXPECT_FALSE( Layout::Equal()(Layout({0, 1, 2}, {Tile({42, 44})}), Layout({0, 1, 2}))); EXPECT_TRUE(Layout::Equal().IgnoreTiles()(Layout({0, 1, 2}, {Tile({42, 44})}), @@ -117,8 +104,6 @@ TEST_F(LayoutTest, LayoutToFromProto) { expect_unchanged(Layout()); expect_unchanged(Layout({1, 3, 2, 0})); - expect_unchanged(Layout().set_format(SPARSE)); - expect_unchanged(Layout().set_format(SPARSE).set_max_sparse_elements(123)); expect_unchanged(Layout({0, 1}).set_element_size_in_bits(42)); expect_unchanged(Layout({3, 2, 1, 0}, {Tile({42, 123}), Tile({4, 5})})); } diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index 45572d9062e..6f8ece1bb10 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -94,13 +94,6 @@ void SetDefaultLayoutToContainer(T* minor_to_major) { return layout; } -/* static */ Layout LayoutUtil::MakeSparseLayout(int64 max_sparse_elements) { - Layout layout; - layout.set_format(SPARSE); - layout.set_max_sparse_elements(max_sparse_elements); - return layout; -} - namespace { // Internal helper that creates a default layout for an array of the given rank. @@ -293,19 +286,6 @@ Layout CreateDefaultLayoutForRank(int64 rank) { layout.minor_to_major().end(), std::greater()); } -/* static */ bool LayoutUtil::IsSparseArray(const Shape& shape) { - return shape.IsArray() && shape.has_layout() && IsSparse(shape.layout()); -} - -/* static */ bool LayoutUtil::IsSparse(const Layout& layout) { - return layout.format() == SPARSE; -} - -/* static */ int64 LayoutUtil::MaxSparseElements(const Layout& layout) { - CHECK(IsSparse(layout)); - return layout.max_sparse_elements(); -} - /* static */ bool LayoutUtil::HasLayout(const Shape& shape) { if (shape.IsTuple()) { // Tuple shape: all subshapes must have a layout. @@ -461,8 +441,6 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) { for (int64 minor_to_major : layout.minor_to_major()) { hash_value = Hash64Combine(hash_value, hash()(minor_to_major)); } - hash_value = Hash64Combine(hash_value, layout.max_sparse_elements()); - for (Tile tile : layout.tiles()) { for (int64 tile_dim : tile.dimensions()) { hash_value = Hash64Combine(hash_value, hash()(tile_dim)); diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index b391220ade9..60e135de354 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -49,10 +49,6 @@ class LayoutUtil { // dimensions. static Layout MakeDescendingLayout(int64 rank); - // Creates a sparse layout with the given maximum number of elements. (This is - // a convenience function for protobuf construction.) - static Layout MakeSparseLayout(int64 max_sparse_elements); - // Returns default layout for the given shape. static Layout GetDefaultLayoutForShape(const Shape& shape); @@ -109,17 +105,6 @@ class LayoutUtil { // more minor, and so on until dimension N-1 which is the minor. static bool IsMonotonicWithDim0Major(const Layout& layout); - // Returns whether the given Shape is an array (i.e. not a tuple) and has a - // sparse format layout. - static bool IsSparseArray(const Shape& shape); - - // Returns whether the given Layout has a sparse format. - static bool IsSparse(const Layout& layout); - - // Returns the maximum number of elements that can be stored in a sparse - // layout. - static int64 MaxSparseElements(const Layout& layout); - // Returns whether the given shape has a layout. For tuple shapes, true is // returned only if all elements have layouts. static bool HasLayout(const Shape& shape); diff --git a/tensorflow/compiler/xla/layout_util_test.cc b/tensorflow/compiler/xla/layout_util_test.cc index 12da2140636..398baa13fca 100644 --- a/tensorflow/compiler/xla/layout_util_test.cc +++ b/tensorflow/compiler/xla/layout_util_test.cc @@ -33,14 +33,6 @@ class LayoutUtilTest : public ::testing::Test { *shape.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major); return shape; } - - Shape MakeShapeWithSparseLayout(PrimitiveType element_type, - absl::Span dimensions, - int64 max_sparse_elements) { - Shape shape = ShapeUtil::MakeShape(element_type, dimensions); - *shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements); - return shape; - } }; TEST_F(LayoutUtilTest, TupleLayoutComparison) { @@ -92,29 +84,6 @@ TEST_F(LayoutUtilTest, CopyLayoutArray) { EXPECT_FALSE(dst.has_layout()); } -TEST_F(LayoutUtilTest, CopyLayoutSparse) { - Shape src = MakeShapeWithSparseLayout(F32, {2, 3}, 2); - Shape dst = MakeShapeWithLayout(F32, {2, 3}, {1, 0}); - - EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst)); - EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); - EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); - - // Should work if destination has no layout. - dst.clear_layout(); - EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst)); - EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); - EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); - - // If source is cleared, then destination should be cleared. - src.clear_layout(); - EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst)); - EXPECT_TRUE(dst.has_layout()); - EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); - EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); - EXPECT_FALSE(dst.has_layout()); -} - TEST_F(LayoutUtilTest, CopyLayoutTuple) { Shape src = ShapeUtil::MakeTupleShape( {MakeShapeWithLayout(F32, {2, 3}, {0, 1}), @@ -134,25 +103,6 @@ TEST_F(LayoutUtilTest, CopyLayoutTuple) { EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); } -TEST_F(LayoutUtilTest, CopyLayoutTupleSparse) { - Shape src = ShapeUtil::MakeTupleShape( - {MakeShapeWithSparseLayout(F32, {2, 3}, 4), - MakeShapeWithSparseLayout(F32, {42, 123}, 4), - ShapeUtil::MakeTupleShape( - {MakeShapeWithLayout(F32, {}, {}), - MakeShapeWithSparseLayout(F32, {1, 2, 3}, 6)})}); - Shape dst = ShapeUtil::MakeTupleShape( - {MakeShapeWithLayout(F32, {2, 3}, {1, 0}), - MakeShapeWithLayout(F32, {42, 123}, {1, 0}), - ShapeUtil::MakeTupleShape( - {MakeShapeWithLayout(F32, {}, {}), - MakeShapeWithLayout(F32, {1, 2, 3}, {1, 2, 0})})}); - - EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst)); - EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); - EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); -} - TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleSameRank) { Shape src = MakeShapeWithLayout(F32, {123, 42, 7}, {2, 0, 1}); Shape dst = MakeShapeWithLayout(F32, {2, 3, 5}, {1, 0}); @@ -160,13 +110,6 @@ TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleSameRank) { EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); } -TEST_F(LayoutUtilTest, CopyLayoutSparseNotCompatibleSameRank) { - Shape src = MakeShapeWithSparseLayout(F32, {123, 42, 7}, 6); - Shape dst = MakeShapeWithLayout(F32, {2, 3, 5}, {1, 0}); - ASSERT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); - EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); -} - TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleDifferentRank) { Shape src = MakeShapeWithLayout(F32, {123, 42, 7}, {2, 0, 1}); Shape dst = MakeShapeWithLayout(F32, {2, 3}, {1, 0}); @@ -176,15 +119,6 @@ TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleDifferentRank) { ::testing::ContainsRegex("cannot copy layout from shape")); } -TEST_F(LayoutUtilTest, CopyLayoutSparseNotCompatibleDifferentRank) { - Shape src = MakeShapeWithLayout(F32, {123, 42, 7}, {2, 0, 1}); - Shape dst = MakeShapeWithSparseLayout(F32, {2, 3}, 4); - auto status = LayoutUtil::CopyLayoutBetweenShapes(src, &dst); - EXPECT_FALSE(status.ok()); - EXPECT_THAT(status.error_message(), - ::testing::ContainsRegex("cannot copy layout from shape")); -} - TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleTuple) { Shape src = ShapeUtil::MakeTupleShape({MakeShapeWithLayout(F32, {2, 3}, {0, 1}), diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 3d6310c1e17..6c7aff3b11e 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -80,7 +80,7 @@ bool LiteralProtoHasValues(const LiteralProto& proto) { proto.c64s_size() || proto.c128s_size() || proto.tuple_literals_size() || !proto.f16s().empty() || !proto.bf16s().empty() || !proto.u16s().empty() || - !proto.s16s().empty() || proto.sparse_indices_size(); + !proto.s16s().empty(); } } // namespace @@ -135,21 +135,8 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) { // Literals can be used as DMA targets, which can require alignment. We // force a 16-byte minimum alignment. constexpr int kMinimumAlignment = 16; - if (LayoutUtil::IsSparseArray(shape)) { - // For sparse arrays, the buffer must be of the size of the maximum - // number of sparse elements possible. - const int64 max_sparse_elements = - LayoutUtil::MaxSparseElements(shape.layout()); - piece->set_buffer(static_cast(tensorflow::port::AlignedMalloc( - max_sparse_elements * - ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type()), - kMinimumAlignment))); - piece->set_sparse_indices( - new SparseIndexArray(max_sparse_elements, shape.rank())); - } else { - piece->set_buffer(static_cast(tensorflow::port::AlignedMalloc( - piece->size_bytes(), kMinimumAlignment))); - } + piece->set_buffer(static_cast(tensorflow::port::AlignedMalloc( + piece->size_bytes(), kMinimumAlignment))); } } else { // If the shape is neither an array nor tuple, then it must be @@ -181,7 +168,6 @@ void Literal::DeallocateBuffers() { [&](const ShapeIndex& index, Piece* piece) { if (piece->buffer() != nullptr) { tensorflow::port::AlignedFree(piece->buffer()); - delete piece->sparse_indices(); } }); } @@ -211,16 +197,6 @@ Literal LiteralBase::CreateFromShape(const Shape& shape) { return literal; } -const SparseIndexArray* LiteralBase::sparse_indices( - const ShapeIndex& shape_index) const { - return piece(shape_index).sparse_indices(); -} - -SparseIndexArray* MutableLiteralBase::sparse_indices( - const ShapeIndex& shape_index) { - return piece(shape_index).sparse_indices(); -} - template Status MutableLiteralBase::CopySliceFromInternal( const LiteralBase& src_literal, absl::Span src_base, @@ -373,12 +349,9 @@ std::vector Literal::DecomposeTuple() { } Piece& src_piece = piece(src_index); - // Move the respective buffer and sparse indices over to the element - // Literal. + // Move the respective buffer over to the element Literal. dest_piece->set_buffer(src_piece.buffer()); src_piece.set_buffer(nullptr); - dest_piece->set_sparse_indices(src_piece.sparse_indices()); - src_piece.set_sparse_indices(nullptr); }); } // Set this literal to be nil-shaped. @@ -512,8 +485,6 @@ Status Literal::MoveFrom(Literal&& src_literal, Piece& dest_piece = piece(dest_index); tensorflow::port::AlignedFree(dest_piece.buffer()); dest_piece.set_buffer(src_piece.buffer()); - delete dest_piece.sparse_indices(); - dest_piece.set_sparse_indices(src_piece.sparse_indices()); }); src_literal.shape_ = absl::make_unique(ShapeUtil::MakeNil()); @@ -738,14 +709,14 @@ Literal LiteralBase::SliceInternal( const Shape& result_shape, absl::Span start_indices) const { Literal result_literal(result_shape); DimensionVector new_indices(result_shape.rank()); - result_literal.EachCell( - [&](absl::Span indices, NativeT /*value*/) { - for (int64 i = 0; i < result_shape.rank(); ++i) { - new_indices[i] = indices[i] + start_indices[i]; - } - NativeT value = Get(new_indices); - result_literal.Set(indices, value); - }); + CHECK(result_literal + .Populate([&](absl::Span indices) { + for (int64 i = 0; i < result_shape.rank(); ++i) { + new_indices[i] = indices[i] + start_indices[i]; + } + return Get(new_indices); + }) + .ok()); return result_literal; } @@ -854,66 +825,6 @@ string LiteralBase::GetAsString(absl::Span multi_index, } } -string LiteralBase::GetSparseElementAsString( - int64 sparse_element_number, const ShapeIndex& shape_index) const { - const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); - CHECK(LayoutUtil::IsSparseArray(subshape)); - switch (subshape.element_type()) { - case PRED: - return GetSparseElement(sparse_element_number, shape_index) - ? "true" - : "false"; - case S8: - return StrCat(GetSparseElement(sparse_element_number, shape_index)); - case S16: - return StrCat( - GetSparseElement(sparse_element_number, shape_index)); - case S32: - return StrCat( - GetSparseElement(sparse_element_number, shape_index)); - case S64: - return StrCat( - GetSparseElement(sparse_element_number, shape_index)); - case U8: - return StrCat( - GetSparseElement(sparse_element_number, shape_index)); - case U16: - return StrCat( - GetSparseElement(sparse_element_number, shape_index)); - case U32: - return StrCat( - GetSparseElement(sparse_element_number, shape_index)); - case U64: - return StrCat( - GetSparseElement(sparse_element_number, shape_index)); - case F16: - return StrCat(static_cast( - GetSparseElement(sparse_element_number, shape_index))); - case F32: - return StrCat( - GetSparseElement(sparse_element_number, shape_index)); - case BF16: - return StrCat(static_cast( - GetSparseElement(sparse_element_number, shape_index))); - case F64: - return StrCat( - GetSparseElement(sparse_element_number, shape_index)); - case C64: { - complex64 c = - GetSparseElement(sparse_element_number, shape_index); - return StrCat("(", c.real(), ", ", c.imag(), ")"); - } - case C128: { - complex128 c = - GetSparseElement(sparse_element_number, shape_index); - return StrCat("(", c.real(), ", ", c.imag(), ")"); - } - default: - LOG(FATAL) << "Invalid element type for sparse arrays: " - << PrimitiveType_Name(subshape.element_type()); - } -} - absl::optional LiteralBase::GetIntegralAsS64( absl::Span multi_index) const { CHECK(LayoutUtil::IsDenseArray(shape())); @@ -1047,81 +958,6 @@ Status MutableLiteralBase::SetFromDouble(absl::Span multi_index, return Status::OK(); } -absl::Span LiteralBase::GetSparseIndex( - int64 sparse_element_number, const ShapeIndex& shape_index) const { - const Piece& p = piece(shape_index); - CHECK_GE(sparse_element_number, 0); - CHECK_LT(sparse_element_number, p.sparse_indices()->index_count()); - return p.sparse_indices()->At(sparse_element_number); -} - -void MutableLiteralBase::SortSparseElements(const ShapeIndex& shape_index) { - piece(shape_index).SortSparseElements(); -} - -void LiteralBase::Piece::SortSparseElements() { - switch (subshape().element_type()) { - case PRED: - SortSparseElementsInternal(); - break; - case S8: - SortSparseElementsInternal(); - break; - case U8: - SortSparseElementsInternal(); - break; - case S16: - SortSparseElementsInternal(); - break; - case U16: - SortSparseElementsInternal(); - break; - case S32: - SortSparseElementsInternal(); - break; - case U32: - SortSparseElementsInternal(); - break; - case S64: - SortSparseElementsInternal(); - break; - case U64: - SortSparseElementsInternal(); - break; - case F32: - SortSparseElementsInternal(); - break; - case F64: - SortSparseElementsInternal(); - break; - case C64: - SortSparseElementsInternal(); - break; - case C128: - SortSparseElementsInternal(); - break; - case F16: - SortSparseElementsInternal(); - break; - case BF16: - SortSparseElementsInternal(); - break; - default: - LOG(FATAL) << "Element type not valid for sparse array: " - << PrimitiveType_Name(subshape().element_type()); - } -} - -template -void LiteralBase::Piece::SortSparseElementsInternal() { - CHECK(LayoutUtil::IsSparseArray(subshape())); - int64 num_elements = sparse_indices()->index_count(); - auto values = data(); - CHECK_LE(num_elements, values.size()); - sparse_indices()->SortWithValues( - absl::Span(values.data(), num_elements)); -} - namespace { string ShapeToString(bool print_layout, const Shape& shape) { @@ -1151,32 +987,6 @@ void TupleToStringHelper(const LiteralBase& literal, pieces->push_back("\n)"); } -void SparseArrayToStringHelper(const LiteralBase& literal, - const Shape& subshape, bool print_shape, - bool print_layout, std::vector* pieces) { - if (print_shape) { - pieces->push_back(ShapeToString(print_layout, subshape)); - } - pieces->push_back("{"); - int64 rank = subshape.rank(); - int64 num_elements = literal.sparse_element_count(); - for (int64 i = 0; i < num_elements; ++i) { - if (i > 0) { - pieces->push_back(", "); - } - if (rank == 1) { - pieces->push_back(StrCat(literal.GetSparseIndex(i)[0])); - pieces->push_back(": "); - } else { - pieces->push_back("["); - pieces->push_back(absl::StrJoin(literal.GetSparseIndex(i), ", ")); - pieces->push_back("]: "); - } - pieces->push_back(literal.GetSparseElementAsString(i)); - } - pieces->push_back("}"); -} - void DenseArrayToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, bool print_shape, bool print_layout, std::vector* pieces) { @@ -1261,9 +1071,6 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, pieces); } else if (subshape.IsToken()) { pieces->push_back("token"); - } else if (LayoutUtil::IsSparseArray(subshape)) { - SparseArrayToStringHelper(literal, subshape, print_shape, print_layout, - pieces); } else { CHECK(LayoutUtil::IsDenseArray(subshape)); DenseArrayToStringHelper(literal, shape_index, print_shape, print_layout, @@ -1273,11 +1080,6 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, } // namespace -int64 LiteralBase::sparse_element_count() const { - CHECK(LayoutUtil::IsSparseArray(shape())); - return sparse_indices()->index_count(); -} - string LiteralBase::ToString() const { std::vector pieces; CHECK(LayoutUtil::HasLayout(this->shape())); @@ -2053,22 +1855,6 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { TF_RET_CHECK(LayoutUtil::HasLayout(shape)); TF_RET_CHECK(ShapeUtil::Equal(shape, subshape())); - if (LayoutUtil::IsSparseArray(subshape())) { - // Compute the number of elements (indices) in the sparse shape and reserve - // the necessary space in spare_indices. - TF_RET_CHECK(subshape().rank() != 0) << "Scalar shapes cannot be sparse"; - TF_RET_CHECK(proto.sparse_indices_size() % subshape().rank() == 0) - << "Unexpected number of indices in proto (" - << proto.sparse_indices_size() << ") for shape of rank " - << subshape().rank(); - const int64 index_count = proto.sparse_indices_size() / subshape().rank(); - sparse_indices()->Resize(index_count); - - // Copy the indices from the proto into the SparseIndexArray object. - TF_RETURN_IF_ERROR(CopyFromRepeatedField(sparse_indices()->mutable_data(), - proto.sparse_indices())); - } - switch (subshape().element_type()) { case PRED: TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.preds())); @@ -2175,11 +1961,6 @@ LiteralProto LiteralBase::ToProto() const { piece.WriteToProto(proto_piece); }); - if (LayoutUtil::IsSparseArray(shape())) { - CopyToRepeatedField(proto.mutable_sparse_indices(), - sparse_indices()->data()); - } - return proto; } @@ -2295,12 +2076,6 @@ MutableBorrowingLiteral::MutableBorrowingLiteral(const char* src_buf_ptr, MutableBorrowingLiteral::~MutableBorrowingLiteral() { if (root_piece_ != nullptr) { - root_piece_->ForEachMutableSubpiece( - [&](const ShapeIndex& index, Piece* piece) { - if (piece->buffer() != nullptr) { - delete piece->sparse_indices(); - } - }); delete root_piece_; } } diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index 2d27f8eb7f6..7aee34437e6 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -35,7 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/sparse_index_array.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" @@ -77,11 +76,6 @@ class LiteralBase { template absl::Span data(const ShapeIndex& shape_index = {}) const; - // Returns a const pointer to the sparse index array. Returns nullptr if the - // literal is not a sparse array. - const SparseIndexArray* sparse_indices( - const ShapeIndex& shape_index = {}) const; - // Returns a const pointer to (or size of) the underlying buffer holding the // array at the given shape index. CHECKs if the subshape of the literal at // the given ShapeIndex is not array. @@ -126,10 +120,6 @@ class LiteralBase { // into text. string GetAsString(absl::Span multi_index, const ShapeIndex& shape_index = {}) const; - // As GetSparseElement(), but determines the correct type and converts the - // value into text. - string GetSparseElementAsString(int64 sparse_element_number, - const ShapeIndex& shape_index = {}) const; // Return whether the value at the specified index is equal to the provided // generic `value` (T must be an arithmetic type). @@ -172,21 +162,6 @@ class LiteralBase { absl::optional GetAsComplex128( absl::Span multi_index) const; - // Returns the multi-index of the element in a sparse literal at the given - // sparse element number. The sparse element number is the position with in - // the sparse array's list of (index, value) pairs, and is checked against the - // total number of (index, value) pairs in the sparse array. - absl::Span GetSparseIndex( - int64 sparse_element_number, const ShapeIndex& shape_index = {}) const; - - // Returns the value of the element in a sparse literal at the given sparse - // element number. The sparse element number is the position with in the - // sparse array's list of (index, value) pairs, and is checked against the - // total number of (index, value) pairs in the sparse array. - template - NativeT GetSparseElement(int64 sparse_element_number, - const ShapeIndex& shape_index = {}) const; - // Invokes the "per cell" callback for each element in the provided // literal with the element's indices and a string representation of // the element's value. @@ -259,13 +234,7 @@ class LiteralBase { return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index)); } - // Returns the count of the elements in the sparse array at the given shape - // index in this literal, which will be no larger than - // LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()). - int64 sparse_element_count() const; - - // Compute a hash for this literal. This literal must not be a sparse tensor - // or a tuple containing a sparse tensor. + // Compute a hash for this literal. size_t Hash() const; // Converts this literal to the given shape. Returns an error is the @@ -385,14 +354,6 @@ class LiteralBase { char* buffer() const { return buffer_; } void set_buffer(char* buffer) { buffer_ = buffer; } - // The array of multi-indices that provide the locations of non-zero - // elements in a sparse array. Only used if - // LayoutUtil::IsSparseArray(shape()) is true. - SparseIndexArray* sparse_indices() const { return sparse_indices_; } - void set_sparse_indices(SparseIndexArray* sparse_indices) { - sparse_indices_ = sparse_indices; - } - // Gets or sets the subshape of this piece. This reference points to a // subshape within the shape in the containing Literal (Literal::shape_). const Shape& subshape() const { return *subshape_; } @@ -402,13 +363,7 @@ class LiteralBase { int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); } // Returns the number of elements in this piece's array. - int64 element_count() const { - // If this is a sparse array, use the number of elements represented by - // the indices in the associated SparseIndexArray. - return LayoutUtil::IsSparseArray(subshape()) - ? sparse_indices()->index_count() - : ShapeUtil::ElementsIn(subshape()); - } + int64 element_count() const { return ShapeUtil::ElementsIn(subshape()); } // Returns the child piece at 'index' of this piece. Piece& child(int64 index) { return children_[index]; } @@ -489,9 +444,6 @@ class LiteralBase { // piece must be equal (not just compatible) to the shape of the proto. Status CopyFromProto(const LiteralProto& proto); - // Sorts the elements in a sparse array. - void SortSparseElements(); - private: // Helpers for traversing the piece via ForEachSubpiece rooted at 'index'. // The first non-OK (or non-true) value is returned by the function. @@ -541,17 +493,9 @@ class LiteralBase { bool EqualElementsInternal(const Piece& other, std::vector* multi_index) const; - // Helper for SortSparseElements that has the element type as a template - // parameter. - template - void SortSparseElementsInternal(); - // For array-shaped pieces, this is the buffer holding the literal data. char* buffer_ = nullptr; - // For sparse arrays, this is the array of indices. - SparseIndexArray* sparse_indices_ = nullptr; - // The shape of piece. This points into the shape of the containing Literal // (Literal::shape_). const Shape* subshape_ = nullptr; @@ -598,10 +542,6 @@ class MutableLiteralBase : public LiteralBase { // Unhide const method from parent class. using LiteralBase::data; - // Returns a pointer to the sparse index array. Returns nullptr if the literal - // is not a sparse array. - SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {}); - // TODO(b/67651157): Remove this accessor. Literal users should not be able to // mutate the shape as this can produce malformed Literals. Shape* mutable_shape_do_not_use() { return shape_.get(); } @@ -613,16 +553,6 @@ class MutableLiteralBase : public LiteralBase { // Unhide const method from parent class. using LiteralBase::untyped_data; - // Populates a literal with a sparse layout with the given indices and values. - // Each index in the indices array is CHECKed against the dimensions in the - // literal's shape. If sort is true, then the indices and values will be - // sorted. If sort is false, then the indices and values are assumed to - // already be in sorted order. See CreateSparse for an example of how data - // are populated. - template - void PopulateSparse(SparseIndexArray indices, - absl::Span values, bool sort = true); - // Copy values from 'src_literal' rooted at 'src_shape_index' into this // literal rooted at 'dest_shape_index'. The subshape of this literal rooted // at 'dest_shape_index' must be compatible with the subshape of 'src_literal' @@ -661,16 +591,6 @@ class MutableLiteralBase : public LiteralBase { template void Set(absl::Span multi_index, NativeT value); - // Appends the given element to the literal. If the elements are not appended - // in sorted order, then SortSparseElements should be called before calling - // other methods. This literal must have a sparse layout. - template - void AppendSparseElement(absl::Span multi_index, NativeT value, - const ShapeIndex& shape_index = {}); - - // Sorts the elements in a sparse array. - void SortSparseElements(const ShapeIndex& shape_index = {}); - // As Set(), but truncates `value` to the literal element type before storing. // This literal must be an array. Status SetIntegralAsS64(absl::Span multi_index, int64 value); @@ -988,34 +908,6 @@ NativeT LiteralBase::GetFirstElement() const { return data().at(0); } -template -NativeT LiteralBase::GetSparseElement(int64 sparse_element_number, - const ShapeIndex& shape_index) const { - CHECK( - LayoutUtil::IsSparseArray(ShapeUtil::GetSubshape(shape(), shape_index))); - return data(shape_index)[sparse_element_number]; -} - -template -void MutableLiteralBase::AppendSparseElement( - absl::Span multi_index, NativeT value, - const ShapeIndex& shape_index) { - Piece& p = piece(shape_index); - const Shape& subshape = p.subshape(); - CHECK(LayoutUtil::IsSparseArray(subshape)); - int64 rank = subshape.rank(); - CHECK_EQ(multi_index.size(), rank); - for (int64 i = 0; i < rank; ++i) { - CHECK_GE(multi_index[i], 0); - CHECK_LT(multi_index[i], subshape.dimensions(i)); - } - int64 last_element = p.sparse_indices()->index_count(); - CHECK_LT(last_element, LayoutUtil::MaxSparseElements(subshape.layout())); - p.sparse_indices()->Append(multi_index); - CHECK_LT(last_element, p.data().size()); - p.data()[last_element] = value; -} - template void LiteralBase::EachCell( std::function indices, NativeT value)> @@ -1094,31 +986,6 @@ void MutableLiteralBase::PopulateR4FromArray4D(const Array4D& values) { PopulateFromArray(values); } -template -void MutableLiteralBase::PopulateSparse(SparseIndexArray indices, - absl::Span values, - bool sort) { - CHECK(LayoutUtil::IsSparseArray(shape())); - int rank = shape().rank(); - CHECK_EQ(indices.rank(), rank); - int64 max_elements = LayoutUtil::MaxSparseElements(shape().layout()); - CHECK_LE(indices.max_indices(), max_elements); - int64 num_elements = values.size(); - CHECK_LE(num_elements, max_elements); - CHECK_EQ(num_elements, indices.index_count()); - auto root_data = root_piece().data(); - // Piece::data() returns a Span of size equal to the number of indices - // in the SparseIndexArray. So there is no need to adjust the size of the data - // here. It is enough to just copy the incoming values into the data buffer. - std::copy(values.begin(), values.end(), root_data.begin()); - *this->root_piece().sparse_indices() = std::move(indices); - if (sort) { - auto root_data = this->root_piece().data(); - this->root_piece().sparse_indices()->SortWithValues(root_data); - } - DCHECK(this->root_piece().sparse_indices()->Validate(shape())); -} - template Status MutableLiteralBase::PopulateInternal(const FnType& generator, bool parallel) { diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc index 9b17cb762c8..6afbcce40b0 100644 --- a/tensorflow/compiler/xla/literal_test.cc +++ b/tensorflow/compiler/xla/literal_test.cc @@ -252,42 +252,6 @@ TEST_F(LiteralUtilTest, CreateR3FromArray3d) { EXPECT_EQ(expected, result); } -TEST_F(LiteralUtilTest, CreateSparse) { - std::vector dimensions = {8, 8, 8}; - Array2D indices = { - {3, 4, 5}, - {1, 2, 3}, - {2, 3, 4}, - {3, 5, 6}, - }; - std::vector values = {7, 8, 9, 10}; - auto literal = LiteralUtil::CreateSparse( - dimensions, SparseIndexArray(indices.n1() + 3, indices), values); - - Array2D expected_indices = { - {1, 2, 3}, - {2, 3, 4}, - {3, 4, 5}, - {3, 5, 6}, - }; - std::vector expected_values = {8, 9, 7, 10}; - - EXPECT_EQ(literal.sparse_indices()->data(), - absl::Span(expected_indices.data(), - expected_indices.num_elements())); - EXPECT_EQ(literal.data(), absl::Span(expected_values)); - - // Serialize then deserialize and verify the resulting literal. - TF_ASSERT_OK_AND_ASSIGN(Literal literal_from_proto, - Literal::CreateFromProto(literal.ToProto())); - - EXPECT_EQ(literal_from_proto.sparse_indices()->data(), - absl::Span(expected_indices.data(), - expected_indices.num_elements())); - EXPECT_EQ(literal_from_proto.data(), - absl::Span(expected_values)); -} - TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { // clang-format off auto literal = LiteralUtil::CreateR4Projected({ @@ -1978,43 +1942,6 @@ TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) { EXPECT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements")); } -TEST_F(LiteralUtilTest, SortSparseElements) { - auto literal = LiteralUtil::CreateSparse({10, 10, 10}, - SparseIndexArray(10, 3), {}); - literal.AppendSparseElement({2, 3, 4}, 2.0); - literal.AppendSparseElement({3, 4, 5}, 3.0); - literal.AppendSparseElement({1, 2, 3}, 1.0); - literal.SortSparseElements(); - EXPECT_EQ(literal.ToString(), - "f32[10,10,10]{[1, 2, 3]: 1, [2, 3, 4]: 2, [3, 4, 5]: 3}"); -} - -TEST_F(LiteralUtilTest, GetSparseElementAsString) { - std::vector dimensions = {10, 10, 10}; - SparseIndexArray indices(10, {{1, 2, 3}, {2, 3, 4}, {3, 4, 5}}); - - EXPECT_EQ( - LiteralUtil::CreateSparse(dimensions, indices, {true, false, true}) - .GetSparseElementAsString(1), - "false"); - EXPECT_EQ(LiteralUtil::CreateSparse(dimensions, indices, {1, 2, 3}) - .GetSparseElementAsString(1), - absl::StrCat(int64{2})); - EXPECT_EQ( - LiteralUtil::CreateSparse(dimensions, indices, {1.0, 2.0, 3.0}) - .GetSparseElementAsString(1), - absl::StrCat(double{2.0})); - EXPECT_EQ(LiteralUtil::CreateSparse(dimensions, indices, - {half{1.0}, half{2.0}, half{3.0}}) - .GetSparseElementAsString(1), - absl::StrCat(static_cast(half{2.0}))); - EXPECT_EQ(LiteralUtil::CreateSparse( - dimensions, indices, - std::vector{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}) - .GetSparseElementAsString(1), - absl::StrCat("(", float{3.0}, ", ", float{4.0}, ")")); -} - TEST_F(LiteralUtilTest, BroadcastVectorToMatrix0) { Literal literal = LiteralUtil::CreateR1({1, 2}); TF_ASSERT_OK_AND_ASSIGN( @@ -2061,6 +1988,11 @@ TEST_F(LiteralUtilTest, GetAsComplex128) { EXPECT_FALSE(c6.GetAsComplex128({}).has_value()); } +TEST_F(LiteralUtilTest, SliceOnBool) { + Literal c1 = LiteralUtil::CreateR1({true, true, false}); + EXPECT_EQ(c1, c1.Slice({0}, {3})); +} + TEST_F(LiteralUtilTest, IsEqualAt) { double val_double = 10.0; int val_integral = 10; diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index c4535badafa..b22b71a2ec0 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -38,7 +38,6 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/sparse_index_array.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" @@ -102,46 +101,6 @@ class LiteralUtil { values, const Layout& layout); - // Creates a literal with a sparse layout and the given indices and values. - // The shape is initialized from the given dimensions. The minor dimension of - // the indices array must equal the rank of the shape (i.e. size of the - // dimensions array). The major dimension of the indices array must equal the - // number of elements in the values array. The maximum number of elements in - // the array is taken from the max_indices() value of the index array. - // - // XLA assumes that sparse literals are in sorted order for all operations. If - // the `sort` argument is true, then the indices and values will be sorted - // while copying them into the literal. If you have ensured that the indices - // and values are already sorted, then you may set the `sort` argument to - // false to skip the sorting step. - // - // For example: - // - // CreateSparse( - // {12, 12, 12}, - // SparseIndexArray(10, 3, - // Array2D{ - // {0, 1, 2}, - // {3, 4, 5}, - // {6, 7, 8}, - // {9, 10, 11}, - // }), - // {1.0, 2.0 3.0, 4.0}) - // - // This creates an array with shape F64[12,12,12]sparse{10}, that has the - // following non-zero values: - // - // [0, 1, 2]: 1.0 - // [3, 4, 5]: 2.0 - // [6, 7, 8]: 3.0 - // [9, 10, 11]: 4.0 - // - template - static Literal CreateSparse(absl::Span dimensions, - SparseIndexArray indices, - absl::Span values, - bool sort = true); - // Creates a scalar literal value zero of the given primitive type. static Literal Zero(PrimitiveType primitive_type); // Creates a scalar literal value one of the given primitive type. @@ -417,21 +376,6 @@ template return CreateR4FromArray4DWithLayout(tmp, layout); } -template -/* static */ Literal LiteralUtil::CreateSparse( - absl::Span dimensions, SparseIndexArray indices, - absl::Span values, bool sort) { - int64 num_elements = values.size(); - int64 rank = dimensions.size(); - CHECK_EQ(num_elements, indices.index_count()); - CHECK_EQ(rank, indices.rank()); - Literal literal(ShapeUtil::MakeShapeWithSparseLayout( - primitive_util::NativeToPrimitiveType(), dimensions, - indices.max_indices())); - literal.PopulateSparse(indices, values, sort); - return literal; -} - template /* static */ Literal LiteralUtil::CreateR4( std::initializer_list Device::GetLocalDeviceState() const { + if (local_device_state_) { + return local_device_state_.get(); + } + return InvalidArgument("Device %s is not a local device.", DebugString()); +} + std::string CpuDevice::DebugString() const { return absl::StrCat("CPU_", id()); } @@ -115,7 +122,7 @@ std::string GpuDevice::DebugString() const { static StatusOr> CreateBFCAllocator( se::Platform* platform, - absl::Span> device_states, + absl::Span> local_devices, LocalClient* client, double memory_fraction, bool preallocate) { CHECK_GT(client->backend().device_count(), 0); std::vector allocators; @@ -148,19 +155,24 @@ static StatusOr> CreateBFCAllocator( /*allow_growth=*/!preallocate, absl::StrCat("GPU_", device_ordinal, "_bfc")); allocators.emplace_back(std::move(gpu_bfc_allocator), - device_states.at(device_ordinal)->compute_stream()); + local_devices.at(device_ordinal) + ->local_device_state() + ->compute_stream()); } return absl::make_unique(platform, std::move(allocators)); } -static std::shared_ptr MakeDevice(const std::string& platform_name, - int id, int local_device_ordinal) { +static std::shared_ptr MakeDevice( + const std::string& platform_name, int id, + std::unique_ptr local_device_state) { if (platform_name == "cpu") { - return std::make_shared(id, local_device_ordinal, platform_name); + return std::make_shared(id, std::move(local_device_state), + platform_name); } else { CHECK_EQ(platform_name, "gpu"); - return std::make_shared(id, local_device_ordinal, platform_name); + return std::make_shared(id, std::move(local_device_state), + platform_name); } } @@ -179,16 +191,15 @@ StatusOr> PyLocalClient::Get( ClientLibrary::GetOrCreateLocalClient(options)); bool gpu_platform = platform_name == "gpu"; - std::vector> device_states; std::vector> devices; bool synchronous_deallocation = platform_name == "cpu"; for (int i = 0; i < client->device_count(); ++i) { se::StreamExecutor* executor = client->backend().stream_executor(i).ValueOrDie(); - device_states.push_back(absl::make_unique( + auto device_state = absl::make_unique( executor, synchronous_deallocation, asynchronous, - /*allow_event_reuse=*/gpu_platform)); - devices.push_back(MakeDevice(platform_name, i, i)); + /*allow_event_reuse=*/gpu_platform); + devices.push_back(MakeDevice(platform_name, i, std::move(device_state))); } std::unique_ptr allocator; @@ -196,7 +207,7 @@ StatusOr> PyLocalClient::Get( if (gpu_platform) { if (allocator_config.kind != AllocatorConfig::Kind::kPlatform) { TF_ASSIGN_OR_RETURN(allocator, - CreateBFCAllocator(platform, device_states, client, + CreateBFCAllocator(platform, devices, client, allocator_config.memory_fraction, allocator_config.preallocate)); } @@ -217,21 +228,18 @@ StatusOr> PyLocalClient::Get( return std::make_shared( platform_name, client, std::move(devices), /*host_id=*/0, - std::move(device_states), std::move(allocator), - std::move(host_memory_allocator)); + std::move(allocator), std::move(host_memory_allocator)); } PyLocalClient::PyLocalClient( std::string platform_name, LocalClient* client, std::vector> devices, int host_id, - std::vector> device_states, std::unique_ptr allocator, std::unique_ptr host_memory_allocator) : platform_name_(std::move(platform_name)), client_(client), devices_(std::move(devices)), host_id_(host_id), - device_states_(std::move(device_states)), owned_allocator_(std::move(allocator)), host_memory_allocator_(std::move(host_memory_allocator)), h2d_transfer_pool_(tensorflow::Env::Default(), "py_xla_h2d_transfer", @@ -242,15 +250,16 @@ PyLocalClient::PyLocalClient( allocator_ = client_->backend().memory_allocator(); } - local_devices_.resize(device_states_.size()); for (const std::shared_ptr& device : devices_) { CHECK(id_to_device_.insert({device->id(), device}).second) << "Duplicate device id: " << device->id(); - if (device->local_device_ordinal() != -1) { - int idx = device->local_device_ordinal(); + if (device->local_device_state()) { + int idx = device->local_device_state()->device_ordinal(); + if (idx >= local_devices_.size()) { + local_devices_.resize(idx + 1); + } CHECK(local_devices_[idx] == nullptr) << idx; - CHECK_LT(idx, local_devices_.size()); local_devices_[idx] = device; } } @@ -274,17 +283,19 @@ PyLocalClient::DeserializeExecutable( } Status PyLocalClient::TransferToInfeed(const LiteralSlice& literal, - int device_ordinal) { - TF_RETURN_IF_ERROR( - CheckDeviceOrdinal(device_ordinal, "PyLocalClient::TransferToInfeed")); - return client_->TransferToInfeedLocal(literal, device_ordinal); + std::shared_ptr device) { + TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, + device->GetLocalDeviceState()); + return client_->TransferToInfeedLocal(literal, + local_device->device_ordinal()); } -StatusOr PyLocalClient::TransferFromOutfeed(const Shape& shape, - int device_ordinal) { - TF_RETURN_IF_ERROR( - CheckDeviceOrdinal(device_ordinal, "PyLocalClient::TransferFromOutfeed")); - return client_->TransferFromOutfeedLocal(shape, device_ordinal); +StatusOr PyLocalClient::TransferFromOutfeed( + const Shape& shape, std::shared_ptr device) { + TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, + device->GetLocalDeviceState()); + return client_->TransferFromOutfeedLocal(shape, + local_device->device_ordinal()); } StatusOr PyLocalClient::GetDefaultDeviceAssignment( @@ -293,36 +304,26 @@ StatusOr PyLocalClient::GetDefaultDeviceAssignment( num_replicas, /*computation_count=*/1); } -Status PyLocalClient::CheckDeviceOrdinal(int device_ordinal, - absl::string_view caller_name) { - if (device_ordinal < 0 || device_ordinal >= local_device_count()) { - return InvalidArgument( - "%s got bad device_ordinal: %d (num_local_devices=%d)", caller_name, - device_ordinal, local_device_count()); - } - return Status::OK(); -} - /* static */ StatusOr> PyLocalBuffer::FromLiterals( std::vector leaves_literals, const Shape& tuple_shape, std::shared_ptr leaves_reference, - std::shared_ptr client, int device_ordinal) { + std::shared_ptr client, std::shared_ptr device) { tensorflow::profiler::TraceMe traceme("PyLocalBuffer::FromLiterals"); VLOG(1) << "PyLocalBuffer::FromLiterals: shape: " << tuple_shape.ToString() - << " device ordinal: " << device_ordinal; - TF_RETURN_IF_ERROR(client->CheckDeviceOrdinal(device_ordinal, - "PyLocalBuffer::FromLiterals")); - DeviceState* device = &client->device_state(device_ordinal); + << " device: " << device->DebugString(); + TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, + device->GetLocalDeviceState()); TransferManager* transfer_manager = client->client()->backend().transfer_manager(); se::DeviceMemoryAllocator* allocator = client->allocator(); TF_ASSIGN_OR_RETURN( Shape compact_shape, transfer_manager->ChooseCompactLayoutForShape(tuple_shape)); - TF_ASSIGN_OR_RETURN(ScopedShapedBuffer scoped_buffer, - transfer_manager->AllocateScopedShapedBuffer( - compact_shape, allocator, device_ordinal)); + TF_ASSIGN_OR_RETURN( + ScopedShapedBuffer scoped_buffer, + transfer_manager->AllocateScopedShapedBuffer( + compact_shape, allocator, local_device->device_ordinal())); // Make the host to device stream wait for the newly allocated buffer to be // available on the compute stream. We schedule this wait synchronously; while @@ -331,8 +332,9 @@ StatusOr> PyLocalBuffer::FromLiterals( // computations that depend on this transfer being enqueued on the compute // stream. if (!transfer_manager->CanShapedBufferBeAccessedNow( - device->host_to_device_stream()->parent(), scoped_buffer)) { - device->host_to_device_stream()->ThenWaitFor(device->compute_stream()); + local_device->host_to_device_stream()->parent(), scoped_buffer)) { + local_device->host_to_device_stream()->ThenWaitFor( + local_device->compute_stream()); } std::shared_ptr definition_event = @@ -344,16 +346,15 @@ StatusOr> PyLocalBuffer::FromLiterals( // TODO(makro): Use move capture once C++ 14 features are available. auto leaves = std::make_shared>( std::move(leaves_literals)); - auto transfer_h2d = [client, transfer_manager, device, device_ordinal, - device_buffer, compact_shape, leaves, - leaves_reference]() { + auto transfer_h2d = [client, transfer_manager, local_device, device_buffer, + compact_shape, leaves, leaves_reference]() { // This function uses TF_CHECK_OK and ValueOrDie() since we have no way to // report failures from a callback. However, the operations here are // unlikely to fail and not recoverable even if we were to fail: DMAs to // memory that has already been allocated, and a possible Event allocation. ShapedBuffer buffer = device_buffer->AsShapedBuffer(compact_shape); TF_CHECK_OK(transfer_manager->WriteTupleIndexTablesAsync( - device->host_to_device_stream(), buffer)); + local_device->host_to_device_stream(), buffer)); std::vector> staging_buffers; staging_buffers.reserve(leaves->size()); auto it = leaves->begin(); @@ -363,7 +364,7 @@ StatusOr> PyLocalBuffer::FromLiterals( ShapedBuffer leaf( indexed_shape.shape, transfer_manager->HostShapeToDeviceShape(indexed_shape.shape), - client->client()->platform(), device_ordinal); + client->client()->platform(), local_device->device_ordinal()); leaf.buffers().CopySubtreeFrom(buffer.buffers(), indexed_shape.index, {}); // If applicable on the backend, stage the transfer via host memory @@ -379,51 +380,53 @@ StatusOr> PyLocalBuffer::FromLiterals( BorrowingLiteral literal(static_cast(staging_buffer.get()), it->shape()); TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync( - device->host_to_device_stream(), literal, leaf)); + local_device->host_to_device_stream(), literal, leaf)); staging_buffers.push_back(std::move(staging_buffer)); } else { // Otherwise, just transfer the literal. TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync( - device->host_to_device_stream(), *it, leaf)); + local_device->host_to_device_stream(), *it, leaf)); } ++it; } EventPool::Handle event = - device->event_pool() - .ThenAllocateAndRecordEvent(device->host_to_device_stream()) + local_device->event_pool() + .ThenAllocateAndRecordEvent(local_device->host_to_device_stream()) .ValueOrDie(); // Sets the buffer definition event. Note: this has the side effect of // unblocking any host threads that may have been waiting to consume the // buffer. device_buffer->definition_event()->SetDefinitionEvent( - std::move(event), device->host_to_device_stream()); + std::move(event), local_device->host_to_device_stream()); - if (device->synchronous_deallocation()) { - device->ThenRelease(device->host_to_device_stream(), device_buffer); + if (local_device->synchronous_deallocation()) { + local_device->ThenRelease(local_device->host_to_device_stream(), + device_buffer); } - device->ThenRelease( - device->host_to_device_stream(), + local_device->ThenRelease( + local_device->host_to_device_stream(), std::make_pair(leaves_reference, std::move(staging_buffers))); }; client->h2d_transfer_pool()->Schedule(transfer_h2d); - return absl::make_unique( - compact_shape, std::move(device_buffer), std::move(client)); + return absl::make_unique(compact_shape, + std::move(device_buffer), + std::move(client), std::move(device)); } /* static */ StatusOr> PyLocalBuffer::MakeTuple( const std::vector buffers, - std::shared_ptr client, int device_ordinal) { - TF_RETURN_IF_ERROR( - client->CheckDeviceOrdinal(device_ordinal, "PyLocalBuffer::MakeTuple")); + std::shared_ptr client, std::shared_ptr device) { + TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, + device->GetLocalDeviceState()); std::vector host_shapes; std::vector> device_buffers; host_shapes.reserve(buffers.size()); device_buffers.reserve(buffers.size()); for (const PyLocalBuffer* buffer : buffers) { - TF_RET_CHECK(buffer->device_ordinal() == device_ordinal); + TF_RET_CHECK(buffer->device().get() == device.get()); std::shared_ptr device_buffer = buffer->DeviceBuffer(); if (!device_buffer) { return InvalidArgument( @@ -436,45 +439,48 @@ StatusOr> PyLocalBuffer::FromLiterals( se::DeviceMemoryAllocator* allocator = client->allocator(); TransferManager* transfer_manager = client->client()->backend().transfer_manager(); - DeviceState& device = client->device_state(device_ordinal); auto definition_event = std::make_shared(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr tuple_buffer, - SharedDeviceBuffer::MakeTuple(device_buffers, transfer_manager, allocator, - device_ordinal, definition_event)); + TF_ASSIGN_OR_RETURN(std::shared_ptr tuple_buffer, + SharedDeviceBuffer::MakeTuple( + device_buffers, transfer_manager, allocator, + local_device->device_ordinal(), definition_event)); auto buffer = absl::make_unique( - ShapeUtil::MakeTupleShape(host_shapes), tuple_buffer, std::move(client)); + ShapeUtil::MakeTupleShape(host_shapes), tuple_buffer, std::move(client), + std::move(device)); // TODO(phawkins): extend TransferManager so we do not need to form a full // ShapedBuffer just to write the root tuple index table. TF_ASSIGN_OR_RETURN(ShapedBuffer shaped_buffer, buffer->AsShapedBuffer()); if (!transfer_manager->CanShapedBufferBeAccessedNow( - device.host_to_device_stream()->parent(), shaped_buffer)) { + local_device->host_to_device_stream()->parent(), shaped_buffer)) { // Wait for the compute stream so that memory allocations are synchronized. - device.host_to_device_stream()->ThenWaitFor(device.compute_stream()); + local_device->host_to_device_stream()->ThenWaitFor( + local_device->compute_stream()); } TF_RETURN_IF_ERROR(transfer_manager->WriteRootTupleIndexTable( - device.host_to_device_stream(), shaped_buffer)); + local_device->host_to_device_stream(), shaped_buffer)); TF_ASSIGN_OR_RETURN(EventPool::Handle event, - device.event_pool().ThenAllocateAndRecordEvent( - device.host_to_device_stream())); + local_device->event_pool().ThenAllocateAndRecordEvent( + local_device->host_to_device_stream())); definition_event->SetDefinitionEvent(std::move(event), - device.host_to_device_stream()); + local_device->host_to_device_stream()); - if (device.synchronous_deallocation()) { - device.ThenRelease(device.host_to_device_stream(), std::move(tuple_buffer)); + if (local_device->synchronous_deallocation()) { + local_device->ThenRelease(local_device->host_to_device_stream(), + std::move(tuple_buffer)); } return buffer; } PyLocalBuffer::PyLocalBuffer(Shape on_host_shape, std::shared_ptr device_buffer, - std::shared_ptr client) + std::shared_ptr client, + std::shared_ptr device) : client_(std::move(client)), on_host_shape_(std::move(on_host_shape)), - device_ordinal_(device_buffer->device_ordinal()), + device_(std::move(device)), device_buffer_(std::move(device_buffer)) {} void PyLocalBuffer::Delete() { @@ -499,8 +505,7 @@ Status PyLocalBuffer::CopyToHostAsync() { } host_value = host_value_ = std::make_shared(); } - se::Stream* stream = - client_->device_state(device_ordinal_).device_to_host_stream(); + se::Stream* stream = device_->local_device_state()->GetDeviceToHostStream(); WaitForBufferDefinitionEventsOnStream(*device_buffer, stream); host_value->value = std::make_shared(on_host_shape_); TF_ASSIGN_OR_RETURN(ShapedBuffer shaped_buffer, AsShapedBuffer()); @@ -564,36 +569,38 @@ PyLocalBuffer::DestructureTuple() { for (int64 i = 0; i < num_children; ++i) { results.push_back(absl::make_unique( on_host_shape_.tuple_shapes(i), device_buffer_->children().at(i), - client_)); + client_, device_)); } return results; } StatusOr> PyLocalBuffer::CopyToDevice( - int dst_device_ordinal) { + std::shared_ptr dst_device) { tensorflow::profiler::TraceMe traceme("PyLocalBuffer::CopyToDevice"); std::shared_ptr src_device_buffer = DeviceBuffer(); - if (dst_device_ordinal == device_ordinal_) { - return absl::make_unique(on_host_shape_, src_device_buffer, - client_); - } - int transfer_device_ordinal = client_->EnqueueD2DTransfersOnSrcStream() - ? device_ordinal_ - : dst_device_ordinal; - DeviceState& transfer_device = client_->device_state(transfer_device_ordinal); - const DeviceState& dst_device = client_->device_state(dst_device_ordinal); + TF_ASSIGN_OR_RETURN(LocalDeviceState * dst_local_device, + dst_device->GetLocalDeviceState()); - se::Stream* transfer_stream = transfer_device.GetDeviceToDeviceStream(); + if (dst_device.get() == device_.get()) { + return absl::make_unique(on_host_shape_, src_device_buffer, + client_, device_); + } + LocalDeviceState* transfer_local_device = + client_->EnqueueD2DTransfersOnSrcStream() ? device_->local_device_state() + : dst_local_device; + + se::Stream* transfer_stream = + transfer_local_device->GetDeviceToDeviceStream(); TransferManager* transfer_manager = client_->client()->backend().transfer_manager(); - TF_ASSIGN_OR_RETURN( - ScopedShapedBuffer dst_buffer, - transfer_manager->AllocateScopedShapedBuffer( - on_host_shape_, client_->allocator(), dst_device_ordinal)); + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer dst_buffer, + transfer_manager->AllocateScopedShapedBuffer( + on_host_shape_, client_->allocator(), + dst_local_device->device_ordinal())); if (!transfer_manager->CanShapedBufferBeAccessedNow( - dst_device.compute_stream()->parent(), dst_buffer)) { - transfer_stream->ThenWaitFor(dst_device.compute_stream()); + dst_local_device->compute_stream()->parent(), dst_buffer)) { + transfer_stream->ThenWaitFor(dst_local_device->compute_stream()); } TF_ASSIGN_OR_RETURN(ShapedBuffer src_buffer, AsShapedBuffer()); @@ -607,37 +614,39 @@ StatusOr> PyLocalBuffer::CopyToDevice( TF_RET_CHECK(input_buffer.size() == output_buffer.size()) << "input: " << input_buffer.size() << " output: " << output_buffer.size(); - TF_RETURN_IF_ERROR(transfer_device.ThenMemcpyDeviceToDevice( - transfer_stream, dst_device.compute_stream(), input_buffer, + TF_RETURN_IF_ERROR(transfer_local_device->ThenMemcpyDeviceToDevice( + transfer_stream, dst_local_device->compute_stream(), input_buffer, output_buffer)); } // We hold on to the `src_device_buffer` until the transfer is finished. - transfer_device.ThenRelease(transfer_stream, std::move(src_device_buffer)); + transfer_local_device->ThenRelease(transfer_stream, + std::move(src_device_buffer)); // Write new tuple buffers. The destination buffers have different addresses, // so we must construct tuple buffers from scratch instead of copying them. if (dst_buffer.on_device_shape().IsTuple()) { TF_RETURN_IF_ERROR(transfer_manager->WriteTupleIndexTablesAsync( - dst_device.host_to_device_stream(), dst_buffer)); + dst_local_device->host_to_device_stream(), dst_buffer)); // We need a single definition event, so make the device to device stream // wait for the stream that wrote the tuple index tables on the destination // device. - transfer_stream->ThenWaitFor(dst_device.host_to_device_stream()); + transfer_stream->ThenWaitFor(dst_local_device->host_to_device_stream()); } auto definition_event = std::make_shared(); TF_ASSIGN_OR_RETURN( EventPool::Handle event, - transfer_device.event_pool().ThenAllocateAndRecordEvent(transfer_stream)); + transfer_local_device->event_pool().ThenAllocateAndRecordEvent( + transfer_stream)); definition_event->SetDefinitionEvent(std::move(event), transfer_stream); std::shared_ptr dst_device_buffer = SharedDeviceBuffer::FromScopedShapedBuffer(std::move(dst_buffer), definition_event); return absl::make_unique( - on_host_shape_, std::move(dst_device_buffer), client_); + on_host_shape_, std::move(dst_device_buffer), client_, dst_device); } Status PyLocalBuffer::BlockHostUntilReady() { @@ -652,7 +661,7 @@ Status PyLocalBuffer::BlockHostUntilReady() { // be an issue, we could either use a separate stream for this purpose, or // poll for the buffer definition events. se::Stream* stream = client_->device_state(device_buffer->device_ordinal()) - .device_to_host_stream(); + .GetDeviceToHostStream(); WaitForBufferDefinitionEventsOnStream(*device_buffer, stream); return stream->BlockHostUntilDone(); } @@ -694,7 +703,7 @@ StatusOr> PyLocalExecutable::ExecuteHelper( const int device_id = (*device_assignment_)(replica, 0); std::shared_ptr device = LookupDevice(*client_, device_id); CHECK_EQ(device->host_id(), client_->host_id()); - int device_ordinal = device->local_device_ordinal(); + int device_ordinal = device->local_device_state()->device_ordinal(); tensorflow::profiler::TraceMe traceme("LocalExecutable::Execute"); VLOG(3) << "Replica " << replica << " mapped to device ordinal for execution: " << device_ordinal; @@ -729,7 +738,7 @@ StatusOr> PyLocalExecutable::ExecuteHelper( << " buffer: " << argument_buffers.back().ToString(); } - DeviceState* device_state = &client_->device_state(device_ordinal); + LocalDeviceState* device_state = &client_->device_state(device_ordinal); // The choice of where we wait is arbitrary; the reason for the wait is pacing // to avoid problems such as memory fragmentation and running ahead too far, // not for correctness. Placing it before the executable launch allows the @@ -782,7 +791,7 @@ StatusOr> PyLocalExecutable::ExecuteHelper( device_state->compute_stream(), std::make_tuple(executable_, compute_reservation, device_assignment_)); return absl::make_unique(on_host_shape, std::move(out_buffer), - client_); + client_, device); } StatusOr> PyLocalExecutable::Execute( @@ -833,8 +842,7 @@ PyLocalExecutable::ExecutePerReplica( for (int i = 0; i < num_local_replicas; ++i) { const int replica = local_replicas_[i]; std::shared_ptr device = local_devices_[i]; - const DeviceState& device_state = - client_->device_state(device->local_device_ordinal()); + const LocalDeviceState& device_state = *device->local_device_state(); device_state.execute_thread()->Schedule([&, replica, i] { results[i] = ExecuteHelper(argument_handles[i], replica, run_id); diff --git a/tensorflow/compiler/xla/python/local_client.h b/tensorflow/compiler/xla/python/local_client.h index 3f13f62241f..e0a21ad6f1e 100644 --- a/tensorflow/compiler/xla/python/local_client.h +++ b/tensorflow/compiler/xla/python/local_client.h @@ -27,7 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/compiler/xla/python/device_state.h" +#include "tensorflow/compiler/xla/python/local_device_state.h" #include "tensorflow/compiler/xla/python/shared_device_buffer.h" #include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" @@ -43,10 +43,10 @@ class PyLocalExecutable; class Device { public: - explicit Device(int id, int local_device_ordinal, + explicit Device(int id, std::unique_ptr local_device_state, absl::string_view platform_name, int host_id = 0) : id_(id), - local_device_ordinal_(local_device_ordinal), + local_device_state_(std::move(local_device_state)), host_id_(host_id), platform_name_(platform_name) {} virtual ~Device() {} @@ -56,13 +56,17 @@ class Device { // hosts' devices. This is the ID that should be used in a DeviceAssignment. int id() const { return id_; } - // If this is a device local to this host, the local index of this device as - // according to the underlying backend. Unlike id(), this will always be in - // the range [0, num_local_devices), and can be used with the xla::LocalClient - // and xla::Backend APIs. - // - // -1 if this device is not local to this host. - int local_device_ordinal() const { return local_device_ordinal_; } + // If this is a device local to this host, returns a LocalDeviceState object + // that can be used to manipulate the device. Returns nullptr if the device is + // not local to this host. + LocalDeviceState* local_device_state() const { + return local_device_state_.get(); + } + + // If this is a device local to this host, returns a LocalDeviceState object + // that can be used to manipulate the device. Returns an error if the device + // is not local to this host. + StatusOr GetLocalDeviceState() const; // The ID of this device's host. This is always 0 on single-host platforms. int host_id() const { return host_id_; } @@ -73,7 +77,7 @@ class Device { private: const int id_; - const int local_device_ordinal_; + const std::unique_ptr local_device_state_; const int host_id_; const std::string platform_name_; }; @@ -123,13 +127,14 @@ class PyLocalClient { explicit PyLocalClient( std::string platform_name, LocalClient* client, std::vector> devices, int host_id, - std::vector> device_states, std::unique_ptr allocator, std::unique_ptr host_memory_allocator); virtual ~PyLocalClient() = default; - Status TransferToInfeed(const LiteralSlice& literal, int device_ordinal); - StatusOr TransferFromOutfeed(const Shape& shape, int device_ordinal); + Status TransferToInfeed(const LiteralSlice& literal, + std::shared_ptr device); + StatusOr TransferFromOutfeed(const Shape& shape, + std::shared_ptr device); virtual StatusOr GetDefaultDeviceAssignment( int num_replicas) const; @@ -146,8 +151,8 @@ class PyLocalClient { int host_id() const { return host_id_; } const std::string& platform_name() const { return platform_name_; } - DeviceState& device_state(int device_ordinal) const { - return *device_states_.at(device_ordinal); + LocalDeviceState& device_state(int device_ordinal) const { + return *local_devices_.at(device_ordinal)->local_device_state(); } LocalClient* client() const { return client_; } @@ -178,10 +183,6 @@ class PyLocalClient { const std::string& serialized, std::shared_ptr this_shared) const; - // Returns a bad status containing `caller_name` if `device_ordinal` doesn't - // correspond to a local device. - Status CheckDeviceOrdinal(int device_ordinal, absl::string_view caller_name); - protected: std::string platform_name_; LocalClient* client_; @@ -194,8 +195,6 @@ class PyLocalClient { std::vector> local_devices_; int host_id_; - // Device states local to this host. Indexed by local device ordinal. - std::vector> device_states_; se::DeviceMemoryAllocator* allocator_; std::unique_ptr owned_allocator_; @@ -219,16 +218,16 @@ class PyLocalBuffer { static StatusOr> FromLiterals( std::vector leaves_literals, const Shape& tuple_shape, std::shared_ptr leaves_reference, - std::shared_ptr client, int device_ordinal); + std::shared_ptr client, std::shared_ptr device); static StatusOr> MakeTuple( const std::vector buffers, - std::shared_ptr client, int device_ordinal); + std::shared_ptr client, std::shared_ptr device); - PyLocalBuffer() = default; PyLocalBuffer(Shape on_host_shape, std::shared_ptr device_buffer, - std::shared_ptr client); + std::shared_ptr client, + std::shared_ptr device); PyLocalBuffer(const PyLocalBuffer&) = delete; PyLocalBuffer(PyLocalBuffer&&) = delete; @@ -236,7 +235,7 @@ class PyLocalBuffer { PyLocalBuffer& operator=(PyLocalBuffer&&) = delete; const Shape& on_host_shape() const { return on_host_shape_; } - int device_ordinal() const { return device_ordinal_; } + std::shared_ptr device() const { return device_; } const std::string& platform_name() const { return client_->platform_name(); } std::shared_ptr client() const { return client_; } @@ -266,8 +265,9 @@ class PyLocalBuffer { // Destructures a tuple-valued PyLocalBuffer into its constituent elements. StatusOr>> DestructureTuple(); - // Copies the buffer to device `dst_device_ordinal`. - StatusOr> CopyToDevice(int dst_device_ordinal); + // Copies the buffer to device `dst_device`. + StatusOr> CopyToDevice( + std::shared_ptr dst_device); // Blocks the host until the buffer's value has been computed and is ready for // immediate use on the device. Useful in particular for timing benchmarks. @@ -276,7 +276,7 @@ class PyLocalBuffer { private: const std::shared_ptr client_; const Shape on_host_shape_; - const int device_ordinal_; + const std::shared_ptr device_; mutable absl::Mutex mu_; std::shared_ptr device_buffer_ GUARDED_BY(mu_); diff --git a/tensorflow/compiler/xla/python/device_state.cc b/tensorflow/compiler/xla/python/local_device_state.cc similarity index 72% rename from tensorflow/compiler/xla/python/device_state.cc rename to tensorflow/compiler/xla/python/local_device_state.cc index 3403d882e92..0373d4b642b 100644 --- a/tensorflow/compiler/xla/python/device_state.cc +++ b/tensorflow/compiler/xla/python/local_device_state.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/device_state.h" +#include "tensorflow/compiler/xla/python/local_device_state.h" #include #include @@ -24,20 +24,25 @@ limitations under the License. namespace xla { -DeviceState::DeviceState(se::StreamExecutor* executor, - bool synchronous_deallocation, bool asynchronous, - bool allow_event_reuse) +LocalDeviceState::LocalDeviceState(se::StreamExecutor* executor, + bool synchronous_deallocation, + bool asynchronous, bool allow_event_reuse) : synchronous_deallocation_(synchronous_deallocation), event_pool_(allow_event_reuse), - compute_semaphore_(/*capacity=*/asynchronous ? 32 : 1) { + compute_semaphore_(/*capacity=*/asynchronous ? 32 : 1), + executor_(executor) { compute_stream_ = absl::make_unique(executor); host_to_device_stream_ = absl::make_unique(executor); - device_to_host_stream_ = absl::make_unique(executor); callback_stream_ = absl::make_unique(executor); compute_stream_->Init(); host_to_device_stream_->Init(); - device_to_host_stream_->Init(); callback_stream_->Init(); + device_to_host_streams_.reserve(kNumDeviceToHostStreams); + for (int i = 0; i < kNumDeviceToHostStreams; ++i) { + auto stream = absl::make_unique(executor); + stream->Init(); + device_to_host_streams_.push_back(std::move(stream)); + } device_to_device_streams_.reserve(kNumDeviceToDeviceStreams); for (int i = 0; i < kNumDeviceToDeviceStreams; ++i) { auto stream = absl::make_unique(executor); @@ -50,14 +55,14 @@ DeviceState::DeviceState(se::StreamExecutor* executor, "py_xla_callback"); } -DeviceState::~DeviceState() { +LocalDeviceState::~LocalDeviceState() { Status status = SynchronizeAllActivity(); if (!status.ok()) { LOG(ERROR) << "Error when closing device: " << status; } } -Status DeviceState::SynchronizeAllActivity() { +Status LocalDeviceState::SynchronizeAllActivity() { Status status; // TODO(phawkins): in theory the call to SynchronizeAllActivity below should // suffice. However on the Host platform SynchronizeAllActivity is a dummy @@ -73,10 +78,9 @@ Status DeviceState::SynchronizeAllActivity() { return status; } -Status DeviceState::ThenMemcpyDeviceToDevice(se::Stream* transfer_stream, - se::Stream* dst_stream, - se::DeviceMemoryBase src_buffer, - se::DeviceMemoryBase dst_buffer) { +Status LocalDeviceState::ThenMemcpyDeviceToDevice( + se::Stream* transfer_stream, se::Stream* dst_stream, + se::DeviceMemoryBase src_buffer, se::DeviceMemoryBase dst_buffer) { // The default implementation simply calls ThenMemcpyD2D, and assumes that // the buffer addresses identify the devices. This does not work // on all platforms; this method is virtual so it can be overridden. @@ -84,14 +88,22 @@ Status DeviceState::ThenMemcpyDeviceToDevice(se::Stream* transfer_stream, return Status::OK(); } -void DeviceState::ThenExecuteOnCallbackThread( +void LocalDeviceState::ThenExecuteOnCallbackThread( se::Stream* stream, std::function callback) const { stream->ThenDoHostCallback([this, callback]() mutable { callback_thread_->Schedule(std::move(callback)); }); } -se::Stream* DeviceState::GetDeviceToDeviceStream() { +se::Stream* LocalDeviceState::GetDeviceToHostStream() { + absl::MutexLock lock(&mu_); + int i = next_device_to_host_stream_; + next_device_to_host_stream_ = + (next_device_to_host_stream_ + 1) % device_to_host_streams_.size(); + return device_to_host_streams_.at(i).get(); +} + +se::Stream* LocalDeviceState::GetDeviceToDeviceStream() { absl::MutexLock lock(&mu_); int i = next_device_to_device_stream_; next_device_to_device_stream_ = diff --git a/tensorflow/compiler/xla/python/device_state.h b/tensorflow/compiler/xla/python/local_device_state.h similarity index 82% rename from tensorflow/compiler/xla/python/device_state.h rename to tensorflow/compiler/xla/python/local_device_state.h index 3772c03fc59..7348b9c59f0 100644 --- a/tensorflow/compiler/xla/python/device_state.h +++ b/tensorflow/compiler/xla/python/local_device_state.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DEVICE_STATE_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_DEVICE_STATE_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_DEVICE_STATE_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_DEVICE_STATE_H_ #include #include @@ -29,9 +29,9 @@ limitations under the License. namespace xla { // Class that encapsulates state relating to a device (e.g., a GPU) on which we -// can perform computation and transfers. DeviceState objects only exist for -// devices local to this host. -class DeviceState { +// can perform computation and transfers. LocalDeviceState objects only exist +// for devices local to this host. +class LocalDeviceState { public: // If synchronous_deallocation is true, the host must not free buffers until // compute/transfers that use those buffers have completed. For example, this @@ -40,9 +40,12 @@ class DeviceState { // // If asynchronous is false, the host will synchronize to the device after // each execution or transfer. This is intended for debugging only. - DeviceState(se::StreamExecutor* executor, bool synchronous_deallocation, - bool asynchronous, bool allow_event_reuse); - virtual ~DeviceState(); + LocalDeviceState(se::StreamExecutor* executor, bool synchronous_deallocation, + bool asynchronous, bool allow_event_reuse); + virtual ~LocalDeviceState(); + + // StreamExecutor (local) device ordinal. + int device_ordinal() const { return executor_->device_ordinal(); } bool synchronous_deallocation() const { return synchronous_deallocation_; } @@ -52,9 +55,10 @@ class DeviceState { se::Stream* host_to_device_stream() const { return host_to_device_stream_.get(); } - se::Stream* device_to_host_stream() const { - return device_to_host_stream_.get(); - } + + // Returns a device to host stream. Allocates streams in a round-robin fashion + // amongst the available streams. + se::Stream* GetDeviceToHostStream(); // Returns a device to device stream. Allocates streams in a round-robin // fashion amongst the available streams. @@ -104,15 +108,18 @@ class DeviceState { // stream by the host ahead of the device. Semaphore compute_semaphore_; + se::StreamExecutor* executor_; std::unique_ptr compute_stream_; std::unique_ptr host_to_device_stream_; - std::unique_ptr device_to_host_stream_; + std::vector> device_to_host_streams_; std::vector> device_to_device_streams_; - // Number of device-to-device streams to create in the multistream case. + // Number of device-to-host and device-to-device streams. + static constexpr int kNumDeviceToHostStreams = 4; static constexpr int kNumDeviceToDeviceStreams = 4; absl::Mutex mu_; + int next_device_to_host_stream_ GUARDED_BY(mu_) = 0; int next_device_to_device_stream_ GUARDED_BY(mu_) = 0; // Callback stream is used for running short host-side callbacks after device @@ -132,4 +139,4 @@ class DeviceState { } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DEVICE_STATE_H_ +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_DEVICE_STATE_H_ diff --git a/tensorflow/compiler/xla/python/tpu_driver/BUILD b/tensorflow/compiler/xla/python/tpu_driver/BUILD index 99a07c31256..b796fe8c541 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/BUILD +++ b/tensorflow/compiler/xla/python/tpu_driver/BUILD @@ -31,11 +31,6 @@ tf_proto_library_cc( use_grpc_namespace = True, ) -cc_library( - name = "c_api", - hdrs = ["c_api.h"], -) - cc_library( name = "tpu_driver", srcs = [ @@ -66,6 +61,7 @@ cc_library( hdrs = ["grpc_tpu_driver.h"], deps = [ ":tpu_driver", + "//tensorflow:grpc++", "//tensorflow/core/platform:logging", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:util", @@ -77,6 +73,25 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "external_tpu_driver", + srcs = ["external_tpu_driver.cc"], + deps = [ + ":tpu_driver", + "@com_google_absl//absl/strings:str_format", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/core/platform:logging", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/service:hlo_proto_cc", + ":tpu_service_proto_cc", + ":tpu_driver_proto_cc", + "//tensorflow/compiler/xla/python/tpu_driver/client:c_api", + ] + external_deps(), + alwayslink = 1, +) + cc_library( name = "recording_tpu_driver", srcs = [ diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/BUILD b/tensorflow/compiler/xla/python/tpu_driver/client/BUILD index d5d492de054..932bee43ffc 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/BUILD +++ b/tensorflow/compiler/xla/python/tpu_driver/client/BUILD @@ -19,7 +19,6 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:executable_build_options", - "//tensorflow/compiler/xla/python:device_state", "//tensorflow/compiler/xla/python:local_client", "//tensorflow/compiler/xla/python:semaphore", "//tensorflow/compiler/xla/python/tpu_driver", @@ -76,3 +75,8 @@ py_library( "//third_party/py/numpy", ], ) + +cc_library( + name = "c_api", + hdrs = ["c_api.h"], +) diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/c_api.h b/tensorflow/compiler/xla/python/tpu_driver/client/c_api.h new file mode 100644 index 00000000000..8c967d6e0a1 --- /dev/null +++ b/tensorflow/compiler/xla/python/tpu_driver/client/c_api.h @@ -0,0 +1,227 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_CLIENT_C_API_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_CLIENT_C_API_H_ + +#include + +#define TPUDRIVER_CAPI_EXPORT __attribute__((visibility("default"))) + +#ifdef __cplusplus +extern "C" { +#endif + +struct TpuDriverFn; + +typedef struct TpuDriver TpuDriver; + +typedef struct TpuEvent TpuEvent; + +typedef struct TpuBufferHandleInternal TpuBufferHandleInternal; + +typedef struct TpuCompiledProgramHandleInternal + TpuCompiledProgramHandleInternal; + +typedef struct TpuLoadedProgramHandleInternal TpuLoadedProgramHandleInternal; +typedef struct HloProtoInternal HloProtoInternal; + +typedef struct TpuBufferHandle { + TpuBufferHandleInternal* internal_handle; + TpuEvent* event; + int64_t size_in_bytes; +} TpuBufferHandle; + +typedef struct TpuCompiledProgramHandle { + TpuCompiledProgramHandleInternal* internal_handle; + TpuEvent* event; +} TpuCompiledProgramHandle; + +typedef struct TpuLoadedProgramHandle { + TpuLoadedProgramHandleInternal* internal_handle; + TpuEvent* event; +} TpuLoadedProgramHandle; + +typedef struct HloProto { + HloProtoInternal* internal_hlo_proto; +} HloProto; + +typedef struct DeviceAssignment { + int replica_count; + int computation_count; +} DeviceAssignment; + +typedef struct TpuStatus { + int32_t code; + char* msg; +} TpuStatus; + +typedef struct CompiledProgramShape { + struct TpuStatus* status; + void* bytes; + int32_t size; +} CompiledProgramShape; + +typedef void(PrototypeTpuDriver_Initialize)(struct TpuDriverFn* driver_fn); +typedef struct TpuDriver*(PrototypeTpuDriver_Open)(const char* worker); +typedef void(PrototypeTpuDriver_Close)(struct TpuDriver* driver); + +// TODO(frankchn): Make this not a hard-coded constant. +const int32_t MemoryRegion_HBM = 1; + +typedef struct TpuCompiledProgramHandle*(PrototypeTpuDriver_CompileProgram)( + struct TpuDriver* driver, const struct HloProto hlo_proto, + int32_t num_replicas, int32_t eventc, struct TpuEvent** eventv); + +typedef struct TpuCompiledProgramHandle*( + PrototypeTpuDriver_CompileProgramFromText)(struct TpuDriver* driver, + const char* hlo_text, + int32_t num_replicas, + int32_t eventc, + struct TpuEvent** eventv); + +typedef struct TpuLoadedProgramHandle*(PrototypeTpuDriver_LoadProgram)( + struct TpuDriver* driver, int32_t core_id, + const struct TpuCompiledProgramHandle* compiled_program_handle, + int32_t eventc, struct TpuEvent** eventv); + +typedef struct TpuEvent*(PrototypeTpuDriver_UnloadProgram)( + struct TpuDriver* driver, + struct TpuLoadedProgramHandle* loaded_program_handle, int32_t eventc, + struct TpuEvent** eventv); + +typedef struct TpuEvent*(PrototypeTpuDriver_ExecuteProgram)( + struct TpuDriver* driver, struct TpuLoadedProgramHandle* handle, + int32_t inputc, struct TpuBufferHandle** input_buffer_handle, + int32_t outputc, struct TpuBufferHandle** output_buffer_handle, + struct DeviceAssignment device_assignment, int32_t eventc, + struct TpuEvent** eventv); + +typedef struct TpuBufferHandle*(PrototypeTpuDriver_AllocateTuple)( + struct TpuDriver* driver, int32_t core_id, int32_t memory_region, + int32_t bufferc, struct TpuBufferHandle** buffer_handle, int32_t eventc, + struct TpuEvent** eventv); + +typedef struct TpuBufferHandle*(PrototypeTpuDriver_Allocate)( + struct TpuDriver* driver, int32_t core_id, int32_t memory_region, + int64_t num_bytes, int32_t eventc, struct TpuEvent** eventv); + +typedef struct TpuEvent*(PrototypeTpuDriver_Deallocate)( + struct TpuDriver* driver, struct TpuBufferHandle* buffer_handle, + int32_t eventc, struct TpuEvent** eventv); + +typedef struct TpuEvent*(PrototypeTpuDriver_TransferToDevice)( + struct TpuDriver* driver, const void* src, struct TpuBufferHandle* dst, + int32_t eventc, struct TpuEvent** eventv); + +typedef struct TpuEvent*(PrototypeTpuDriver_TransferFromDevice)( + struct TpuDriver* driver, struct TpuBufferHandle* src, void* dst, + int32_t eventc, struct TpuEvent** eventv); + +typedef struct TpuEvent*(PrototypeTpuDriver_TransferFromDeviceToDevice)( + struct TpuDriver* driver, struct TpuBufferHandle* src, + struct TpuBufferHandle* dst, int32_t eventc, struct TpuEvent** eventv); + +typedef void(PrototypeTpuDriver_CreateDeviceAssignment)(int replica_count, + int computation_count); + +typedef struct CompiledProgramShape*( + PrototypeTpuDriver_GetCompiledProgramShape)( + struct TpuCompiledProgramHandle* handle); + +typedef void(PrototypeTpuDriver_FreeCompiledProgramShape)( + struct CompiledProgramShape* shape); + +typedef void(PrototypeTpuDriver_EventAddCallback)( + struct TpuEvent* event, + void (*callback_fn)(struct TpuStatus*, void* additional_info), + void* additional_info); + +typedef struct TpuStatus*(PrototypeTpuDriver_EventAwait)(struct TpuEvent* event, + int64_t timeout_in_us); + +typedef void(PrototypeTpuDriver_FreeEvent)(struct TpuEvent* event); + +typedef void(PrototypeTpuDriver_FreeStatus)(struct TpuStatus* status); + +typedef const char*(PrototypeTpuDriver_Version)(); + +TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Initialize TpuDriver_Initialize; +TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Open TpuDriver_Open; +TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Close TpuDriver_Close; +TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_CompileProgram + TpuDriver_CompileProgram; +TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_CompileProgramFromText + TpuDriver_CompileProgramFromText; +TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_LoadProgram + TpuDriver_LoadProgram; +TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_UnloadProgram + TpuDriver_UnloadProgram; +TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_ExecuteProgram + TpuDriver_ExecuteProgram; +TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_AllocateTuple + TpuDriver_AllocateTuple; +TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Allocate TpuDriver_Allocate; +TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Deallocate TpuDriver_Deallocate; +TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_TransferToDevice + TpuDriver_TransferToDevice; +TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_TransferFromDevice + TpuDriver_TransferFromDevice; +TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_TransferFromDeviceToDevice + TpuDriver_TransferFromDeviceToDevice; +TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_GetCompiledProgramShape + TpuDriver_GetCompiledProgramShape; +TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_FreeCompiledProgramShape + TpuDriver_FreeCompiledProgramShape; +TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_EventAddCallback + TpuDriver_EventAddCallback; +TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_EventAwait TpuDriver_EventAwait; +TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_FreeEvent TpuDriver_FreeEvent; +TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_FreeStatus TpuDriver_FreeStatus; +TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Version TpuDriver_Version; + +#ifdef __cplusplus +} +#endif + +struct TpuDriverFn { + PrototypeTpuDriver_Open* TpuDriver_Open; // NOLINT + PrototypeTpuDriver_Close* TpuDriver_Close; // NOLINT + PrototypeTpuDriver_CompileProgram* TpuDriver_CompileProgram; // NOLINT + PrototypeTpuDriver_CompileProgramFromText* + TpuDriver_CompileProgramFromText; // NOLINT + PrototypeTpuDriver_LoadProgram* TpuDriver_LoadProgram; // NOLINT + PrototypeTpuDriver_UnloadProgram* TpuDriver_UnloadProgram; // NOLINT + PrototypeTpuDriver_ExecuteProgram* TpuDriver_ExecuteProgram; // NOLINT + PrototypeTpuDriver_AllocateTuple* TpuDriver_AllocateTuple; // NOLINT + PrototypeTpuDriver_Allocate* TpuDriver_Allocate; // NOLINT + PrototypeTpuDriver_Deallocate* TpuDriver_Deallocate; // NOLINT + PrototypeTpuDriver_TransferToDevice* TpuDriver_TransferToDevice; // NOLINT + PrototypeTpuDriver_TransferFromDevice* + TpuDriver_TransferFromDevice; // NOLINT + PrototypeTpuDriver_TransferFromDeviceToDevice* + TpuDriver_TransferFromDeviceToDevice; // NOLINT + PrototypeTpuDriver_GetCompiledProgramShape* + TpuDriver_GetCompiledProgramShape; // NOLINT + PrototypeTpuDriver_FreeCompiledProgramShape* + TpuDriver_FreeCompiledProgramShape; // NOLINT + PrototypeTpuDriver_EventAddCallback* TpuDriver_EventAddCallback; // NOLINT + PrototypeTpuDriver_EventAwait* TpuDriver_EventAwait; // NOLINT + PrototypeTpuDriver_FreeEvent* TpuDriver_FreeEvent; // NOLINT + PrototypeTpuDriver_FreeStatus* TpuDriver_FreeStatus; // NOLINT + PrototypeTpuDriver_Version* TpuDriver_Version; // NOLINT +}; + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_CLIENT_C_API_H_ diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/c_api_client.c b/tensorflow/compiler/xla/python/tpu_driver/client/c_api_client.c index 70ab4af85fd..5fabc8380a5 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/c_api_client.c +++ b/tensorflow/compiler/xla/python/tpu_driver/client/c_api_client.c @@ -13,15 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// Before you start, make sure c_api.so, c_api.h and and c_api_client.c are in +// the same working directory. +// // To compile: gcc -o c_api_client c_api_client.c -ldl -// To run, make sure c_api.so and c_api_client in the same directory, and then -// sudo ./c_api_client +// To run: sudo ./c_api_client #include #include #include -int main(int argc, char** argv) { +#include "c_api.h" + +void* LoadAndInitializeDriver(const char* shared_lib, + struct TpuDriverFn* driver_fn) { void* handle; handle = dlopen("./c_api.so", RTLD_NOW); if (!handle) { @@ -29,21 +34,124 @@ int main(int argc, char** argv) { exit(EXIT_FAILURE); } - const char* (*TpuDriver_Version)(void); - void (*TpuDriver_Initialize)(void); - void (*TpuDriver_Open)(const char* worker); + PrototypeTpuDriver_Initialize* initialize_fn; + *(void**)(&initialize_fn) = dlsym(handle, "TpuDriver_Initialize"); + initialize_fn(driver_fn); - fprintf(stdout, "------ Going to Find Out Version ------\n"); - *(void**)(&TpuDriver_Version) = dlsym(handle, "TpuDriver_Version"); - fprintf(stdout, "TPU Driver Version: %s\n", TpuDriver_Version()); + return handle; +} - fprintf(stdout, "------ Going to Initialize ------\n"); - *(void**)(&TpuDriver_Initialize) = dlsym(handle, "TpuDriver_Initialize"); - TpuDriver_Initialize(); +int main(int argc, char** argv) { + struct TpuDriverFn driver_fn; + void* handle = LoadAndInitializeDriver("./c_api.so", &driver_fn); + + fprintf(stdout, "------ Going to Query Version ------\n"); + fprintf(stdout, "TPU Driver Version: %s\n", driver_fn.TpuDriver_Version()); fprintf(stdout, "------ Going to Open a TPU Driver ------\n"); - *(void**)(&TpuDriver_Open) = dlsym(handle, "TpuDriver_Open"); - TpuDriver_Open("local://"); + struct TpuDriver* driver = driver_fn.TpuDriver_Open("local://"); + + // An example of simple program to sum two parameters. + const char* hlo_module_text = R"(HloModule add_vec_module + ENTRY %add_vec (a: s32[256], b: s32[256]) -> s32[256] { + %a = s32[256] parameter(0) + %b = s32[256] parameter(1) + ROOT %sum = s32[256] add(%a, %b) + } + )"; + + fprintf(stdout, "------ Going to Compile a TPU program ------\n"); + struct TpuCompiledProgramHandle* cph = + driver_fn.TpuDriver_CompileProgramFromText(driver, hlo_module_text, + /*num_replicas=*/1, /*eventc=*/0, /*eventv*/NULL); + + fprintf(stdout, "------ Going to Load a TPU program ------\n"); + + struct TpuLoadedProgramHandle* lph = + driver_fn.TpuDriver_LoadProgram(driver, /*core_id=*/0, cph, + /*eventc=*/0, /*eventv=*/NULL); + + const int size = 1024; + + fprintf(stdout, "------ Going to Allocate a TPU Buffer ------\n"); + struct TpuBufferHandle* buf_a_handle = + driver_fn.TpuDriver_Allocate(driver, /*core-id=*/0, /*memory_region=*/1, + /*bytes=*/size, /*eventc=*/0, /*eventv=*/NULL); + fprintf(stdout, "------ Going to Allocate a TPU Buffer ------\n"); + struct TpuBufferHandle* buf_b_handle = + driver_fn.TpuDriver_Allocate(driver, /*core-id=*/0, /*memory_region=*/1, + /*bytes=*/size, /*eventc=*/0, /*eventv=*/NULL); + fprintf(stdout, "------ Going to Allocate a TPU Buffer ------\n"); + struct TpuBufferHandle* buf_sum_handle = + driver_fn.TpuDriver_Allocate(driver, /*core-id=*/0, /*memory_region=*/1, + /*bytes=*/size, /*eventc=*/0, /*eventv=*/NULL); + + char a_src[size], b_src[size], sum_src[size]; + for (int i = 0; i < size; ++i) { + a_src[i] = 1; + b_src[i] = 2; + sum_src[i] = 0; + } + + TpuEvent* allocate_buf_a_events[] = {buf_a_handle->event}; + fprintf(stdout, "------ Going to Transfer To Device ------\n"); + struct TpuEvent* transfer_ev1 = + driver_fn.TpuDriver_TransferToDevice(driver, a_src, buf_a_handle, + /*eventc=*/1, /*eventv=*/allocate_buf_a_events); + TpuEvent* allocate_buf_b_events[] = {buf_a_handle->event}; + fprintf(stdout, "------ Going to Transfer To Device ------\n"); + struct TpuEvent* transfer_ev2 = + driver_fn.TpuDriver_TransferToDevice(driver, b_src, buf_b_handle, + /*eventc=*/1, /*eventv=*/allocate_buf_b_events); + + fprintf(stdout, "------ Going to Execute a TPU program ------\n"); + DeviceAssignment device_assignment = {1, 1}; + TpuBufferHandle* input_buffer_handle[] = {buf_a_handle, buf_b_handle}; + TpuBufferHandle* output_buffer_handle[] = {buf_sum_handle}; + TpuEvent* transfer_events[] = {transfer_ev1, transfer_ev2}; + struct TpuEvent* execute_event = + driver_fn.TpuDriver_ExecuteProgram(driver, lph, + /*inputc=*/2, /*input_buffer_handle=*/input_buffer_handle, + /*outputc=*/1, /*output_buffer_handle=*/output_buffer_handle, + device_assignment, + /*eventc=*/2, /*eventv*/transfer_events); + + fprintf(stdout, "------ Going to Transfer From Device ------\n"); + TpuEvent* execute_events[] = {execute_event}; + struct TpuEvent* transfer_sum_event = + driver_fn.TpuDriver_TransferFromDevice(driver, buf_sum_handle, sum_src, + /*eventc=*/1, /*eventv=*/execute_events); + + TpuStatus* status = driver_fn.TpuDriver_EventAwait(transfer_sum_event, + 10000000); + if (status->code != 0) { + fprintf(stdout, "Transfer Event Await: Code: %d, Message: %s\n", + status->code, status->msg); + } + + fprintf(stdout, "------ Going to Unload a TPU program ------\n"); + struct TpuEvent* unload_program_event = driver_fn.TpuDriver_UnloadProgram( + driver, lph, /*eventc=*/1, /*eventv=*/execute_events); + + fprintf(stdout, "------ Going to Deallocate a TPU Buffer ------\n"); + struct TpuEvent* dealloc_ev1 = driver_fn.TpuDriver_Deallocate(driver, + buf_a_handle, /*eventc=*/0, /*eventv=*/NULL); + driver_fn.TpuDriver_FreeEvent(dealloc_ev1); + + fprintf(stdout, "------ Going to Deallocate a TPU Buffer ------\n"); + struct TpuEvent* dealloc_ev2 = driver_fn.TpuDriver_Deallocate(driver, + buf_b_handle, /*eventc=*/0, /*eventv=*/NULL); + driver_fn.TpuDriver_FreeEvent(dealloc_ev2); + + fprintf(stdout, "------ Going to Deallocate a TPU Buffer ------\n"); + struct TpuEvent* dealloc_ev3 = driver_fn.TpuDriver_Deallocate(driver, + buf_sum_handle, /*eventc=*/0, /*eventv=*/NULL); + driver_fn.TpuDriver_FreeEvent(dealloc_ev3); + + fprintf(stdout, "sum:\n"); + for (size_t i = 0; i < size; ++i) { + fprintf(stdout, "%d ", sum_src[i]); + } dlclose(handle); exit(EXIT_SUCCESS); diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc index 2b69239bb7a..48f89b5cf2f 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc @@ -39,10 +39,9 @@ std::string TpuDevice::DebugString() const { } static std::shared_ptr MakeDevice(const std::string& platform_name, - int id, int local_device_ordinal) { + int id) { CHECK_EQ(platform_name, "tpu"); - CHECK_EQ(id, local_device_ordinal); // Every device must be local for now. - return std::make_shared(id, local_device_ordinal, "tpu"); + return std::make_shared(id, /*local_device_state=*/nullptr, "tpu"); } StatusOr> PyTpuClient::Get( @@ -67,7 +66,7 @@ StatusOr> PyTpuClient::Get( LOG(INFO) << "Creating " << num_cores << " TPU device(s)."; devices.reserve(num_cores); for (int i = 0; i < num_cores; ++i) { - devices.push_back(MakeDevice("tpu", i, i)); + devices.push_back(MakeDevice("tpu", i)); } return std::make_shared("tpu", std::move(client), @@ -87,8 +86,8 @@ PyTpuClient::PyTpuClient(std::string platform_name, CHECK(id_to_device_.insert({device->id(), device}).second) << "Duplicate device id: " << device->id(); - if (device->local_device_ordinal() != -1) { - int idx = device->local_device_ordinal(); + if (device->id() != -1) { + int idx = device->id(); CHECK(local_devices_[idx] == nullptr) << idx; CHECK_LT(idx, local_devices_.size()); local_devices_[idx] = device; @@ -509,7 +508,7 @@ PyTpuExecutable::ExecuteResult PyTpuExecutable::ExecuteHelper( const int device_id = device_assignment_(replica, 0); std::shared_ptr device = LookupDevice(*client_, device_id); CHECK_EQ(device->host_id(), client_->host_id()); - int device_ordinal = device->local_device_ordinal(); + int device_ordinal = device->id(); tensorflow::profiler::TraceMe traceme("PyTpuExecutable::Execute"); VLOG(3) << "Replica " << replica << " mapped to device ordinal for execution: " << device_ordinal; @@ -742,7 +741,7 @@ PyTpuExecutable::ExecutePerReplica( const int device_id = (*device_assignment)(replica, 0); std::shared_ptr device = LookupDevice(*client, device_id); CHECK_EQ(device->host_id(), client->host_id()); - int device_ordinal = device->local_device_ordinal(); + int device_ordinal = device->id(); loaded_programs[replica] = client->driver()->LoadProgram( device_ordinal, compiled_program.get(), {}); } diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h index 7624a14943f..49d4182b719 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h @@ -24,7 +24,6 @@ limitations under the License. #include "absl/synchronization/notification.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" -#include "tensorflow/compiler/xla/python/device_state.h" #include "tensorflow/compiler/xla/python/local_client.h" #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h" #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h" diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc index 60886416a62..2b7082d40c9 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc @@ -96,9 +96,9 @@ PYBIND11_MODULE(tpu_client_extension, m) { std::make_move_iterator(tree.leaves.end())); py::gil_scoped_release gil_release; - return PyTpuBuffer::FromLiterals( - std::move(leaves), tree.shape, std::move(py_buffer_ref), - std::move(client), device->local_device_ordinal()); + return PyTpuBuffer::FromLiterals(std::move(leaves), tree.shape, + std::move(py_buffer_ref), + std::move(client), device->id()); }) .def_static( "from_python", @@ -135,8 +135,8 @@ PYBIND11_MODULE(tpu_client_extension, m) { "Cannot make tuple on device '%s' with '%s' backend", device->DebugString(), client->platform_name()); } - return PyTpuBuffer::MakeTuple( - buffers, client, device->local_device_ordinal()); + return PyTpuBuffer::MakeTuple(buffers, client, + device->id()); }) .def_static("make_tuple", &PyTpuBuffer::MakeTuple) .def("copy_to_device", @@ -144,7 +144,7 @@ PYBIND11_MODULE(tpu_client_extension, m) { CHECK(dst_device != nullptr); GlobalPyRefManager()->CollectGarbage(); py::gil_scoped_release gil_release; - return buffer->CopyToDevice(dst_device->local_device_ordinal()); + return buffer->CopyToDevice(dst_device->id()); }) .def("copy_to_device", [](PyTpuBuffer* buffer, int dst_device_ordinal) { @@ -193,7 +193,7 @@ PYBIND11_MODULE(tpu_client_extension, m) { [](const PyTpuExecutable& executable) { std::vector device_ordinals; for (std::shared_ptr device : executable.local_devices()) { - device_ordinals.push_back(device->local_device_ordinal()); + device_ordinals.push_back(device->id()); } return device_ordinals; }) diff --git a/tensorflow/compiler/xla/python/tpu_driver/external_tpu_driver.cc b/tensorflow/compiler/xla/python/tpu_driver/external_tpu_driver.cc new file mode 100644 index 00000000000..8a8e868b2b8 --- /dev/null +++ b/tensorflow/compiler/xla/python/tpu_driver/external_tpu_driver.cc @@ -0,0 +1,387 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================== + +#include + +#include "absl/strings/str_format.h" +#include "absl/time/time.h" +#include "tensorflow/compiler/xla/python/tpu_driver/client/c_api.h" +#include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h" +#include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace tpu_driver { +namespace { + +class ExternalTpuDriver; + +class ExternalEvent : public Event { + public: + explicit ExternalEvent(::TpuDriverFn* driver_fn, ::TpuEvent* event) + : driver_fn_(driver_fn), event_(event) {} + + ~ExternalEvent() override { driver_fn_->TpuDriver_FreeEvent(event_); } + + xla::Status Await() override { + auto tpu_status = driver_fn_->TpuDriver_EventAwait(event_, -1); + auto ret = xla::Status(tensorflow::error::Code(tpu_status->code), + absl::StrFormat("%s", tpu_status->msg)); + driver_fn_->TpuDriver_FreeStatus(tpu_status); + return ret; + } + + absl::optional AwaitWithTimeout( + absl::Duration duration) override { + auto tpu_status_or = driver_fn_->TpuDriver_EventAwait( + event_, absl::ToInt64Microseconds(duration)); + if (tpu_status_or == nullptr) { + return absl::nullopt; + } else { + auto ret = xla::Status(tensorflow::error::Code(tpu_status_or->code), + absl::StrFormat("%s", tpu_status_or->msg)); + driver_fn_->TpuDriver_FreeStatus(tpu_status_or); + return ret; + } + } + + void AddCallback(std::function callback) override { + // We have to create a new copy of the fn on the heap to make it persist. + std::function* callback_addr = + new std::function(callback); + + // Using the callback_addr instead of capturing because C++11 lambdas with + // variable captures cannot be converted to C function pointers. + driver_fn_->TpuDriver_EventAddCallback( + event_, + [](struct TpuStatus* status, void* additional_info) { + auto callback_addr = + static_cast*>(additional_info); + auto xla_status = xla::Status(tensorflow::error::Code(status->code), + absl::StrFormat("%s", status->msg)); + (*callback_addr)(xla_status); + delete callback_addr; + }, + callback_addr); + } + + private: + ::TpuDriverFn* driver_fn_; + ::TpuEvent* event_; + + friend ExternalTpuDriver; +}; + +class ExternalBufferHandle : public BufferHandle { + public: + explicit ExternalBufferHandle(::TpuDriverFn* driver_fn, + ::TpuBufferHandle* handle) + : handle_(handle), event_(new ExternalEvent(driver_fn, handle->event)) {} + + std::shared_ptr OnReady() override { return event_; } + + int64_t size_in_bytes() override { return handle_->size_in_bytes; } + + absl::optional shape() override { + LOG(FATAL) << "Unimplemented."; + return absl::nullopt; + } + + private: + ::TpuBufferHandle* handle_; + std::shared_ptr event_; + + friend ExternalTpuDriver; +}; + +class ExternalCompiledProgramHandle : public CompiledProgramHandle { + public: + explicit ExternalCompiledProgramHandle(::TpuDriverFn* driver_fn, + ::TpuCompiledProgramHandle* handle) + : handle_(handle), + driver_fn_(driver_fn), + event_(new ExternalEvent(driver_fn, handle->event)) {} + + std::shared_ptr OnReady() override { return event_; } + + int64_t size_in_bytes() override { + LOG(FATAL) << "Unimplemented."; + return 0; + } + + xla::Status program_shape(xla::ProgramShapeProto* program_shape) override { + struct CompiledProgramShape* shape = + driver_fn_->TpuDriver_GetCompiledProgramShape(handle_); + program_shape->ParseFromArray(shape->bytes, shape->size); + + auto status = xla::Status(tensorflow::error::Code(shape->status->code), + absl::StrFormat("%s", shape->status->msg)); + driver_fn_->TpuDriver_FreeCompiledProgramShape(shape); + + return status; + } + + private: + ::TpuCompiledProgramHandle* handle_; + ::TpuDriverFn* driver_fn_; + std::shared_ptr event_; + + friend ExternalTpuDriver; +}; + +class ExternalLoadedProgramHandle : public LoadedProgramHandle { + public: + explicit ExternalLoadedProgramHandle(::TpuDriverFn* driver_fn, + ::TpuLoadedProgramHandle* handle) + : handle_(handle), event_(new ExternalEvent(driver_fn, handle->event)) {} + std::shared_ptr OnReady() override { return event_; } + + int64_t size_in_bytes() override { + LOG(FATAL) << "Unimplemented."; + return 0; + } + + private: + ::TpuLoadedProgramHandle* handle_; + std::shared_ptr event_; + + friend ExternalTpuDriver; +}; + +class ExternalTpuDriver : public TpuDriver { + public: + explicit ExternalTpuDriver(const std::string& so_path) { + void* handle; + handle = dlopen(so_path.c_str(), RTLD_NOW); + if (!handle) { + LOG(FATAL) << "Unable to load shared library: " << dlerror(); + } + + PrototypeTpuDriver_Initialize* initialize_fn; + *reinterpret_cast(&initialize_fn) = + dlsym(handle, "TpuDriver_Initialize"); + initialize_fn(&driver_fn_); + + driver_ = driver_fn_.TpuDriver_Open("local://"); + } + + ~ExternalTpuDriver() override {} + + void QuerySystemInfo(SystemInfo* system_info) override { + LOG(FATAL) << "Unimplemented."; + } + + xla::Status Reset() override { LOG(FATAL) << "Unimplemented."; } + + std::unique_ptr Allocate( + int32_t core_id, MemoryRegion region, int64_t num_bytes, + absl::Span wait_for) override { + auto tpu_events = MakeEventArray(wait_for); + auto bh = absl::make_unique( + &driver_fn_, + driver_fn_.TpuDriver_Allocate(driver_, core_id, region, num_bytes, + wait_for.size(), tpu_events)); + delete tpu_events; + return bh; + } + + std::unique_ptr Allocate( + int32_t core_id, MemoryRegion region, const xla::ShapeProto& shape, + absl::Span wait_for) override { + LOG(FATAL) << "Unimplemented."; + return nullptr; + } + + std::unique_ptr AllocateTuple( + int32_t core_id, MemoryRegion region, + absl::Span children, + absl::Span wait_for) override { + LOG(FATAL) << "Unimplemented."; + return nullptr; + } + + std::shared_ptr Deallocate( + std::unique_ptr handle, + absl::Span wait_for) override { + auto tpu_events = MakeEventArray(wait_for); + auto event = std::make_shared( + &driver_fn_, + driver_fn_.TpuDriver_Deallocate( + driver_, static_cast(handle.get())->handle_, + wait_for.size(), tpu_events)); + delete tpu_events; + return event; + } + + std::shared_ptr TransferToDevice( + const void* src, BufferHandle* dst, + absl::Span wait_for) override { + auto tpu_events = MakeEventArray(wait_for); + auto event = std::make_shared( + &driver_fn_, + driver_fn_.TpuDriver_TransferToDevice( + driver_, src, static_cast(dst)->handle_, + wait_for.size(), tpu_events)); + delete tpu_events; + return event; + } + + std::shared_ptr TransferFromDevice( + const BufferHandle* src, void* dst, + absl::Span wait_for) override { + auto tpu_events = MakeEventArray(wait_for); + auto event = std::make_shared( + &driver_fn_, + driver_fn_.TpuDriver_TransferFromDevice( + driver_, static_cast(src)->handle_, + dst, wait_for.size(), tpu_events)); + delete tpu_events; + return event; + } + + std::shared_ptr TransferFromDeviceToDevice( + const BufferHandle* src, BufferHandle* dst, + absl::Span wait_for) override { + auto tpu_events = MakeEventArray(wait_for); + auto event = std::make_shared( + &driver_fn_, + driver_fn_.TpuDriver_TransferFromDeviceToDevice( + driver_, static_cast(src)->handle_, + static_cast(dst)->handle_, wait_for.size(), + tpu_events)); + delete tpu_events; + return event; + } + + std::unique_ptr CompileProgram( + const xla::HloProto& source, int32_t num_replicas, + absl::Span wait_for) override { + auto tpu_events = MakeEventArray(wait_for); + + struct HloProto hlo; + hlo.size = source.ByteSizeLong(); + hlo.bytes = malloc(hlo.size); + if (!source.SerializeToArray(hlo.bytes, hlo.size)) { + LOG(ERROR) << "Unable to serialize HLO to array."; + return nullptr; + } + + auto handle = absl::make_unique( + &driver_fn_, + driver_fn_.TpuDriver_CompileProgram(driver_, hlo, num_replicas, + wait_for.size(), tpu_events)); + + free(hlo.bytes); + delete tpu_events; + return handle; + } + std::unique_ptr LoadProgram( + int32_t core_id, const CompiledProgramHandle* handle, + absl::Span wait_for) override { + auto tpu_events = MakeEventArray(wait_for); + + auto loaded_handle = absl::make_unique( + &driver_fn_, + driver_fn_.TpuDriver_LoadProgram( + driver_, core_id, + static_cast(handle)->handle_, + wait_for.size(), tpu_events)); + + delete tpu_events; + return loaded_handle; + } + + std::shared_ptr UnloadProgram( + std::unique_ptr handle, + absl::Span wait_for) override { + auto tpu_events = MakeEventArray(wait_for); + auto event = std::make_shared( + &driver_fn_, + driver_fn_.TpuDriver_UnloadProgram( + driver_, + static_cast(handle.get())->handle_, + wait_for.size(), tpu_events)); + delete tpu_events; + return event; + } + + std::shared_ptr ExecuteProgram( + LoadedProgramHandle* program, absl::Span inputs, + absl::Span outputs, + const xla::DeviceAssignmentProto& device_assignment, + absl::Span wait_for) override { + auto tpu_events = MakeEventArray(wait_for); + + struct DeviceAssignmentProto da_proto; + da_proto.size = device_assignment.ByteSizeLong(); + da_proto.bytes = malloc(da_proto.size); + if (!device_assignment.SerializeToArray(da_proto.bytes, da_proto.size)) { + LOG(ERROR) << "Unable to serialize device assignment to array."; + return nullptr; + } + + std::vector<::TpuBufferHandle*> inputv; + inputv.reserve(inputs.size()); + for (int i = 0; i < inputs.size(); i++) { + inputv.push_back( + static_cast(inputs[i])->handle_); + } + std::vector<::TpuBufferHandle*> outputv; + outputv.reserve(outputs.size()); + for (int i = 0; i < outputs.size(); i++) { + outputv.push_back( + static_cast(outputs[i])->handle_); + } + + auto event = std::make_shared( + &driver_fn_, + driver_fn_.TpuDriver_ExecuteProgram( + driver_, + static_cast(program)->handle_, + inputs.size(), inputv.data(), outputs.size(), outputv.data(), + da_proto, wait_for.size(), tpu_events)); + + free(da_proto.bytes); + return event; + } + + std::unique_ptr GetLinearizer() override { return nullptr; } + + private: + ::TpuDriverFn driver_fn_; + ::TpuDriver* driver_; + + ::TpuEvent** MakeEventArray(absl::Span wait_for) { + if (wait_for.empty()) return nullptr; + ::TpuEvent** ret = new ::TpuEvent*[wait_for.size()]; + for (int i = 0; i < wait_for.size(); i++) { + ret[i] = static_cast(wait_for[i])->event_; + } + return ret; + } +}; + +xla::StatusOr> RegisterExternalTpuDriver( + const TpuDriverConfig& config) { + std::string shared_lib = config.worker().substr(strlen("external://")); + return xla::StatusOr>( + absl::make_unique(shared_lib)); +} + +REGISTER_TPU_DRIVER("external://", RegisterExternalTpuDriver); + +} // namespace +} // namespace tpu_driver diff --git a/tensorflow/compiler/xla/python/tpu_driver/platform/external/tools.bzl b/tensorflow/compiler/xla/python/tpu_driver/platform/external/tools.bzl index d2823aeb995..99b07b6c787 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/platform/external/tools.bzl +++ b/tensorflow/compiler/xla/python/tpu_driver/platform/external/tools.bzl @@ -33,5 +33,4 @@ def external_deps(): "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", - "//tensorflow:grpc++", ] diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index f1776763796..b5eb6fa47da 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -142,6 +142,16 @@ Status PyRegisterCustomCallTarget(const std::string& fn_name, return Status::OK(); } +StatusOr> LookupDeviceOrdinal( + PyLocalClient* client, int device_ordinal, absl::string_view caller_name) { + if (device_ordinal < 0 || device_ordinal >= client->local_device_count()) { + return InvalidArgument( + "%s got bad device_ordinal: %d (num_local_devices=%d)", caller_name, + device_ordinal, client->local_device_count()); + } + return client->local_devices()[device_ordinal]; +} + } // namespace PYBIND11_MODULE(xla_extension, m) { @@ -381,13 +391,27 @@ PYBIND11_MODULE(xla_extension, m) { } return result; }) + // TODO(phawkins): delete overload that accepts a device_ordinal after + // all callers have been updated to pass a Device. .def("TransferToInfeed", [](PyLocalClient* client, const LiteralSlice& literal, int device_ordinal) { GlobalPyRefManager()->CollectGarbage(); py::gil_scoped_release gil_release; - return client->TransferToInfeed(literal, device_ordinal); + TF_ASSIGN_OR_RETURN(std::shared_ptr device, + LookupDeviceOrdinal(client, device_ordinal, + "TransferToInfeed")); + return client->TransferToInfeed(literal, device); }) + .def("TransferToInfeed", + [](PyLocalClient* client, const LiteralSlice& literal, + std::shared_ptr device) { + GlobalPyRefManager()->CollectGarbage(); + py::gil_scoped_release gil_release; + return client->TransferToInfeed(literal, device); + }) + // TODO(phawkins): delete overload that accepts a device_ordinal after + // all callers have been updated to pass a Device. .def("TransferFromOutfeed", [](PyLocalClient* client, const Shape& shape, int device_ordinal) -> StatusOr { @@ -395,8 +419,24 @@ PYBIND11_MODULE(xla_extension, m) { std::shared_ptr literal_shared; { py::gil_scoped_release gil_release; - TF_ASSIGN_OR_RETURN(Literal literal, client->TransferFromOutfeed( - shape, device_ordinal)); + TF_ASSIGN_OR_RETURN(std::shared_ptr device, + LookupDeviceOrdinal(client, device_ordinal, + "TransferFromOutfeed")); + TF_ASSIGN_OR_RETURN(Literal literal, + client->TransferFromOutfeed(shape, device)); + literal_shared = std::make_shared(std::move(literal)); + } + return LiteralToPython(std::move(literal_shared)); + }) + .def("TransferFromOutfeed", + [](PyLocalClient* client, const Shape& shape, + std::shared_ptr device) -> StatusOr { + GlobalPyRefManager()->CollectGarbage(); + std::shared_ptr literal_shared; + { + py::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(Literal literal, + client->TransferFromOutfeed(shape, device)); literal_shared = std::make_shared(std::move(literal)); } return LiteralToPython(std::move(literal_shared)); @@ -440,7 +480,7 @@ PYBIND11_MODULE(xla_extension, m) { py::gil_scoped_release gil_release; return PyLocalBuffer::FromLiterals( std::move(leaves), tree.shape, std::move(py_buffer_ref), - std::move(client), device->local_device_ordinal()); + std::move(client), std::move(device)); }) .def_static("make_tuple", [](const std::vector buffers, @@ -454,15 +494,15 @@ PYBIND11_MODULE(xla_extension, m) { "Cannot make tuple on device '%s' with '%s' backend", device->DebugString(), client->platform_name()); } - return PyLocalBuffer::MakeTuple( - buffers, client, device->local_device_ordinal()); + return PyLocalBuffer::MakeTuple(buffers, std::move(client), + std::move(device)); }) .def("copy_to_device", [](PyLocalBuffer* buffer, std::shared_ptr dst_device) { CHECK(dst_device != nullptr); GlobalPyRefManager()->CollectGarbage(); py::gil_scoped_release gil_release; - return buffer->CopyToDevice(dst_device->local_device_ordinal()); + return buffer->CopyToDevice(std::move(dst_device)); }) .def("delete", &PyLocalBuffer::Delete) .def("destructure", &PyLocalBuffer::DestructureTuple) @@ -485,10 +525,7 @@ PYBIND11_MODULE(xla_extension, m) { return LiteralToPython(std::move(literal)); }) .def("shape", &PyLocalBuffer::on_host_shape) - .def("device", - [](PyLocalBuffer* buffer) -> std::shared_ptr { - return buffer->client()->local_devices()[buffer->device_ordinal()]; - }) + .def("device", &PyLocalBuffer::device) .def("platform", &PyLocalBuffer::platform_name) .def("is_deleted", [](const PyLocalBuffer& buffer) { diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index ec5ca9a4a75..fb56e436aaa 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -444,7 +444,7 @@ def shape_from_pyval(pyval): return convert(pyval) -def transfer_to_infeed(value, device_ordinal=0): +def transfer_to_infeed(value, device=None): """Transfers the given value into the XLA infeed queue. XLA's infeed queue is a single queue that feeds the "XLA virtual machine" with @@ -454,29 +454,31 @@ def transfer_to_infeed(value, device_ordinal=0): Args: value: the value that the caller would like to enqueue into the XLA infeed queue - device_ordinal: the device to infeed the value to. Each device has a + device: the device to infeed the value to. Each device has a distinct infeed queue. """ # TODO(phawkins): support non-default backends. backend = get_local_backend() - backend.client.TransferToInfeed(value, device_ordinal) + device = device or backend.local_devices()[0] + backend.client.TransferToInfeed(value, device) -def transfer_from_outfeed(shape, device_ordinal=0): - """Transfers a literal of the given shape from `device_ordinal`'s outfeed. +def transfer_from_outfeed(shape, device=None): + """Transfers a literal of the given shape from `device`'s outfeed. Args: shape: The shape of the value to transfer from outfeed. - device_ordinal: The device ordinal to transfer the outfeed value from. Each - device has a distinct outfeed queue.. + device: The device from which to transfer the outfeed value. Each device has + a distinct outfeed queue.. Returns: The literal value that is produced from the outfeed queue. """ # TODO(phawkins): support non-default backends. backend = get_local_backend() + device = device or backend.local_devices()[0] return backend.client.TransferFromOutfeed( - shape.with_major_to_minor_layout_if_absent(), device_ordinal) + shape.with_major_to_minor_layout_if_absent(), device) DeviceAssignment = _xla.DeviceAssignment diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index b4ea4d9e263..9b24a583cd5 100755 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -608,7 +608,6 @@ cc_library( ":hlo", ":hlo_parser", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], @@ -1079,7 +1078,7 @@ cc_library( deps = [ ":compiler", "//tensorflow/core:lib_internal", - "@llvm//:core", + "@llvm-project//llvm:core", ], ) @@ -1844,6 +1843,7 @@ tf_cc_test( ":hlo_creation_utils", ":hlo_parser", ":hlo_pass", + ":hlo_pass_pipeline", ":pattern_matcher", ":pattern_matcher_gmock", ":shape_inference", @@ -1982,6 +1982,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", ], @@ -2018,6 +2019,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep ], ) @@ -2053,6 +2055,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -2118,6 +2121,7 @@ tf_cc_test( ":while_loop_simplifier", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", "@com_google_absl//absl/strings", @@ -2179,6 +2183,7 @@ tf_cc_test( "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -2207,6 +2212,7 @@ tf_cc_test( ":hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep ], ) @@ -2236,6 +2242,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], ) @@ -2282,13 +2289,17 @@ cc_library( deps = [ ":dynamic_dimension_inference", ":hlo", + ":hlo_casting_utils", ":hlo_dce", ":hlo_pass", + "//tensorflow/compiler/xla:comparison_util", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -2319,6 +2330,7 @@ xla_test( "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], ) @@ -2339,6 +2351,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], ) @@ -2951,6 +2964,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], ) @@ -3309,6 +3323,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "@com_google_absl//absl/memory", ], @@ -3450,6 +3465,7 @@ tf_cc_test( ":hlo_element_type_converter", ":hlo_matchers", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -3528,8 +3544,8 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", - "@llvm//:core", - "@llvm//:transform_utils", + "@llvm-project//llvm:core", + "@llvm-project//llvm:transform_utils", ], ) @@ -3837,6 +3853,7 @@ tf_cc_test( ":sort_simplifier", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], ) @@ -3868,6 +3885,7 @@ tf_cc_test( ":stable_sort_expander", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], ) @@ -3959,6 +3977,7 @@ tf_cc_test( ":while_loop_invariant_code_motion", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], ) @@ -3986,6 +4005,7 @@ tf_cc_test( ":while_loop_constant_sinking", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], ) @@ -4047,6 +4067,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", "@com_google_absl//absl/strings", ], @@ -4095,9 +4116,9 @@ tf_cc_test( "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/tests:verified_hlo_module", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", - "//tensorflow/core:test_main", # fixdeps: keep "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc old mode 100755 new mode 100644 index f145b447bef..0225d2d3bd6 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -80,6 +80,68 @@ bool IsAll(const HloInstruction* op, int8 value) { } } +bool IsAnyOperandComplex(const HloInstruction* hlo) { + for (auto operand : hlo->operands()) { + if (ShapeUtil::ElementIsComplex(operand->shape())) { + return true; + } + } + return false; +} + +bool IsPositive(const HloInstruction* hlo, + const AlgebraicSimplifierOptions& options) { + // Utility only handles real types. + if (IsAnyOperandComplex(hlo)) { + return false; + } + switch (hlo->opcode()) { + case HloOpcode::kGetTupleElement: { + const HloInstruction* gte_operand = hlo->operand(0); + switch (gte_operand->opcode()) { + case HloOpcode::kCustomCall: { + const auto& target = gte_operand->custom_call_target(); + return target == + options.get_cudnn_batchnorm_forward_training_metadata() && + hlo->tuple_index() == 2; + } + default: + return false; + } + } + case HloOpcode::kPower: + case HloOpcode::kAbs: + case HloOpcode::kRsqrt: + case HloOpcode::kSqrt: + return IsPositive(hlo->operand(0), options); + + case HloOpcode::kMultiply: { + return hlo->operand(0) == hlo->operand(1) && + IsPositive(hlo->operand(0), options); + } + default: + return false; + } +} + +bool IsNonNegative(const HloInstruction* hlo, + const AlgebraicSimplifierOptions& options) { + // Utility only handles real types. + if (IsAnyOperandComplex(hlo)) { + return false; + } + switch (hlo->opcode()) { + case HloOpcode::kMultiply: { + return hlo->operand(0) == hlo->operand(1); + } + case HloOpcode::kAbs: { + return true; + } + default: + return IsPositive(hlo, options); + } +} + // Checks whether `op` is a floating-point constant or broadcast of a constant // of the form +/- 2^k for some integer k positive, negative, or zero. Such // values are interesting because multiplying by a power of 2 just moves the @@ -212,6 +274,8 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { AlgebraicSimplifier* simplifier) : options_(options), simplifier_(simplifier) {} + Status HandleAbs(HloInstruction* abs) override; + Status HandleAdd(HloInstruction* add) override; Status HandleAnd(HloInstruction* logical_and) override; @@ -279,8 +343,15 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { Status HandleReduceWindow(HloInstruction* reduce_window) override; Status HandleReverse(HloInstruction* reverse) override; + + Status HandleRsqrt(HloInstruction* rsqrt) override; + Status HandleSlice(HloInstruction* slice) override; + + Status HandleSqrt(HloInstruction* sqrt) override; + Status HandleDynamicSlice(HloInstruction* dynamic_slice) override; + Status HandleDynamicUpdateSlice( HloInstruction* dynamic_update_slice) override; Status HandleScatter(HloInstruction* scatter) override; @@ -501,6 +572,16 @@ bool AlgebraicSimplifierVisitor::ReplaceInstructionIfSameShape( return true; } +Status AlgebraicSimplifierVisitor::HandleAbs(HloInstruction* abs) { + HloInstruction* abs_operand = abs->mutable_operand(0); + VLOG(10) << "trying transform [Abs(A) => A] " << abs->ToString() + << " Abs operand is: " << abs_operand->ToString(); + if (IsNonNegative(abs->operand(0), options_)) { + return ReplaceInstruction(abs, abs_operand); + } + return Status::OK(); +} + Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { HloInstruction *lhs, *rhs; CHECK(Match(add, m::Add(m::Op(&lhs), m::Op(&rhs)))); @@ -2127,24 +2208,24 @@ Status AlgebraicSimplifierVisitor::HandleClamp(HloInstruction* clamp) { Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) { HloInstruction *lhs, *rhs; CHECK(Match(multiply, m::Multiply(m::Op(&lhs), m::Op(&rhs)))); - // A*1 => A - VLOG(10) << "trying transform [A*1 => A]: " << multiply->ToString(); + // LHS*1 => LHS + VLOG(10) << "trying transform [LHS*1 => LHS]: " << multiply->ToString(); if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(multiply, lhs)) { return Status::OK(); } - // 1*A => A - VLOG(10) << "trying transform [1*A => A]: " << multiply->ToString(); + // 1*RHS => RHS + VLOG(10) << "trying transform [1*RHS => RHS]: " << multiply->ToString(); if (IsAll(lhs, 1) && ReplaceInstructionIfSameShape(multiply, rhs)) { return Status::OK(); } - // 0*A => 0. Only applies for integral types for correct NaN-handling. + // 0*RHS => 0. Only applies for integral types for correct NaN-handling. if (IsAll(lhs, 0) && primitive_util::IsIntegralType(multiply->shape().element_type()) && ReplaceInstructionIfSameShape(multiply, lhs)) { return Status::OK(); } - // A*0 => 0 + // LHS*0 => 0 if (IsAll(rhs, 0) && primitive_util::IsIntegralType(multiply->shape().element_type()) && ReplaceInstructionIfSameShape(multiply, rhs)) { @@ -2174,7 +2255,8 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) { product_of_constants)); } - // exp(A) * exp(B) => exp(A+B) + VLOG(10) << "trying to transform exp(LHS) * exp(RHS) => exp(LHS+RHS) " + << multiply->ToString(); if (Match(multiply, m::Multiply(m::Exp(m::Op(&lhs)), m::Exp(m::Op(&rhs))))) { auto add = computation_->AddInstruction(HloInstruction::CreateBinary( multiply->shape(), HloOpcode::kAdd, lhs, rhs)); @@ -2182,6 +2264,18 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) { multiply, HloInstruction::CreateUnary(multiply->shape(), HloOpcode::kExp, add)); } + + VLOG(10) << "trying transform [rsqrt(B) * rsqrt(B) => 1/B] " + << multiply->ToString(); + HloInstruction* b; + if (Match(multiply, m::Multiply(m::Rsqrt(m::Op(&b)), m::Rsqrt(m::Op(&b)))) && + IsPositive(b, options_)) { + return ReplaceWithNewInstruction( + multiply, + HloInstruction::CreateBinary(multiply->shape(), HloOpcode::kDivide, + MakeScalarLike(b, 1), b)); + } + return Status::OK(); } @@ -3329,6 +3423,31 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { return Status::OK(); } +Status AlgebraicSimplifierVisitor::HandleRsqrt(HloInstruction* rsqrt) { + VLOG(10) << "trying transform [rsqrt(Pow(A, -2)) => |A|] " + << rsqrt->ToString(); + HloInstruction* rsqrt_operand = rsqrt->mutable_operand(0); + if (rsqrt_operand->opcode() == HloOpcode::kPower && + IsAll(rsqrt_operand->operand(1), -2) && + IsPositive(rsqrt_operand, options_)) { + return ReplaceWithNewInstruction( + rsqrt, HloInstruction::CreateUnary(rsqrt->shape(), HloOpcode::kAbs, + rsqrt_operand->mutable_operand(0))); + } + + VLOG(10) << "trying transform [rsqrt(Divide(1, A)) => sqrt(A)] " + << rsqrt->ToString(); + if (rsqrt_operand->opcode() == HloOpcode::kDivide && + IsAll(rsqrt_operand->operand(0), 1) && + IsPositive(rsqrt_operand->operand(1), options_)) { + return ReplaceWithNewInstruction( + rsqrt, HloInstruction::CreateUnary(rsqrt->shape(), HloOpcode::kSqrt, + rsqrt_operand->mutable_operand(1))); + } + + return Status::OK(); +} + Status AlgebraicSimplifierVisitor::HandleDynamicSlice( HloInstruction* dynamic_slice) { auto operand = dynamic_slice->mutable_operand(0); @@ -3813,6 +3932,19 @@ Status AlgebraicSimplifierVisitor::HandleSort(HloInstruction* sort) { return Status::OK(); } +Status AlgebraicSimplifierVisitor::HandleSqrt(HloInstruction* sqrt) { + VLOG(10) << "trying transform [sqrt(A*A) => |A|] " << sqrt->ToString(); + HloInstruction* sqrt_operand = sqrt->mutable_operand(0); + if (sqrt_operand->opcode() == HloOpcode::kMultiply && + sqrt_operand->operand(0) == sqrt_operand->operand(1)) { + return ReplaceWithNewInstruction( + sqrt, HloInstruction::CreateUnary( + sqrt_operand->mutable_operand(0)->shape(), HloOpcode::kAbs, + sqrt_operand->mutable_operand(0))); + } + return Status::OK(); +} + namespace { bool OnlyPermutesDegenerateDims(const Shape& shape, absl::Span perm) { diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index 74d8b1d4582..ce364a16134 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -99,7 +99,27 @@ class AlgebraicSimplifierOptions { int64 very_small_gather_size() const { return very_small_gather_size_; } + void set_cudnn_batchnorm_forward_training_metadata(const string& c) { + metadata_.cudnn_batchnorm_forward_training_metadata = c; + } + + const string& get_cudnn_batchnorm_forward_training_metadata() const { + return metadata_.cudnn_batchnorm_forward_training_metadata; + } + private: + // Metadata struct can be used to store any metadata information encapsulated + // with the AlgebraicSimplierOptions that can be later used in an + // AlgebraicSimplifier pass. For example, + // cudnn_batchnorm_forward_training_metadata can be used to store the name of + // a custom call. If the custom call is + // __cudnn$batchNormalizationForwardTraining, the output with index 2 is + // guaranteed to be postive. This property has been used to recursively + // determine if the operand of an instruction is always positive. + struct Metadata { + string cudnn_batchnorm_forward_training_metadata{""}; + Metadata() {} + }; ReshapeIsBitcastCallback reshape_is_bitcast_callback_; bool is_layout_sensitive_{false}; bool enable_dot_strength_reduction_{true}; @@ -107,6 +127,7 @@ class AlgebraicSimplifierOptions { bool enable_conv_simplification_{true}; bool enable_window_reduce_to_reduce_replacement_{true}; int64 very_small_gather_size_{4}; + Metadata metadata_; }; // A pass which performs algebraic simplifications. diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index f37ff5387ee..b4e66eb1ad7 100755 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" +#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" #include "tensorflow/compiler/xla/service/shape_inference.h" @@ -5847,5 +5848,243 @@ TEST_F(AlgebraicSimplifierTest, SliceOfConcat) { GmockMatch(m::Parameter(1))); } +TEST_F(AlgebraicSimplifierTest, SqrtOfSelfMultiply) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[32]{0} parameter(0) + m0 = f32[32]{0} multiply(f32[32]{0} p0, f32[32]{0} p0) + ROOT s0 = f32[32]{0} sqrt(f32[32]{0} m0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Abs(m::Parameter(0)))); +} + +TEST_F(AlgebraicSimplifierTest, RsqrtOfRPower) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[128,32,2,112]{3,2,1,0} parameter(0) + p1 = f32[32]{0} parameter(1) + p2 = f32[32]{0} parameter(2) + c0 = f32[] constant(0.001) + c1 = s64[] constant(1) + custom-call.1 = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) custom-call(p0, p1, p2, c0, c1), custom_call_target="__cudnn$batchNormalizationForwardTraining" + get-tuple-element.1 = f32[128,32,2,112]{3,2,1,0} get-tuple-element(custom-call.1), index=0 + get-tuple-element.2 = f32[32]{0} get-tuple-element(custom-call.1), index=1 + get-tuple-element = f32[32]{0} get-tuple-element(custom-call.1), index=2 + c2 = f32[] constant(-2) + broadcast = f32[32]{0} broadcast(f32[] c2), dimensions={} + power = f32[32]{0} power(get-tuple-element, broadcast) + rsqrt = f32[32]{0} rsqrt(f32[32]{0} power) + ROOT tuple = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) tuple(get-tuple-element.1, get-tuple-element.2, rsqrt) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + default_options_.set_cudnn_batchnorm_forward_training_metadata( + "__cudnn$batchNormalizationForwardTraining"); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + // Expect transformation: rsqrt(power(gte.2,-2)) -> abs(gte.2) + EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kPower), nullptr); + EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kRsqrt), nullptr); + auto computation = m->entry_computation(); + auto root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kTuple); + EXPECT_EQ(root->operand(2)->opcode(), HloOpcode::kAbs); + EXPECT_EQ(root->operand(2)->operand(0)->opcode(), + HloOpcode::kGetTupleElement); +} + +TEST_F(AlgebraicSimplifierTest, RsqrtDivide) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[128,32,2,112]{3,2,1,0} parameter(0) + p1 = f32[32]{0} parameter(1) + p2 = f32[32]{0} parameter(2) + constant = f32[] constant(0.001) + constant.1 = s64[] constant(1) + custom-call.1 = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) custom-call(p0, p1, p2, constant, constant.1), custom_call_target="__cudnn$batchNormalizationForwardTraining" + get-tuple-element.1 = f32[128,32,2,112]{3,2,1,0} get-tuple-element(custom-call.1), index=0 + get-tuple-element.2 = f32[32]{0} get-tuple-element(custom-call.1), index=1 + get-tuple-element = f32[32]{0} get-tuple-element(custom-call.1), index=2 + constant.2 = f32[] constant(1) + broadcast.1 = f32[32]{0} broadcast(constant.2), dimensions={} + divide = f32[32]{0} divide(broadcast.1, get-tuple-element) + rsqrt = f32[32]{0} rsqrt(divide) + ROOT tuple = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) tuple(get-tuple-element.1, get-tuple-element.2, rsqrt) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + default_options_.set_cudnn_batchnorm_forward_training_metadata( + "__cudnn$batchNormalizationForwardTraining"); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + // Expect transformation: rsqrt(divide(1,gte.2)) -> sqrt(gte.2) + EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kDivide), nullptr); + EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kRsqrt), nullptr); + auto computation = m->entry_computation(); + auto root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kTuple); + EXPECT_EQ(root->operand(2)->opcode(), HloOpcode::kSqrt); + EXPECT_EQ(root->operand(2)->operand(0)->opcode(), + HloOpcode::kGetTupleElement); +} + +TEST_F(AlgebraicSimplifierTest, MultiplySelfRsqrt) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[128,32,2,112]{3,2,1,0} parameter(0) + p1 = f32[32]{0} parameter(1) + p2 = f32[32]{0} parameter(2) + constant = f32[] constant(0.001) + constant.1 = s64[] constant(1) + custom-call.1 = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) custom-call(p0, p1, p2, constant, constant.1), custom_call_target="__cudnn$batchNormalizationForwardTraining" + get-tuple-element.1 = f32[128,32,2,112]{3,2,1,0} get-tuple-element(custom-call.1), index=0 + get-tuple-element.2 = f32[32]{0} get-tuple-element(custom-call.1), index=1 + get-tuple-element = f32[32]{0} get-tuple-element(custom-call.1), index=2 + rsqrt = f32[32]{0} rsqrt(get-tuple-element) + multiply = f32[32]{0} multiply(rsqrt, rsqrt) + ROOT tuple = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) tuple(get-tuple-element.1, get-tuple-element.2, multiply) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + default_options_.set_cudnn_batchnorm_forward_training_metadata( + "__cudnn$batchNormalizationForwardTraining"); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + + // Expect transformation: multiply(rsqrt(gte.2), rsqrt(gte.2)) -> divide(1, + // gte.2) + EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kMultiply), nullptr); + EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kRsqrt), nullptr); + + auto computation = m->entry_computation(); + auto root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kTuple); + EXPECT_EQ(root->operand(2)->opcode(), HloOpcode::kDivide); + EXPECT_EQ(root->operand(2)->operand(0)->opcode(), HloOpcode::kBroadcast); + EXPECT_EQ(root->operand(2)->operand(1)->opcode(), + HloOpcode::kGetTupleElement); +} + +TEST_F(AlgebraicSimplifierTest, MultiplySelfRsqrt_NegativeTestCase) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[128,32,2,112]{3,2,1,0} parameter(0) + p1 = f32[32]{0} parameter(1) + p2 = f32[32]{0} parameter(2) + constant = f32[] constant(0.001) + constant.1 = s64[] constant(1) + custom-call.1 = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) custom-call(p0, p1, p2, constant, constant.1), custom_call_target="__cudnn$batchNormalizationForwardTraining" + get-tuple-element.1 = f32[128,32,2,112]{3,2,1,0} get-tuple-element(custom-call.1), index=0 + get-tuple-element.2 = f32[32]{0} get-tuple-element(custom-call.1), index=1 + get-tuple-element = f32[32]{0} get-tuple-element(custom-call.1), index=2 + rsqrt = f32[32]{0} rsqrt(get-tuple-element) + multiply = f32[32]{0} multiply(rsqrt, rsqrt) + ROOT tuple = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) tuple(get-tuple-element.1, get-tuple-element.2, multiply) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + default_options_.set_cudnn_batchnorm_forward_training_metadata( + "__cudnn$batchNormalizationForward"); + ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_NE(FindInstruction(m.get(), HloOpcode::kMultiply), nullptr); + EXPECT_NE(FindInstruction(m.get(), HloOpcode::kRsqrt), nullptr); + EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kDivide), nullptr); + EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kBroadcast), nullptr); + EXPECT_EQ(m->entry_computation()->root_instruction()->operand(2)->opcode(), + HloOpcode::kMultiply); +} + +TEST_F(AlgebraicSimplifierTest, AbsEliminationBatchnormTraining) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[128,32,2,112]{3,2,1,0} parameter(0) + p1 = f32[32]{0} parameter(1) + p2 = f32[32]{0} parameter(2) + constant = f32[] constant(0.001) + constant.1 = s64[] constant(1) + custom-call.1 = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) custom-call(p0, p1, p2, constant, constant.1), custom_call_target="__cudnn$batchNormalizationForwardTraining" + get-tuple-element.1 = f32[128,32,2,112]{3,2,1,0} get-tuple-element(custom-call.1), index=0 + get-tuple-element.2 = f32[32]{0} get-tuple-element(custom-call.1), index=1 + get-tuple-element = f32[32]{0} get-tuple-element(custom-call.1), index=2 + abs = f32[32]{0} abs(get-tuple-element) + ROOT %tuple = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) tuple(get-tuple-element.1, get-tuple-element.2, abs) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + default_options_.set_cudnn_batchnorm_forward_training_metadata( + "__cudnn$batchNormalizationForwardTraining"); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + // Verify that the module doesn't have any abs node. + EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kAbs), nullptr); + EXPECT_EQ(m->entry_computation()->root_instruction()->operand(2)->opcode(), + HloOpcode::kGetTupleElement); +} + +TEST_F(AlgebraicSimplifierTest, + AbsEliminationBatchnormTraining_NegativeTestCase) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[128,32,2,112]{3,2,1,0} parameter(0) + p1 = f32[32]{0} parameter(1) + p2 = f32[32]{0} parameter(2) + constant = f32[] constant(0.001) + constant.1 = s64[] constant(1) + custom-call.1 = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) custom-call(p0, p1, p2, constant, constant.1), custom_call_target="__cudnn$batchNormalizationForwardTraining" + get-tuple-element.1 = f32[128,32,2,112]{3,2,1,0} get-tuple-element(custom-call.1), index=0 + get-tuple-element.2 = f32[32]{0} get-tuple-element(custom-call.1), index=1 + get-tuple-element = f32[32]{0} get-tuple-element(custom-call.1), index=2 + abs = f32[32]{0} abs(get-tuple-element) + ROOT %tuple = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) tuple(get-tuple-element.1, get-tuple-element.2, abs) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + default_options_.set_cudnn_batchnorm_forward_training_metadata( + "__cudnn$batchNormalizationForwardInference"); + ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_NE(FindInstruction(m.get(), HloOpcode::kAbs), nullptr); +} + +TEST_F(AlgebraicSimplifierTest, AbsEliminationMultiply) { + const char* kModuleStr = R"( + HloModule m + test { + p = f32[32]{0} parameter(0) + m = f32[32]{0} multiply(p, p) + ROOT a = f32[32]{0} abs(m) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0)))); +} + +TEST_F(AlgebraicSimplifierTest, AbsEliminationPower2) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[32]{0} parameter(0) + c0 = f32[] constant(2) + b0 = f32[32]{0} broadcast(c0), dimensions={} + pow = f32[32]{0} power(p0, b0) + ROOT a = f32[32]{0} abs(pow) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + // Pow(A, 2) is transformed to AA. As a result, Abs(Power(A, 2)) is + // transformed to AA. + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0)))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.cc b/tensorflow/compiler/xla/service/ar_crs_combiner.cc index 06aaad351e6..ec8c391a542 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner.cc +++ b/tensorflow/compiler/xla/service/ar_crs_combiner.cc @@ -366,12 +366,13 @@ void ArCrsCombiner::GroupAllReducesById(HloModule* module) { } Status ArCrsCombiner::KeepProvablyEqualInstructionGroupsMPMD() { - for (auto it : all_reduce_map_) { - auto channel_id = it.first; + for (auto it = all_reduce_map_.begin(); it != all_reduce_map_.end();) { + auto copy_it = it++; // Advance `it` before invalidation from erase. + auto channel_id = copy_it->first; VLOG(2) << "KeepProvablyEqualInstructionGroups. Checking AllReduce channel id: " << channel_id << "\n"; - auto pairs_vec = it.second; + auto pairs_vec = copy_it->second; TF_RET_CHECK(pairs_vec.size() == num_spatial_partitions_); auto instr_0 = pairs_vec[0].ar; for (int i = 1; i < pairs_vec.size(); ++i) { @@ -381,7 +382,7 @@ Status ArCrsCombiner::KeepProvablyEqualInstructionGroupsMPMD() { absl::flat_hash_map visited_pairs; while (true) { if (!InstructionsComputeSameValue(next_0, next_i, &visited_pairs)) { - all_reduce_map_.erase(channel_id); + all_reduce_map_.erase(copy_it); VLOG(2) << "KeepProvablyEqualInstructionGroups. Erased AllReduce " "channel id: " << channel_id << "\n"; @@ -406,12 +407,13 @@ Status ArCrsCombiner::KeepProvablyEqualInstructionGroupsSPMD( auto replication_analysis, HloReplicationAnalysis::Run(module, /*cross_partition_spmd=*/true)); - for (auto it : all_reduce_map_) { - auto channel_id = it.first; + for (auto it = all_reduce_map_.begin(); it != all_reduce_map_.end();) { + auto copy_it = it++; // Advance `it` before invalidation from erase. + auto channel_id = copy_it->first; VLOG(2) << "KeepProvablyEqualInstructionGroups. Checking AllReduce channel id: " << channel_id << "\n"; - auto pairs_vec = it.second; + auto pairs_vec = copy_it->second; TF_RET_CHECK(pairs_vec.size() == 1); auto instr = pairs_vec[0].ar; auto next = instr->users()[0]; @@ -420,7 +422,7 @@ Status ArCrsCombiner::KeepProvablyEqualInstructionGroupsSPMD( // guarantee that the HLO produces an array. TF_RET_CHECK(next->shape().IsArray()); if (!replication_analysis->HloInstructionIsReplicatedAt(next, {})) { - all_reduce_map_.erase(channel_id); + all_reduce_map_.erase(copy_it); VLOG(2) << "KeepProvablyEqualInstructionGroups. Erased AllReduce " "channel id: " << channel_id << "\n"; diff --git a/tensorflow/compiler/xla/service/convolution_group_converter.cc b/tensorflow/compiler/xla/service/convolution_group_converter.cc index f942d6768df..06bcd773f44 100644 --- a/tensorflow/compiler/xla/service/convolution_group_converter.cc +++ b/tensorflow/compiler/xla/service/convolution_group_converter.cc @@ -218,14 +218,127 @@ Status ConvolutionVisitor::HandleBatchGroupCount(HloInstruction* convolution) { int64 input_batch_dimension = dim_numbers.input_batch_dimension(); int64 output_batch_dimension = dim_numbers.output_batch_dimension(); + const int64 kernel_output_feature_dimension = + dim_numbers.kernel_output_feature_dimension(); int64 output_feature_dimension = dim_numbers.output_feature_dimension(); int64 input_batch = activation->shape().dimensions(input_batch_dimension); + const int64 output_feature = + filter->shape().dimensions(kernel_output_feature_dimension); + + VLOG(2) << "is_cost_viable_ " << is_cost_viable_(convolution); + const bool cost_too_high = !is_cost_viable_(convolution); + + if (output_feature != batch_group_count) { + const int64 group_size = output_feature / batch_group_count; + + VLOG(2) << "Need to insert a spatial dimension in activations and in the " + "kernel to deal with backprop of grouped convolutions " + << " group size " << group_size; + + // Add spatial dimension to the activation, and reshape. + Shape reshaped_activation_shape = activation->shape(); + ShapeUtil::AppendMajorDimension(1, &reshaped_activation_shape); + const int64 new_spatial_dim = + reshaped_activation_shape.dimensions().size() - 1; + + activation = add( + HloInstruction::CreateReshape(reshaped_activation_shape, activation)); + + // Insert new spatial dimension after the output feature dimension on the + // kernel. + auto dims = filter->shape().dimensions(); + std::vector new_dims; + for (int i = 0; i < dims.size(); i++) { + if (i == kernel_output_feature_dimension) { + new_dims.push_back(batch_group_count); + new_dims.push_back(group_size); + } else { + new_dims.push_back(dims[i]); + } + } + + Shape reshaped_filter_shape = ShapeUtil::MakeShapeWithDescendingLayout( + filter->shape().element_type(), new_dims); + + filter = add(HloInstruction::CreateReshape(reshaped_filter_shape, filter)); + + Shape new_output_shape = convolution->shape(); + ShapeUtil::AppendMajorDimension(1, &new_output_shape); + + // Edit convolution dimension numbers. Note that kernel_input_feature_dim + // now becomes a spatial dimension, and the newly added dimension of size + // 1 is the new kernel_input_feature_dim. + dim_numbers.add_input_spatial_dimensions(new_spatial_dim); + + // Update spatial dimension numbers if they show up after the newly added + // spatial dimension. + for (auto& d : *dim_numbers.mutable_kernel_spatial_dimensions()) { + if (d > kernel_output_feature_dimension) { + ++d; + } + } + + // Same for input feature dimension. + if (dim_numbers.kernel_input_feature_dimension() > + kernel_output_feature_dimension) { + dim_numbers.set_kernel_input_feature_dimension( + dim_numbers.kernel_input_feature_dimension() + 1); + } + + dim_numbers.add_kernel_spatial_dimensions(kernel_output_feature_dimension + + 1); + + dim_numbers.add_output_spatial_dimensions(output_batch_dimension); + + dim_numbers.set_output_batch_dimension(new_spatial_dim); + + // Add window for the new spatial dimension. + Window new_window = convolution->window(); + auto* dim = new_window.add_dimensions(); + dim->set_window_dilation(1); + dim->set_base_dilation(1); + dim->set_stride(1); + dim->set_size(group_size); + dim->set_padding_high(group_size - 1); + dim->set_padding_low(group_size - 1); + dim->set_window_reversal(false); + + auto new_convolution = add(HloInstruction::CreateConvolve( + new_output_shape, activation, filter, /*feature_group_count=*/1, + batch_group_count, new_window, dim_numbers, + convolution->precision_config())); + + VLOG(2) << "New convolution " << new_convolution->ToString(); + + // This reversal is not done via set_window_reversal because GPUs don't + // support it. + auto rev = add(HloInstruction::CreateReverse( + new_output_shape, new_convolution, {output_batch_dimension})); + + // Delete the extra spatial dimension, and reshape. + Shape reshaped_convolution_shape = + ShapeUtil::DeleteDimension(new_spatial_dim, rev->shape()); + auto reshaped_convolution = + HloInstruction::CreateReshape(reshaped_convolution_shape, rev); + + VLOG(2) << "Reshaped convolution " << reshaped_convolution->ToString(); + + TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( + convolution, std::move(reshaped_convolution))); + + changed_ = true; + + convolution = new_convolution; + dim_numbers = convolution->convolution_dimension_numbers(); + output_batch_dimension = new_spatial_dim; + } + // We are not yet supporting batch_group of sizes greater than 1. TF_RET_CHECK(input_batch == batch_group_count); - if (!is_cost_viable_(convolution) || filter_expansion_) { + if (cost_too_high || filter_expansion_) { // We first obtain the expanded the filter (which is the convolution // output). The batch dimension is the expanded one (which originally // represents kernel input feature dimension). We mask the filter to zero @@ -238,11 +351,17 @@ Status ConvolutionVisitor::HandleBatchGroupCount(HloInstruction* convolution) { auto expanded_filter_shape = ExpandedFilterShape( convolution->shape(), batch_group_count, output_batch_dimension); + VLOG(2) << "output_batch_dimension " << output_batch_dimension; + VLOG(2) << "New output shape of convolution " + << expanded_filter_shape.ToString(); + auto new_convolution = add(HloInstruction::CreateConvolve( expanded_filter_shape, activation, filter, /*feature_group_count=*/1, /*batch_group_count=*/1, convolution->window(), dim_numbers, convolution->precision_config())); + VLOG(2) << "Expanded convolution " << new_convolution->ToString(); + auto zero = add(HloInstruction::CreateConstant( LiteralUtil::Zero(expanded_filter_shape.element_type()))); auto zero_filter = @@ -354,6 +473,7 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { changed_ = false; return Status::OK(); } + VLOG(2) << "is_cost_viable_ " << is_cost_viable_(convolution); // We want to repeat 'filter' in the 'input_feature_dim' dimension // 'group_count' times. if (!is_cost_viable_(convolution) || filter_expansion_) { diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index bec66aea27f..713f10b146f 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -77,7 +77,6 @@ cc_library( ":buffer_info_util", ":conv_canonicalization", ":cpu_executable", - ":cpu_hlo_support_checker", ":cpu_instruction_fusion", ":cpu_layout_assignment", ":cpu_options", @@ -148,15 +147,15 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", - "@llvm//:core", - "@llvm//:mc", - "@llvm//:object", - "@llvm//:support", - "@llvm//:target", - "@llvm//:x86_code_gen", # fixdeps: keep + "@llvm-project//llvm:core", + "@llvm-project//llvm:mc", + "@llvm-project//llvm:object", + "@llvm-project//llvm:support", + "@llvm-project//llvm:target", + "@llvm-project//llvm:x86_code_gen", # fixdeps: keep ] + select({ "//tensorflow:linux_ppc64le": [ - "@llvm//:powerpc_code_gen", # fixdeps: keep + "@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep ], "//conditions:default": [ ], @@ -188,12 +187,12 @@ cc_library( ":runtime_single_threaded_fft", ":runtime_single_threaded_matmul", "@com_google_absl//absl/memory", - "@llvm//:execution_engine", - "@llvm//:core", - "@llvm//:mc", # fixdeps: keep - "@llvm//:orc_jit", - "@llvm//:support", - "@llvm//:target", # fixdeps: keep + "@llvm-project//llvm:execution_engine", + "@llvm-project//llvm:core", + "@llvm-project//llvm:mc", # fixdeps: keep + "@llvm-project//llvm:orc_jit", + "@llvm-project//llvm:support", + "@llvm-project//llvm:target", # fixdeps: keep "//tensorflow/compiler/xla/service:custom_call_target_registry", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", @@ -257,7 +256,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", - "@llvm//:orc_jit", + "@llvm-project//llvm:orc_jit", ], ) @@ -315,10 +314,10 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", - "@llvm//:code_gen", - "@llvm//:core", - "@llvm//:support", - "@llvm//:target", + "@llvm-project//llvm:code_gen", + "@llvm-project//llvm:core", + "@llvm-project//llvm:support", + "@llvm-project//llvm:target", ], ) @@ -332,8 +331,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", - "@llvm//:analysis", - "@llvm//:target", + "@llvm-project//llvm:analysis", + "@llvm-project//llvm:target", ], ) @@ -362,7 +361,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@llvm//:core", + "@llvm-project//llvm:core", ], ) @@ -378,7 +377,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", "//tensorflow/core:lib", "@com_google_absl//absl/strings:str_format", - "@llvm//:core", + "@llvm-project//llvm:core", ], ) @@ -394,7 +393,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", - "@llvm//:core", + "@llvm-project//llvm:core", ], ) @@ -425,7 +424,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", "@com_google_absl//absl/strings", - "@llvm//:core", + "@llvm-project//llvm:core", ], ) @@ -463,13 +462,13 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", "@com_google_absl//absl/memory", - "@llvm//:analysis", - "@llvm//:core", - "@llvm//:ipo", - "@llvm//:mc", - "@llvm//:object", - "@llvm//:support", - "@llvm//:target", + "@llvm-project//llvm:analysis", + "@llvm-project//llvm:core", + "@llvm-project//llvm:ipo", + "@llvm-project//llvm:mc", + "@llvm-project//llvm:object", + "@llvm-project//llvm:support", + "@llvm-project//llvm:target", ], ) @@ -527,8 +526,8 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:math_ops", "//tensorflow/core:lib", - "@llvm//:core", - "@llvm//:transform_utils", + "@llvm-project//llvm:core", + "@llvm-project//llvm:transform_utils", ], ) @@ -762,7 +761,7 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla/service:hlo", - "@llvm//:core", + "@llvm-project//llvm:core", ], ) @@ -818,6 +817,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "@com_google_absl//absl/types:span", ], @@ -914,6 +914,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", ], @@ -936,7 +937,7 @@ cc_library( hdrs = ["orc_jit_memory_mapper.h"], deps = [ "//tensorflow/core:lib", - "@llvm//:execution_engine", + "@llvm-project//llvm:execution_engine", ], ) @@ -953,34 +954,8 @@ cc_library( "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/types:span", - "@llvm//:core", - "@llvm//:support", - ], -) - -cc_library( - name = "cpu_hlo_support_checker", - srcs = ["cpu_hlo_support_checker.cc"], - hdrs = ["cpu_hlo_support_checker.h"], - deps = [ - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:xla_data_proto_cc", - "//tensorflow/compiler/xla/service:hlo_pass", - "//tensorflow/core:lib", - ], -) - -tf_cc_test( - name = "cpu_hlo_support_checker_test", - srcs = ["cpu_hlo_support_checker_test.cc"], - deps = [ - ":cpu_hlo_support_checker", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", + "@llvm-project//llvm:core", + "@llvm-project//llvm:support", ], ) @@ -1007,8 +982,8 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "@llvm//:core", - "@llvm//:support", - "@llvm//:target", + "@llvm-project//llvm:core", + "@llvm-project//llvm:support", + "@llvm-project//llvm:target", ], ) diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 6a331ba4f19..a04a39b4461 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -60,7 +60,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h" #include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h" #include "tensorflow/compiler/xla/service/cpu/cpu_executable.h" -#include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h" #include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h" #include "tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h" #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" @@ -248,7 +247,6 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass(); pipeline.AddPass(); - pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc deleted file mode 100644 index 4ac61f44d9f..00000000000 --- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h" - -#include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/core/lib/core/errors.h" - -namespace xla { - -StatusOr CpuHloSupportChecker::Run(HloModule* module) { - for (auto* computation : module->computations()) { - for (const auto& instruction : computation->instructions()) { - TF_RETURN_IF_ERROR( - ShapeUtil::ValidateShapeWithOptionalLayout(instruction->shape())); - TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( - instruction->shape(), - [&instruction](const Shape& subshape, const ShapeIndex&) { - if (LayoutUtil::IsSparseArray(subshape)) { - return xla::Unimplemented( - "CPU backend does not support HLO instruction %s with shape " - "containing a sparse layout: %s", - instruction->ToString(), - ShapeUtil::HumanStringWithLayout(instruction->shape())); - } - return Status::OK(); - })); - } - } - return false; -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h deleted file mode 100644 index a39a9d47246..00000000000 --- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h +++ /dev/null @@ -1,40 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_HLO_SUPPORT_CHECKER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_HLO_SUPPORT_CHECKER_H_ - -#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" - -namespace xla { - -// This pass should run early in the HLO pipeline and checks for HLO constructs -// which are not supported by the CPU backend and cannot be removed via HLO -// transformations (eg, sparse layouts). -class CpuHloSupportChecker : public HloModulePass { - public: - CpuHloSupportChecker() = default; - ~CpuHloSupportChecker() override = default; - - absl::string_view name() const override { return "cpu_hlo_support_checker"; } - - // Note: always returns false (no instructions are ever modified by this - // pass). - StatusOr Run(HloModule* module) override; -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_HLO_SUPPORT_CHECKER_H_ diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc deleted file mode 100644 index 7a905928e6d..00000000000 --- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc +++ /dev/null @@ -1,76 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h" - -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/protobuf/error_codes.pb.h" - -namespace xla { -namespace { - -using ::testing::HasSubstr; - -class CpuHloSupportCheckerTest : public HloTestBase { - protected: - CpuHloSupportChecker& checker() { return checker_; } - - private: - CpuHloSupportChecker checker_; -}; - -TEST_F(CpuHloSupportCheckerTest, Add) { - HloComputation::Builder builder(TestName()); - const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); - HloInstruction* param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape, "param0")); - HloInstruction* param1 = builder.AddInstruction( - HloInstruction::CreateParameter(1, scalar_shape, "param1")); - builder.AddInstruction(HloInstruction::CreateBinary( - scalar_shape, HloOpcode::kAdd, param0, param1)); - auto module = CreateNewVerifiedModule(); - module->AddEntryComputation(builder.Build()); - - TF_ASSERT_OK(checker().Run(module.get()).status()); -} - -TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) { - HloComputation::Builder builder(TestName()); - const Shape sparse_shape = ShapeUtil::MakeShapeWithSparseLayout(F32, {10}, 2); - HloInstruction* param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, sparse_shape, "param0")); - HloInstruction* param1 = builder.AddInstruction( - HloInstruction::CreateParameter(1, sparse_shape, "param1")); - builder.AddInstruction(HloInstruction::CreateBinary( - sparse_shape, HloOpcode::kAdd, param0, param1)); - // Since verifier is reporting sparse layouts as errors, we should - // use a regular HloModule instead of VerifiedHloModule to avoid - // verifier errors being triggered in the destructor. - auto module = CreateNewUnverifiedModule(); - module->AddEntryComputation(builder.Build()); - - Status status = checker().Run(module.get()).status(); - ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); - EXPECT_THAT(status.error_message(), - HasSubstr("CPU backend does not support")); - EXPECT_THAT(status.error_message(), - HasSubstr(ShapeUtil::HumanStringWithLayout(sparse_shape))); -} - -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 394d1fc979d..24718e16e22 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -37,6 +37,7 @@ limitations under the License. #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" +#include "llvm/IR/IntrinsicsX86.h" #include "llvm/IR/LLVMContext.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/map_util.h" diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc index 70a6d0af02c..7831c1b1b5b 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc @@ -70,11 +70,11 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSort( index % sort_dimension_offset + (index - index % sort_dimension_offset) * sort_dimension_elements; auto compare_function = [&](int64 a, int64 b) -> bool { - int64 memory_index_lhs = (base_offset + a * sort_dimension_offset) * - values_primitive_type_size_in_bytes[0]; - int64 memory_index_rhs = (base_offset + b * sort_dimension_offset) * - values_primitive_type_size_in_bytes[0]; for (int32 i = 0; i < values_count; ++i) { + int64 memory_index_lhs = (base_offset + a * sort_dimension_offset) * + values_primitive_type_size_in_bytes[i]; + int64 memory_index_rhs = (base_offset + b * sort_dimension_offset) * + values_primitive_type_size_in_bytes[i]; comparison_values[i * 2] = values[i] + memory_index_lhs; comparison_values[i * 2 + 1] = values[i] + memory_index_rhs; } diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD index 51a12aee22f..f52de3394fe 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -95,7 +95,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "@com_google_absl//absl/memory", - "@llvm//:core", + "@llvm-project//llvm:core", ], ) @@ -110,9 +110,9 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "@com_google_absl//absl/strings", - "@llvm//:arm_code_gen", # fixdeps: keep - "@llvm//:target", - "@llvm//:x86_code_gen", # fixdeps: keep + "@llvm-project//llvm:arm_code_gen", # fixdeps: keep + "@llvm-project//llvm:target", + "@llvm-project//llvm:x86_code_gen", # fixdeps: keep ], ) @@ -142,9 +142,9 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "@com_google_absl//absl/strings", - "@llvm//:arm_code_gen", # fixdeps: keep - "@llvm//:target", - "@llvm//:x86_code_gen", # fixdeps: keep + "@llvm-project//llvm:arm_code_gen", # fixdeps: keep + "@llvm-project//llvm:target", + "@llvm-project//llvm:x86_code_gen", # fixdeps: keep ], ) @@ -246,8 +246,8 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "@com_google_absl//absl/strings", - "@llvm//:arm_code_gen", # fixdeps: keep - "@llvm//:target", - "@llvm//:x86_code_gen", # fixdeps: keep + "@llvm-project//llvm:arm_code_gen", # fixdeps: keep + "@llvm-project//llvm:target", + "@llvm-project//llvm:x86_code_gen", # fixdeps: keep ], ) diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc index 333626ef3b9..266e5be0d66 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc @@ -189,7 +189,7 @@ Status DynamicDimensionInferenceVisitor::HandleCustomCall(HloInstruction* hlo) { // dimensions. ShapeIndex data_output = {0}; parent_->SetDynamicSize(hlo, data_output, i, dynamic_size, - {.stride = 1, .multiple_of = 1}); + DimensionConstraint(1, 1)); } } return Status::OK(); @@ -215,11 +215,6 @@ Status DynamicDimensionInferenceVisitor::HandleSort(HloInstruction* hlo) { int64 dynamic_dimension, int64 operand_index, HloInstruction* dynamic_size, DimensionConstraint constraint) { HloSortInstruction* sort = Cast(hlo); - int64 sort_dimension = sort->sort_dimension(); - if (sort_dimension == dynamic_dimension) { - return Unimplemented( - "Dynamic dimension on sorting dimension is not supported"); - } if (sort->values_count() == 0) { parent_->SetDynamicSize(hlo, {}, dynamic_dimension, dynamic_size, constraint); @@ -466,7 +461,7 @@ Status DynamicDimensionInferenceVisitor::HandleConcatenate( dim_size_total, dynamic_dim)); } parent_->SetDynamicSize(hlo, {}, hlo->concatenate_dimension(), - dim_size_total, {.stride = 1, .multiple_of = 1}); + dim_size_total, DimensionConstraint(1, 1)); } // Simply pass through non-concat dynamic dimensions. @@ -521,7 +516,7 @@ Status DynamicDimensionInferenceVisitor::HandleSetDimensionSize( // Propagate dynamic dimension indicated by this set dimension size // instruction. parent_->SetDynamicSize(hlo, {}, hlo->dimension(), hlo->mutable_operand(1), - {.stride = 1, .multiple_of = 1}); + DimensionConstraint(1, 1)); } // Also Propagate dynamic dimension already set by operands. @@ -865,7 +860,7 @@ Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) { parent_->SetDynamicSize( reshape, {}, output_dynamic_dimension, new_dynamic_size, - {.stride = 1, .multiple_of = constraint.multiple_of / divisor}); + DimensionConstraint(1, constraint.multiple_of / divisor)); } if (input_dim_size < output_dim_size) { @@ -902,12 +897,12 @@ Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) { hlo->parent()->AddInstruction(HloInstruction::CreateBinary( output_dynamic_size->shape(), HloOpcode::kMultiply, new_dynamic_size, operand_dynamic_size)); + int64 new_multiple_of_constraint = + constraint.multiple_of * output_dim_size / + operand->shape().dimensions(input_dynamic_dimension); parent_->SetDynamicSize( reshape, {}, output_dynamic_dimension, new_dynamic_size, - {.stride = 1, - .multiple_of = - constraint.multiple_of * output_dim_size / - operand->shape().dimensions(input_dynamic_dimension)}); + DimensionConstraint(1, new_multiple_of_constraint)); } return Status::OK(); @@ -1279,7 +1274,7 @@ Status DynamicDimensionInferenceVisitor::HandleParameter(HloInstruction* hlo) { parent_->SetDynamicSize(target_parameter, dynamic_dimension.parameter_index, dynamic_dimension.dimension, dynamic_size, - {.stride = 1, .multiple_of = 1}); + DimensionConstraint(1, 1)); return Status::OK(); }); } diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.h b/tensorflow/compiler/xla/service/dynamic_dimension_inference.h index 21808385ec2..070127796d6 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.h +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.h @@ -149,6 +149,9 @@ class DynamicDimensionInference { // // struct DimensionConstraint { + explicit DimensionConstraint(int64 s, int64 m) + : stride(s), multiple_of(m) {} + DimensionConstraint() : stride(1), multiple_of(1) {} // Stride represents the distance of a newly placed element and the previous // placed element on this dynamic dimension. int64 stride; diff --git a/tensorflow/compiler/xla/service/dynamic_padder.cc b/tensorflow/compiler/xla/service/dynamic_padder.cc index f41a965825d..e09138f3e11 100644 --- a/tensorflow/compiler/xla/service/dynamic_padder.cc +++ b/tensorflow/compiler/xla/service/dynamic_padder.cc @@ -21,16 +21,21 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/str_format.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/comparison_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { @@ -569,6 +574,7 @@ Status RewriteDynamicReshapeSingleDim( } return Status::OK(); } + StatusOr RewriteDynamicConcat( HloInstruction* concat, DynamicDimensionInference* dynamic_dimension_inference) { @@ -618,6 +624,100 @@ StatusOr RewriteDynamicConcat( concat, rewritten_concat, {})); return true; } + +StatusOr RewriteDynamicSort( + HloInstruction* hlo, + DynamicDimensionInference* dynamic_dimension_inference) { + HloInstruction* dynamic_size = nullptr; + HloSortInstruction* sort = Cast(hlo); + HloComputation* comp = hlo->parent(); + int64 sort_dim = sort->sort_dimension(); + // Find the dynamic dimension in the operand. + for (auto* operand : sort->operands()) { + if (dynamic_size == nullptr) { + dynamic_size = + dynamic_dimension_inference->GetDynamicSize(operand, {}, sort_dim); + } + } + + if (dynamic_size == nullptr) { + // Not a dynamic sort, ignore. + return false; + } + + Shape operand_shape = + ShapeUtil::ChangeElementType(sort->operand(0)->shape(), S32); + HloInstruction* iota = + comp->AddInstruction(HloInstruction::CreateIota(operand_shape, sort_dim)); + HloInstruction* dynamic_size_broadcasted = comp->AddInstruction( + HloInstruction::CreateBroadcast(operand_shape, dynamic_size, {})); + HloInstruction* lt = comp->AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::ChangeElementType(operand_shape, PRED), iota, + dynamic_size_broadcasted, ComparisonDirection::kLt)); + sort->AppendOperand(lt); + + const int64 param_number_before_rewritten = + sort->called_computations()[0]->num_parameters(); + auto new_param_0 = HloInstruction::CreateParameter( + param_number_before_rewritten, ShapeUtil::MakeScalarShape(PRED), + "inbound_lhs"); + auto new_param_1 = HloInstruction::CreateParameter( + param_number_before_rewritten + 1, ShapeUtil::MakeScalarShape(PRED), + "inbound_rhs"); + std::vector extra_parameters{new_param_0.get(), + new_param_1.get()}; + HloComputation* sort_comp = sort->parent()->parent()->AddEmbeddedComputation( + sort->called_computations()[0]->CloneWithReplacements( + /*replacements=*/absl::flat_hash_map< + const HloInstruction*, std::unique_ptr>(), + extra_parameters)); + auto inbound_lhs = + sort_comp->parameter_instruction(param_number_before_rewritten); + auto inbound_rhs = + sort_comp->parameter_instruction(param_number_before_rewritten + 1); + sort->ReplaceCalledComputations( + [&](HloComputation* comp) { return sort_comp; }); + + // inbound_lhs & (sort_comp | !in_bound_rhs) + // Select the lhs if it is in bounds and the rhs is out of bounds or the + // sort_comp returns true. + auto out_of_bound_rhs = sort_comp->AddInstruction(HloInstruction::CreateUnary( + ShapeUtil::MakeScalarShape(PRED), HloOpcode::kNot, inbound_rhs)); + auto sort_comp_or_out_of_bound_rhs = + sort_comp->AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeScalarShape(PRED), HloOpcode::kOr, + sort_comp->root_instruction(), out_of_bound_rhs)); + + auto new_root = sort_comp->AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeScalarShape(PRED), HloOpcode::kAnd, inbound_lhs, + sort_comp_or_out_of_bound_rhs)); + sort_comp->set_root_instruction(new_root); + Shape compare_shape = + ShapeUtil::ChangeElementType(sort->operand(0)->shape(), PRED); + if (sort->shape().IsTuple()) { + // For sort that is already tuple, simply add another result to the tuple. + *sort->mutable_shape()->add_tuple_shapes() = + ShapeUtil::ChangeElementType(operand_shape, PRED); + } else { + auto sort_users = sort->users(); + auto sort_clone = comp->AddInstruction(sort->Clone()); + *sort_clone->mutable_shape() = ShapeUtil::MakeTupleShape( + {sort->shape(), ShapeUtil::ChangeElementType(operand_shape, PRED)}); + auto rewritten_sort = comp->AddInstruction( + HloInstruction::CreateGetTupleElement(sort->shape(), sort_clone, 0)); + for (HloInstruction* user : sort_users) { + TF_RETURN_IF_ERROR(sort->ReplaceUseWith(user, rewritten_sort)); + } + TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize( + sort, rewritten_sort, {})); + if (comp->root_instruction() == sort) { + comp->set_root_instruction(rewritten_sort); + } + } + + return true; +} + StatusOr RewriteDynamicReshape( HloInstruction* reshape, DynamicDimensionInference* dynamic_dimension_inference) { @@ -920,12 +1020,17 @@ StatusOr DynamicPadder::Run(HloModule* module) { DynamicDimensionInference::Run(module)); for (HloComputation* computation : module->computations()) { - for (HloInstruction* inst : computation->instructions()) { + for (HloInstruction* inst : computation->MakeInstructionPostOrder()) { if (inst->opcode() == HloOpcode::kConcatenate) { TF_ASSIGN_OR_RETURN( changed, RewriteDynamicConcat(inst, &dynamic_dimension_inference)); continue; } + if (inst->opcode() == HloOpcode::kSort) { + TF_ASSIGN_OR_RETURN( + changed, RewriteDynamicSort(inst, &dynamic_dimension_inference)); + continue; + } for (int64 operand_num = 0; operand_num < inst->operand_count(); ++operand_num) { HloInstruction* original_operand = inst->mutable_operand(operand_num); diff --git a/tensorflow/compiler/xla/service/dynamic_padder_test.cc b/tensorflow/compiler/xla/service/dynamic_padder_test.cc index 0e60e420d47..57e4a4e9af3 100644 --- a/tensorflow/compiler/xla/service/dynamic_padder_test.cc +++ b/tensorflow/compiler/xla/service/dynamic_padder_test.cc @@ -827,5 +827,84 @@ ENTRY main { EXPECT_EQ(result, expected); } +XLA_TEST_F(ExecutionTest, DynamicSort) { + const string hlo_text = R"( +HloModule TEST + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(lhs, rhs) +} + +%compare-greater-than (lhs: s32[], rhs: s32[]) -> pred[] { + %lhs = s32[] parameter(0) + %rhs = s32[] parameter(1) + ROOT %compare = pred[] compare(s32[] %lhs, s32[] %rhs), direction=GT +} + +ENTRY main { + param = s32[4] parameter(0) + size = s32[] constant(3) + param_dynamic_size = s32[4] set-dimension-size(param, size), + dimensions={0} + sort = s32[4]{0} sort(s32[4]{0} %param_dynamic_size), + dimensions={0}, is_stable=false, to_apply=%compare-greater-than + full_size = s32[] constant(4) + ROOT result = s32[4] set-dimension-size(sort, full_size), dimensions={0} +} +)"; + + Literal operand = LiteralUtil::CreateR1({1, 4, 3, 2}); + auto module = GetHloModule(hlo_text); + + Literal result = PadAndExecute(std::move(module), {&operand}); + Literal expected = LiteralUtil::CreateR1({4, 3, 1, 2}); + + EXPECT_EQ(result, expected); +} + +XLA_TEST_F(ExecutionTest, DynamicTupleSort) { + const string hlo_text = R"( +HloModule TEST + +%compare-greater-than (lhs: s32[], rhs: s32[], lhs_2: s32[], lhs_2: s32[]) -> pred[] { + %lhs = s32[] parameter(0) + %rhs = s32[] parameter(1) + %lhs_2 = s32[] parameter(2) + %rhs_2 = s32[] parameter(3) + ROOT %compare = pred[] compare(s32[] %lhs, s32[] %rhs), direction=GT +} + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(lhs, rhs) +} + +ENTRY main { + param = s32[3] parameter(0) + size = s32[] constant(2) + param_dynamic_size = s32[3] set-dimension-size(param, size), + dimensions={0} + sort = (s32[3]{0}, s32[3]{0}) sort(s32[3]{0} %param_dynamic_size, + s32[3]{0} %param_dynamic_size), + dimensions={0}, is_stable=true, to_apply=%compare-greater-than + get-tuple-element = s32[3]{0} get-tuple-element((s32[3]{0}, s32[3]{0}) %sort), + index=0 + full_size = s32[] constant(3) + ROOT result = s32[3] set-dimension-size(get-tuple-element, full_size), dimensions={0} +} +)"; + + Literal operand = LiteralUtil::CreateR1({0, 4, 2}); + auto module = GetHloModule(hlo_text); + + Literal result = PadAndExecute(std::move(module), {&operand}); + Literal expected = LiteralUtil::CreateR1({4, 0, 2}); + + EXPECT_EQ(result, expected); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/g3doc/hlo_parser.md b/tensorflow/compiler/xla/service/g3doc/hlo_parser.md index f0f3dd7785c..5c3b1540600 100644 --- a/tensorflow/compiler/xla/service/g3doc/hlo_parser.md +++ b/tensorflow/compiler/xla/service/g3doc/hlo_parser.md @@ -116,29 +116,6 @@ non_tuple | rank2345 ; rank2345 - : shape sparse_or_nested_array + : nested_array ; -sparse_or_nested_array - : sparse_array - | nested_array - ; -sparse_array - : '{' sparse_array1 '}' - ; -sparse_array1 - : sparse_array_item - | sparse_array1 ',' sparse_array_item - ; -sparse_array_item - : multi_index ':' scalar - ; -multi_index - : kInt - | '[' multi_index1 ']' - ; -multi_index1 - : kInt - | multi_index1 ',' kInt - ; - ``` diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 13e8a3f4409..fb085a237f1 100755 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -184,7 +184,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@llvm//:core", + "@llvm-project//llvm:core", ], ) @@ -198,8 +198,8 @@ cc_library( "//tensorflow/core:lib", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@llvm//:core", - "@llvm//:support", + "@llvm-project//llvm:core", + "@llvm-project//llvm:support", ], ) @@ -287,8 +287,8 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", - "@llvm//:core", - "@llvm//:support", + "@llvm-project//llvm:core", + "@llvm-project//llvm:support", ], ) @@ -306,7 +306,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", "//tensorflow/core:lib", - "@llvm//:core", + "@llvm-project//llvm:core", ], ) @@ -335,8 +335,8 @@ cc_library( "//tensorflow/core:lib", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@llvm//:core", - "@llvm//:support", + "@llvm-project//llvm:core", + "@llvm-project//llvm:support", ], ) @@ -594,7 +594,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", - "@llvm//:core", + "@llvm-project//llvm:core", ], ) @@ -1068,7 +1068,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/memory", - "@llvm//:core", + "@llvm-project//llvm:core", ], alwayslink = True, # Contains per-platform transfer manager registration ) @@ -1093,7 +1093,6 @@ cc_library( ":gpu_copy_insertion", ":gpu_executable", ":gpu_hlo_schedule", - ":gpu_hlo_support_checker", ":gpu_layout_assignment", ":gpu_sanitize_constant_names", ":gpu_scatter_expander", @@ -1116,6 +1115,7 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:call_inliner", "//tensorflow/compiler/xla/service:conditional_simplifier", + "//tensorflow/compiler/xla/service:convolution_group_converter", "//tensorflow/compiler/xla/service:depthwise_convolution_converter", "//tensorflow/compiler/xla/service:dot_decomposer", "//tensorflow/compiler/xla/service:dump", @@ -1161,7 +1161,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", - "@llvm//:core", + "@llvm-project//llvm:core", ], ) @@ -1196,6 +1196,7 @@ cc_library( ":gpu_conv_padding_legalization", ":gpu_conv_rewriter", ":gpu_layout_assignment", + ":ir_emission_utils", ":reduction_degenerate_dim_remover", ":reduction_dimension_grouper", ":reduction_layout_normalizer", @@ -1414,18 +1415,6 @@ tf_cc_test( ], ) -cc_library( - name = "gpu_hlo_support_checker", - srcs = ["gpu_hlo_support_checker.cc"], - hdrs = ["gpu_hlo_support_checker.h"], - deps = [ - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:xla_data_proto_cc", - "//tensorflow/compiler/xla/service:hlo_pass", - "//tensorflow/core:lib", - ], -) - cc_library( name = "stream_executor_util", srcs = ["stream_executor_util.cc"], @@ -1453,20 +1442,6 @@ cc_library( ], ) -tf_cc_test( - name = "gpu_hlo_support_checker_test", - srcs = ["gpu_hlo_support_checker_test.cc"], - deps = [ - ":gpu_hlo_support_checker", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - ], -) - cc_library( name = "buffer_comparator", srcs = ["buffer_comparator.cc"], @@ -1604,6 +1579,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:pattern_matcher", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc index 37095adf7c6..4ecf6ed8007 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc @@ -577,10 +577,24 @@ static StatusOr DeviceCompare(se::Stream* stream, se::DeviceMemory rhs_typed(rhs); uint64 buffer_size = lhs_typed.ElementCount(); - TF_ASSIGN_OR_RETURN(absl::Span compiled_ptx, - se::CompileGpuAsmOrGetCached(executor->device_ordinal(), - buffer_compare_ptx, - PtxOptsFromConfig(config))); + absl::Span compiled_ptx = {}; + StatusOr> compiled_ptx_or = + se::CompileGpuAsmOrGetCached(executor->device_ordinal(), + buffer_compare_ptx, + PtxOptsFromConfig(config)); + if (compiled_ptx_or.ok()) { + compiled_ptx = compiled_ptx_or.ConsumeValueOrDie(); + } else { + static std::once_flag ptxas_not_found_logged; + std::call_once(ptxas_not_found_logged, [&]() { + LOG(WARNING) + << compiled_ptx_or.status().ToString() + << "\nRelying on driver to perform ptx compilation. " + << "\nSetting XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda " + << " or modifying $PATH can be used to set the location of ptxas" + << "\nThis message will only be logged once."; + }); + } TF_ASSIGN_OR_RETURN( std::unique_ptr> comparison_kernel, diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 30b204e6fd5..04761123127 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/conditional_simplifier.h" +#include "tensorflow/compiler/xla/service/convolution_group_converter.h" #include "tensorflow/compiler/xla/service/depthwise_convolution_converter.h" #include "tensorflow/compiler/xla/service/dot_decomposer.h" #include "tensorflow/compiler/xla/service/dump.h" @@ -48,7 +49,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" #include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h" #include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" #include "tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.h" #include "tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.h" @@ -134,15 +134,31 @@ Status GpuCompiler::OptimizeHloModule( pipeline.AddPass(); pipeline.AddPass(); - pipeline.AddPass(); // TODO(b/64094172): make Call work on GPU instead of inlining. pipeline.AddPass(); + + pipeline.AddPass(); + + // We use the ConvolutionGroupConverter to convert backprops of filter + // grouped convolutions into non-grouped equivalents. + auto batch_group_cost_model = [](HloInstruction* conv) { + auto dim_numbers = conv->convolution_dimension_numbers(); + const int64 input_batch_size = conv->operand(0)->shape().dimensions( + dim_numbers.input_batch_dimension()); + return conv->batch_group_count() != input_batch_size; + }; + + pipeline.AddPass( + batch_group_cost_model, + /*convert_batch_groups_only=*/true, + /*canonicalize_depthwise_filter=*/false); + auto cost_model = [](HloInstruction* conv) { // We need a cost model for GPUs. Currently, do nothing. return false; }; - pipeline.AddPass(); + pipeline.AddPass(cost_model); // Expand the sort op to support stable sorting if required. pipeline.AddPass(); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc deleted file mode 100644 index 4765f67c4b1..00000000000 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h" - -#include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/core/lib/core/errors.h" - -namespace xla { - -StatusOr GpuHloSupportChecker::Run(HloModule* module) { - for (auto* computation : module->computations()) { - for (const auto& instruction : computation->instructions()) { - TF_RETURN_IF_ERROR( - ShapeUtil::ValidateShapeWithOptionalLayout(instruction->shape())); - TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( - instruction->shape(), - [&instruction](const Shape& subshape, const ShapeIndex&) { - if (LayoutUtil::IsSparseArray(subshape)) { - return xla::Unimplemented( - "GPU backend does not support HLO instruction %s with shape " - "containing a sparse layout: %s", - instruction->ToString(), - ShapeUtil::HumanStringWithLayout(instruction->shape())); - } - return Status::OK(); - })); - } - } - return false; -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h deleted file mode 100644 index 8b19769a781..00000000000 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h +++ /dev/null @@ -1,40 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SUPPORT_CHECKER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SUPPORT_CHECKER_H_ - -#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" - -namespace xla { - -// This pass should run early in the HLO pipeline and checks for HLO constructs -// which are not supported by the GPU backend and cannot be removed via HLO -// transformations (eg, sparse layouts). -class GpuHloSupportChecker : public HloModulePass { - public: - GpuHloSupportChecker() = default; - ~GpuHloSupportChecker() override = default; - - absl::string_view name() const override { return "gpu_hlo_support_checker"; } - - // Note: always returns false (no instructions are ever modified by this - // pass). - StatusOr Run(HloModule* module) override; -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SUPPORT_CHECKER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc deleted file mode 100644 index 0bd43ec9b23..00000000000 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc +++ /dev/null @@ -1,76 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h" - -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/protobuf/error_codes.pb.h" - -namespace xla { -namespace { - -using ::testing::HasSubstr; - -class GpuHloSupportCheckerTest : public HloTestBase { - protected: - GpuHloSupportChecker& checker() { return checker_; } - - private: - GpuHloSupportChecker checker_; -}; - -TEST_F(GpuHloSupportCheckerTest, Add) { - HloComputation::Builder builder(TestName()); - const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); - HloInstruction* param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape, "param0")); - HloInstruction* param1 = builder.AddInstruction( - HloInstruction::CreateParameter(1, scalar_shape, "param1")); - builder.AddInstruction(HloInstruction::CreateBinary( - scalar_shape, HloOpcode::kAdd, param0, param1)); - auto module = CreateNewVerifiedModule(); - module->AddEntryComputation(builder.Build()); - - TF_ASSERT_OK(checker().Run(module.get()).status()); -} - -TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) { - HloComputation::Builder builder(TestName()); - const Shape sparse_shape = ShapeUtil::MakeShapeWithSparseLayout(F32, {10}, 2); - HloInstruction* param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, sparse_shape, "param0")); - HloInstruction* param1 = builder.AddInstruction( - HloInstruction::CreateParameter(1, sparse_shape, "param1")); - builder.AddInstruction(HloInstruction::CreateBinary( - sparse_shape, HloOpcode::kAdd, param0, param1)); - // Since verifier is reporting sparse layouts as errors, we should - // use a regular HloModule instead of VerifiedHloModule to avoid - // verifier errors being triggered in the destructor. - auto module = CreateNewUnverifiedModule(); - module->AddEntryComputation(builder.Build()); - - Status status = checker().Run(module.get()).status(); - ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); - EXPECT_THAT(status.error_message(), - HasSubstr("GPU backend does not support")); - EXPECT_THAT(status.error_message(), - HasSubstr(ShapeUtil::HumanStringWithLayout(sparse_shape))); -} - -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index b2067fe916d..2ff03354ea8 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "llvm/IR/IntrinsicsNVPTX.h" #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/service/gpu/target_util.h" @@ -234,6 +235,31 @@ bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce) { return reduction_dimensions.dimensions[1] >= kWarpSize; } +bool IsInputFusibleSlices(const HloInstruction& unnested_hlo, + bool verify_no_strides) { + if (!unnested_hlo.IsInputFusion()) { + return false; + } + + auto is_non_strided = [](const std::vector& strides) -> bool { + return absl::c_all_of(strides, [](int stride) { return stride == 1; }); + }; + + const HloInstruction* root = unnested_hlo.fused_expression_root(); + if (root->opcode() == HloOpcode::kSlice) { + return !verify_no_strides || is_non_strided(root->slice_strides()); + } + + if (root->opcode() != HloOpcode::kTuple) { + return false; + } + + return absl::c_all_of(root->operands(), [&](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kSlice && + (!verify_no_strides || is_non_strided(instr->slice_strides())); + }); +} + ReductionDimensions GetReductionKindAndContiguousComponents( const HloInstruction& reduce) { const Shape& input_shape = reduce.operand(0)->shape(); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index 2c37a63c05a..601a63ccede 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -157,6 +157,12 @@ bool ImplementedAsLibraryCall(const HloInstruction& hlo); // kept are contiguous in the input of the reduce instruction. bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce); +// Returns whether unnested_hlo is an input fusion whose root is either a slice +// or a tuple of slices. If verify_no_strides is true, returns false unless all +// ROOT slices have no strides. +bool IsInputFusibleSlices(const HloInstruction& unnested_hlo, + bool verify_no_strides = false); + struct ReductionDimensions { // Indicates whether the reduction is a row reduction or a column reduction. bool is_row_reduction; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index b65c5c7461d..684a513bf1e 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -301,6 +301,44 @@ llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, int64 launch_size, return b->getInt32Ty(); } +// Gets the input shape of the ROOT slices, which will be used as the kernel +// launch dims. The slice input fusion requires the input shapes of the ROOT +// slices to be the same although the (slice) output shapes can be different. +// +// Returns the input shape of the ROOT slices if all the input shapes of ROOT +// slices are the same and the slices are non-strided. Otherwise, returns +// FailedPrecondition. +StatusOr GetConsistentInputShapeForRootSlices( + const HloInstruction& fusion) { + if (!IsInputFusibleSlices(fusion, /*verify_no_strides=*/true)) { + return FailedPrecondition( + "Unsupported root for slice input fusion. " + "Only non-strided slices are supported."); + } + + const HloInstruction& root = *fusion.fused_expression_root(); + if (root.opcode() == HloOpcode::kSlice) { + return root.operands()[0]->shape(); + } + + CHECK_EQ(root.opcode(), HloOpcode::kTuple); + const Shape& first_slice_operand_shape = + root.operands()[0]->operands()[0]->shape(); + for (size_t i = 1; i < root.operands().size(); ++i) { + const HloInstruction* slice = root.operands()[i]; + const Shape& operand_shape = slice->operands()[0]->shape(); + if (!ShapeUtil::EqualIgnoringElementType(first_slice_operand_shape, + operand_shape)) { + return FailedPrecondition( + "Fused slices do not have the same input shape, fused computation = " + "%s.", + root.parent()->name()); + } + } + + return first_slice_operand_shape; +} + } // namespace Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) { @@ -388,7 +426,13 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { absl::make_unique(std::move(thunks), fusion)); return Status::OK(); } + // In the case of root tuple, it can be either reduce or slice input + // fusion. case HloOpcode::kTuple: { + if (IsInputFusibleSlices(*fusion)) { + return EmitInputFusibleNonStridedSlices(fusion); + } + CHECK_GE(root->operand_count(), 1); return EmitReductionFromOrToContiguousDimensions(fusion, root->operands()); @@ -404,6 +448,9 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { } return EmitReductionFromOrToContiguousDimensions(fusion, {root}); } + case HloOpcode::kSlice: { + return EmitInputFusibleNonStridedSlices(fusion); + } default: LOG(FATAL) << "Bad opcode for input fusion: " << fusion->fused_expression_root()->opcode(); @@ -3060,5 +3107,121 @@ Status IrEmitterUnnested::EmitConstantGlobals() { return Status::OK(); } +// Emits code for slices based on the below structure. An if statement with +// a guarding condition is generated for each ROOT slice. +// +// Pseudo code: +// +// Compute values of slice input operands +// +// Compute guarding_cond0 +// if (guarding_cond0) { +// Write to output of slice0 +// } +// +// Compute guarding_cond1 +// if (guarding_cond1) { +// Write to output of slice1 +// } +// +void IrEmitterUnnested::EmitElementForInputFusibleSlices( + HloInstruction* unnested_hlo, const llvm_ir::IrArray::Index& index) { + VLOG(10) << "Emitting slice input fusion for " << unnested_hlo->ToString(); + + HloInstruction* slice_or_tuple = unnested_hlo->fused_expression_root(); + auto slice_instructions = [&]() -> absl::Span { + if (slice_or_tuple->opcode() == HloOpcode::kSlice) { + return absl::Span(&slice_or_tuple, 1); + } + CHECK_EQ(slice_or_tuple->opcode(), HloOpcode::kTuple); + return slice_or_tuple->operands(); + }(); + + // Emit input operand values of slices. + std::vector input_ir_values; + GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_, + GetNestedComputer()); + FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(unnested_hlo), + &elem_emitter); + TF_CHECK_OK(unnested_hlo->fused_expression_root()->Accept(&fused_emitter)); + for (const HloInstruction* slice : slice_instructions) { + auto input_generator = fused_emitter.GetGenerator(slice->operand(0)); + input_ir_values.push_back(input_generator(index).ValueOrDie()); + } + + // Emit for slice_instructions. + KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll); + for (int64 i = 0; i < slice_instructions.size(); ++i) { + HloInstruction* slice = slice_instructions[i]; + + // guarding_cond := index >= start && index < limit, for each dim. + std::vector index_within_ranges; + for (size_t dim = 0; dim < slice->slice_starts().size(); ++dim) { + CHECK_EQ(slice->slice_strides(dim), 1); + auto larger_or_equal_than_start = b_.CreateICmpSGE( + index.multidim()[dim], + index.GetConstantWithIndexType(slice->slice_starts(dim))); + llvm::Value* smaller_than_limit = b_.CreateICmpSLT( + index.multidim()[dim], + index.GetConstantWithIndexType(slice->slice_limits(dim))); + llvm::Value* within_range = + b_.CreateAnd(larger_or_equal_than_start, smaller_than_limit); + index_within_ranges.push_back(within_range); + } + llvm::Value* guarding_cond = b_.CreateAnd(index_within_ranges); + + auto emit_slice_elem_func = [&] { + const std::vector& src_multidim = index.multidim(); + std::vector dst_multidim(src_multidim.size()); + for (size_t dim = 0; dim < src_multidim.size(); ++dim) { + dst_multidim[dim] = + Sub(src_multidim[dim], + index.GetConstantWithIndexType(slice->slice_starts(dim))); + } + ShapeIndex shape_index = (slice_or_tuple->opcode() == HloOpcode::kSlice) + ? ShapeIndex() + : ShapeIndex({i}); + llvm_ir::IrArray src_ir_array = + GetIrArray(*unnested_hlo, *unnested_hlo, shape_index); + IrArray::Index slice_dst_index(dst_multidim, slice->shape(), + index.GetType()); + llvm::Value* dst_addr = src_ir_array.EmitArrayElementAddress( + slice_dst_index, &b_, "slice.dest"); + b_.CreateStore(input_ir_values[i], dst_addr); + }; + + ksl.If(StrCat("slice", i), guarding_cond, emit_slice_elem_func); + } +} + +Status IrEmitterUnnested::EmitInputFusibleNonStridedSlices( + HloInstruction* unnested_hlo) { + constexpr int unroll_factor = 1; + std::unique_ptr kernel_thunk = BuildKernelThunk( + unnested_hlo, /*implements_whole_instruction=*/true, unroll_factor); + + TF_ASSIGN_OR_RETURN(Shape element_shape, + GetConsistentInputShapeForRootSlices(*unnested_hlo)); + LaunchDimensions launch_dimensions = CalculateLaunchDimensions( + element_shape, ir_emitter_context_->device_description(), unroll_factor); + UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(), + ir_emitter_context_->llvm_module()); + + Status emit_status = + ParallelLoopEmitter( + [&](const llvm_ir::IrArray::Index index) -> Status { + EmitElementForInputFusibleSlices(unnested_hlo, index); + return Status::OK(); + }, + element_shape, launch_dimensions, &b_) + .EmitLoop(IrName(unnested_hlo), + GetIndexTypeForKernel( + unnested_hlo, launch_dimensions.launch_bound(), &b_)); + + thunk_sequence_->emplace_back(std::move(kernel_thunk)); + + return emit_status; +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index fb64da6b43e..42a18e6547d 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -184,6 +184,19 @@ class IrEmitterUnnested : public IrEmitter, ReductionCodegenInfo ComputeReductionCodegenInfo( const HloInstruction* unnested_hlo, const HloInstruction* first_reduce); + // Generates code for input-fusible slices. + // + // Prerequisite: ROOT is either a slice or a tuple of slices. The input shapes + // of all ROOT slices need to be the same while their output shapes can be + // different. On the other hand, the input ranges of slices can be + // overlapping. Further generalization/specialization when the needs are seen + // in the future. + Status EmitInputFusibleNonStridedSlices(HloInstruction* unnested_hlo); + + void EmitElementForInputFusibleSlices( + HloInstruction* unnested_hlo, + const llvm_ir::IrArray::Index& slice_input_index); + // Emits code for an in-place scatter, modifying `thunk`s launch dimensions in // the process. `scatter` may be fused, scatter indices are taken from // `scatter_indices_gen`, updates from`updates_gen`. The output buffer is diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD index db26d36c71a..9203664e4c7 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD @@ -38,20 +38,20 @@ cc_library( "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@llvm//:amdgpu_code_gen", - "@llvm//:analysis", - "@llvm//:bit_reader", - "@llvm//:bit_writer", - "@llvm//:code_gen", - "@llvm//:core", - "@llvm//:ipo", - "@llvm//:ir_reader", - "@llvm//:linker", - "@llvm//:nvptx_code_gen", # buildcleaner: keep - "@llvm//:objc_arc", # buildcleaner: keep - "@llvm//:scalar", - "@llvm//:support", - "@llvm//:target", + "@llvm-project//llvm:amdgpu_code_gen", + "@llvm-project//llvm:analysis", + "@llvm-project//llvm:bit_reader", + "@llvm-project//llvm:bit_writer", + "@llvm-project//llvm:code_gen", + "@llvm-project//llvm:core", + "@llvm-project//llvm:ipo", + "@llvm-project//llvm:ir_reader", + "@llvm-project//llvm:linker", + "@llvm-project//llvm:nvptx_code_gen", # buildcleaner: keep + "@llvm-project//llvm:objc_arc", # buildcleaner: keep + "@llvm-project//llvm:scalar", + "@llvm-project//llvm:support", + "@llvm-project//llvm:target", ], ) @@ -68,7 +68,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", - "@llvm//:core", - "@llvm//:support", + "@llvm-project//llvm:core", + "@llvm-project//llvm:support", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc old mode 100755 new mode 100644 index fa01d75d35a..d48c36b4b29 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_conv_padding_legalization.h" #include "tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" #include "tensorflow/compiler/xla/service/gpu/reduction_degenerate_dim_remover.h" #include "tensorflow/compiler/xla/service/gpu/reduction_dimension_grouper.h" @@ -134,6 +135,8 @@ Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization( /*allow_mixed_precision=*/false); AlgebraicSimplifierOptions options; + options.set_cudnn_batchnorm_forward_training_metadata( + kCudnnBatchNormForwardTrainingCallTarget); pass.AddPass(options); } @@ -432,7 +435,7 @@ std::vector NVPTXCompiler::CompileGpuAsmOrGetCachedResult( "Can't find ptxas binary in ${CUDA_DIR}/bin. Will back to the " "GPU driver for PTX -> sass compilation. This is OK so long " "as you don't see a warning below about an out-of-date driver " - "version.", + "version. Custom ptxas location can be specified using $PATH.", hlo_module_config); } diff --git a/tensorflow/compiler/xla/service/gpu/target_util.cc b/tensorflow/compiler/xla/service/gpu/target_util.cc index 48c703183fc..49eadd8c6be 100644 --- a/tensorflow/compiler/xla/service/gpu/target_util.cc +++ b/tensorflow/compiler/xla/service/gpu/target_util.cc @@ -18,6 +18,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/target_util.h" #include "absl/strings/str_cat.h" +#include "llvm/IR/IntrinsicsAMDGPU.h" +#include "llvm/IR/IntrinsicsNVPTX.h" #include "llvm/IR/MDBuilder.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD index 51a12e1f2fe..d723a1a6927 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -338,6 +338,21 @@ tf_cc_test( ], ) +tf_cc_test( + name = "gpu_input_fusible_slice_test", + srcs = ["gpu_input_fusible_slice_test.cc"], + tags = tf_cuda_tests_tags(), + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + xla_test( name = "gpu_convolution_regression_test", srcs = ["gpu_convolution_regression_test.cc"], diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_input_fusible_slice_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_input_fusible_slice_test.cc new file mode 100644 index 00000000000..7f345c19331 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_input_fusible_slice_test.cc @@ -0,0 +1,158 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +class GpuSliceInputFusionTest : public GpuCodegenTest { + protected: + GpuSliceInputFusionTest() {} + + HloModuleConfig ConfigWithoutLayoutAssignment() { + HloModuleConfig config; + auto debug_options = HloTestBase::GetDebugOptionsForTest(); + // Disable the layout_assignment pass to use the preassigned layouts; + // otherwise, the pass throws away the layouts in the fusion computation. + debug_options.add_xla_disable_hlo_passes("layout-assignment"); + config.set_debug_options(debug_options); + return config; + } +}; + +TEST_F(GpuSliceInputFusionTest, InputFusionWithOnlyOneSlice) { + const char *const kHloString = R"( + HloModule input_fusion_with_only_one_slice + + fused_computation { + arg.1 = f16[1024,512]{1,0} parameter(0) + arg.2 = f16[1024,512]{1,0} parameter(1) + arg1.conv = f32[1024,512]{1,0} convert(arg.1) + arg2.conv = f32[1024,512]{1,0} convert(arg.2) + add.1 = f32[1024,512]{1,0} add(arg1.conv, arg2.conv) + ROOT slice.1 = f32[512,511]{1,0} slice(add.1), slice={[512:1024], [1:512]} + } + + ENTRY kernel_entry { + arg.1 = f16[1024,512]{1,0} parameter(0) + arg.2 = f16[1024,512]{1,0} parameter(1) + ROOT fusion = f32[512, 511]{1,0} fusion(arg.1, arg.2), kind=kInput, + calls=fused_computation + })"; + + auto hlo_module = + ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) + .ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @fusion +; CHECK: slice0 +; CHECK: } +)", + /*match_optimized_ir=*/false); + // Check that the kernel runs correctly. + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0, 0})); +} + +TEST_F(GpuSliceInputFusionTest, InputFusionWithATupleOfSlices) { + const char *const kHloString = R"( + HloModule input_fusion_with_a_tuple_of_slices + + fused_computation { + arg.1 = f16[1024,512]{1,0} parameter(0) + arg.2 = f16[1024,512]{1,0} parameter(1) + mul.1 = f16[1024,512]{1,0} multiply(arg.1, arg.2) + add.1 = f16[1024,512]{1,0} add(mul.1, arg.2) + slice.1 = f16[512,511]{1,0} slice(arg.1), slice={[512:1024], [1:512]} + slice.2 = f16[0,512]{1,0} slice(add.1), slice={[512:512], [0:512]} + slice.3 = f16[1,1]{1,0} slice(add.1), slice={[512:513], [511:512]} + ROOT tuple.1 = (f16[512,511]{1,0}, f16[0,512]{1,0}, f16[1,1]{1,0}) + tuple(slice.1, slice.2, slice.3) + } + + ENTRY kernel_entry { + arg.1 = f16[1024,512]{1,0} parameter(0) + arg.2 = f16[1024,512]{1,0} parameter(1) + ROOT fusion = (f16[512,511]{1,0}, f16[0,512]{1,0}, f16[1,1]{1,0}) + fusion(arg.1, arg.2), kind=kInput, calls=fused_computation + })"; + + auto hlo_module = + ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) + .ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @fusion +; CHECK: slice2 +; CHECK: } +)", + /*match_optimized_ir=*/false); + // Check that the kernel runs correctly. + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0, 0})); +} + +TEST_F(GpuSliceInputFusionTest, ConcatThenSplit) { + const char *const kHloString = R"( + HloModule input_fusion_with_a_tuple_of_slices + + fused_computation { + arg.1 = f16[1024]{0} parameter(0) + arg.2 = f16[1024]{0} parameter(1) + arg.3 = f16[1023]{0} parameter(2) + arg.4 = f16[1023]{0} parameter(3) + mul.1 = f16[1024]{0} multiply(arg.1, arg.2) + add.1 = f16[1023]{0} add(arg.3, arg.4) + concat.1 = f16[2047]{0} concatenate(mul.1, add.1), dimensions={0} + slice.1 = f16[1024]{0} slice(concat.1), slice={[0:1024]} + slice.2 = f16[1023]{0} slice(concat.1), slice={[1024:2047]} + slice.3 = f16[0]{0} slice(concat.1), slice={[2047:2047]} + ROOT tuple.1 = (f16[1024]{0}, f16[1023]{0}, f16[0]{0}) + tuple(slice.1, slice.2, slice.3) + } + + ENTRY kernel_entry { + arg.1 = f16[1024]{0} parameter(0) + arg.2 = f16[1024]{0} parameter(1) + arg.3 = f16[1023]{0} parameter(2) + arg.4 = f16[1023]{0} parameter(3) + ROOT fusion = (f16[1024]{0}, f16[1023]{0}, f16[0]{0}) + fusion(arg.1, arg.2, arg.3, arg.4), kind=kInput, calls=fused_computation + })"; + + auto hlo_module = + ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) + .ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @fusion +; CHECK: slice2 +; CHECK: } +)", + /*match_optimized_ir=*/false); + // Check that the kernel runs correctly. + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0, 0})); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 65b813b2e24..962be890102 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -31,6 +31,12 @@ namespace xla { using absl::flat_hash_map; using absl::flat_hash_set; +bool HeapSimulator::Chunk::OverlapsWith(Chunk other_chunk) const { + CHECK_NE(size, 0); + CHECK_NE(other_chunk.size, 0); + return offset < other_chunk.chunk_end() && other_chunk.offset < chunk_end(); +} + /*static*/ StatusOr HeapSimulator::MinimumMemoryForModule( const HloSchedule& schedule, @@ -591,8 +597,7 @@ void GlobalDecreasingSizeBestFitHeap::Free(const HloValue* buffer, int64 size) { using Chunk = HeapSimulator::Chunk; -void GlobalDecreasingSizeBestFitHeap::BufferIntervalTree::Add( - int64 start, int64 end, const Chunk& chunk) { +void BufferIntervalTree::Add(int64 start, int64 end, const Chunk& chunk) { node_storage_.emplace_back( BufferIntervalTreeNode{start, end, end, chunk, nullptr, nullptr}); @@ -620,8 +625,7 @@ void GlobalDecreasingSizeBestFitHeap::BufferIntervalTree::Add( } } -std::vector -GlobalDecreasingSizeBestFitHeap::BufferIntervalTree::ChunksOverlappingInTime( +std::vector BufferIntervalTree::ChunksOverlappingInTime( int64 start, int64 end) const { std::vector result; if (node_storage_.empty()) { diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h index ac047de3ec7..2bb0eda249f 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.h +++ b/tensorflow/compiler/xla/service/heap_simulator.h @@ -57,6 +57,8 @@ class HeapSimulator { int64 size; int64 chunk_end() const { return offset + size; } + + bool OverlapsWith(Chunk other_chunk) const; }; // Result represents the result of the heap simulation. @@ -284,6 +286,39 @@ class NoFragmentationStatsHeap : public HeapAlgorithm { int64 max_heap_size_ = 0; }; +// Node in BufferIntervalTree that stores the alloc and free times of a buffer, +// and the chunk assigned to it. +struct BufferIntervalTreeNode { + // Alloc time. + int64 start; + // Free time. + int64 end; + // Maximum free time of all nodes in the subtree where this node is the root. + int64 subtree_end; + // Allocated chunk for the buffer. + HeapSimulator::Chunk chunk; + // Left child. + BufferIntervalTreeNode* left; + // Right child. + BufferIntervalTreeNode* right; +}; + +// An interval tree that can query buffers overlapping in time. +class BufferIntervalTree { + public: + using Chunk = HeapSimulator::Chunk; + // Adds a buffer to the interval tree, with the time interval and allocated + // chunk specified. + void Add(int64 start, int64 end, const Chunk& chunk); + + // Returns vector of allocated chunks that overlap with the given time + // interval. + std::vector ChunksOverlappingInTime(int64 start, int64 end) const; + + private: + std::list node_storage_; +}; + // GlobalDecreasingSizeBestFitHeap collects the live intervals of all buffers, // then allocates them in decreasing spatial or temporal size regardless of the // alloc/free time. It internally tracks the allocated buffers and their live @@ -334,39 +369,6 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm { static BufferIntervalCompare GetSpatialBufferIntervalCompare(); protected: - // Node in BufferIntervalTree that stores the alloc and free times of a - // buffer, and the chunk assigned to it. - struct BufferIntervalTreeNode { - // Alloc time. - int64 start; - // Free time. - int64 end; - // Maximum free time of all nodes in the subtree where this node is the - // root. - int64 subtree_end; - // Allocated chunk for the buffer. - HeapSimulator::Chunk chunk; - // Left child. - BufferIntervalTreeNode* left; - // Right child. - BufferIntervalTreeNode* right; - }; - - // An interval tree that can query buffers overlapping in time. - class BufferIntervalTree { - public: - // Adds a buffer to the interval tree, with the time interval and allocated - // chunk specified. - void Add(int64 start, int64 end, const Chunk& chunk); - - // Returns vector of allocated chunks that overlap with the given time - // interval. - std::vector ChunksOverlappingInTime(int64 start, int64 end) const; - - private: - std::list node_storage_; - }; - // The candidate contains a chunk and the resultant heap size if this // chunk is to be committed. struct ChunkCandidate { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index f8fbaf19c5c..4322c26b2de 100755 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -2987,8 +2987,8 @@ static Status PostOrderDFS(HloInstruction* root, Visitor* visitor, visitor->GetVisitState(current_id); if (visit_state == Visitor::kVisited) { dfs_stack.pop_back(); - VLOG(3) << "Not visiting HLO %" << current_node->name() - << " as it was already visited."; + VLOG(3) << "Not visiting HLO (id = " << current_id + << ") as it was already visited."; continue; } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 2ab606d7100..104bca8e876 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -1792,6 +1792,10 @@ class HloInstruction { // Delegates to HloCholeskyInstruction::cholesky_options(). const CholeskyOptions& cholesky_options() const; + // Appends operand to the list of operands and adds this instruction as a user + // of the operand. + void AppendOperand(HloInstruction* operand); + // Old methods kept for smooth subclassing transition END. protected: @@ -1831,10 +1835,6 @@ class HloInstruction { // by factory methods. HloInstruction(HloOpcode opcode, const Shape& shape); - // Appends operand to the list of operands and adds this instruction as a user - // of the operand. - void AppendOperand(HloInstruction* operand); - void RemoveOperandAt(int index) { operands_.erase(operands_.begin() + index); } diff --git a/tensorflow/compiler/xla/service/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc index 5de3717e26c..bc1745a0791 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.cc +++ b/tensorflow/compiler/xla/service/hlo_lexer.cc @@ -280,7 +280,6 @@ TokKind HloLexer::LexIdentifier() { KEYWORD(ROOT); KEYWORD(maximal); KEYWORD(replicated); - KEYWORD(sparse); #undef KEYWORD @@ -496,8 +495,6 @@ string TokKindToString(TokKind kind) { return "kw_inf"; case TokKind::kNegInf: return "kNegInf"; - case TokKind::kw_sparse: - return "kw_sparse"; case TokKind::kPrimitiveType: return "kPrimitiveType"; case TokKind::kName: diff --git a/tensorflow/compiler/xla/service/hlo_lexer.h b/tensorflow/compiler/xla/service/hlo_lexer.h index d4a49fea200..6a59f180ad8 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.h +++ b/tensorflow/compiler/xla/service/hlo_lexer.h @@ -63,7 +63,6 @@ enum class TokKind { kw_replicated, kw_nan, kw_inf, - kw_sparse, kNegInf, // -inf diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index b05f76a1d29..ecb25298288 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -72,10 +72,6 @@ HloSchedule ScheduleFromInstructionOrder(HloModule* module) { return schedule; } -// Some functions accept either a linear index or a multi-dimensional index -// (used for indexing into sparse literals). -using LinearOrMultiIndex = absl::variant>; - // Parser for the HloModule::ToString() format text. class HloParserImpl : public HloParser { public: @@ -137,24 +133,21 @@ class HloParserImpl : public HloParser { bool ParseTupleLiteral(Literal* literal, const Shape& shape); bool ParseNonTupleLiteral(Literal* literal, const Shape& shape); bool ParseDenseLiteral(Literal* literal, const Shape& shape); - bool ParseSparseLiteral(Literal* literal, const Shape& shape); - // Sets the sub-value of literal at the given linear or sparse index to the - // given value. If the literal is dense, it myst have the default layout. + // Sets the sub-value of literal at the given linear index to the + // given value. If the literal is dense, it must have the default layout. // // `loc` should be the source location of the value. - bool SetValueInLiteral(LocTy loc, int64 value, LinearOrMultiIndex index, + bool SetValueInLiteral(LocTy loc, int64 value, int64 index, Literal* literal); + bool SetValueInLiteral(LocTy loc, double value, int64 index, Literal* literal); - bool SetValueInLiteral(LocTy loc, double value, LinearOrMultiIndex index, + bool SetValueInLiteral(LocTy loc, bool value, int64 index, Literal* literal); + bool SetValueInLiteral(LocTy loc, std::complex value, int64 index, Literal* literal); - bool SetValueInLiteral(LocTy loc, bool value, LinearOrMultiIndex index, - Literal* literal); - bool SetValueInLiteral(LocTy loc, std::complex value, - LinearOrMultiIndex index, Literal* literal); // `loc` should be the source location of the value. template - bool SetValueInLiteralHelper(LocTy loc, ParsedElemT value, - LinearOrMultiIndex index, Literal* literal); + bool SetValueInLiteralHelper(LocTy loc, ParsedElemT value, int64 index, + Literal* literal); // Checks whether the given value is within the range of LiteralNativeT. // `loc` should be the source location of the value. @@ -2125,8 +2118,7 @@ bool HloParserImpl::ParseInstructionNames( "expects '}' at the end of instruction name list"); } -bool HloParserImpl::SetValueInLiteral(LocTy loc, int64 value, - LinearOrMultiIndex index, +bool HloParserImpl::SetValueInLiteral(LocTy loc, int64 value, int64 index, Literal* literal) { const Shape& shape = literal->shape(); switch (shape.element_type()) { @@ -2160,8 +2152,7 @@ bool HloParserImpl::SetValueInLiteral(LocTy loc, int64 value, } } -bool HloParserImpl::SetValueInLiteral(LocTy loc, double value, - LinearOrMultiIndex index, +bool HloParserImpl::SetValueInLiteral(LocTy loc, double value, int64 index, Literal* literal) { const Shape& shape = literal->shape(); switch (shape.element_type()) { @@ -2180,8 +2171,7 @@ bool HloParserImpl::SetValueInLiteral(LocTy loc, double value, } } -bool HloParserImpl::SetValueInLiteral(LocTy loc, bool value, - LinearOrMultiIndex index, +bool HloParserImpl::SetValueInLiteral(LocTy loc, bool value, int64 index, Literal* literal) { const Shape& shape = literal->shape(); switch (shape.element_type()) { @@ -2194,8 +2184,7 @@ bool HloParserImpl::SetValueInLiteral(LocTy loc, bool value, } bool HloParserImpl::SetValueInLiteral(LocTy loc, std::complex value, - LinearOrMultiIndex index, - Literal* literal) { + int64 index, Literal* literal) { const Shape& shape = literal->shape(); switch (shape.element_type()) { case C64: @@ -2221,54 +2210,21 @@ std::string StringifyValue(std::complex val) { template bool HloParserImpl::SetValueInLiteralHelper(LocTy loc, ParsedElemT value, - LinearOrMultiIndex index, - Literal* literal) { + int64 index, Literal* literal) { if (!CheckParsedValueIsInRange(loc, value)) { return false; } // Check that the index is in range and assign into the literal - if (auto* linear_index = absl::get_if(&index)) { - if (*linear_index >= ShapeUtil::ElementsIn(literal->shape())) { - return Error(loc, StrCat("trys to set value ", StringifyValue(value), - " to a literal in shape ", - ShapeUtil::HumanString(literal->shape()), - " at linear index ", *linear_index, - ", but the index is out of range")); - } - literal->data().at(*linear_index) = - static_cast(value); - } else { - auto* multi_index = absl::get_if>(&index); - CHECK(multi_index != nullptr); - - auto invalid_idx = [&](std::string msg) { - return Error(loc, StrFormat("Invalid sparse index [%s]. %s", - absl::StrJoin(*multi_index, ", "), msg)); - }; - - const auto& shape = literal->shape(); - if (shape.rank() != multi_index->size()) { - return invalid_idx( - StrFormat("Has rank %d, but constant has shape %s, which has rank %d", - multi_index->size(), shape.ToString(), shape.rank())); - } - for (int64 i = 0; i < shape.rank(); ++i) { - auto idx = (*multi_index)[i]; - if (idx < 0) { - return invalid_idx(StrFormat( - "Sub-index value at %d, namely %d, cannot be negative.", i, idx)); - } - if (idx >= shape.dimensions(i)) { - return invalid_idx( - StrFormat("Sub-index at %d, namely %d, doesn't fit within shape " - "dimension %d in %s", - i, idx, shape.dimensions(i), shape.ToString())); - } - } - literal->AppendSparseElement(*multi_index, - static_cast(value)); + if (index >= ShapeUtil::ElementsIn(literal->shape())) { + return Error(loc, StrCat("trys to set value ", StringifyValue(value), + " to a literal in shape ", + ShapeUtil::HumanString(literal->shape()), + " at linear index ", index, + ", but the index is out of range")); } + literal->data().at(index) = + static_cast(value); return true; } @@ -2314,12 +2270,8 @@ bool HloParserImpl::ParseTupleLiteral(Literal* literal, const Shape& shape) { // non_tuple // ::= rank01 // ::= rank2345 -// rank2345 ::= shape sparse_or_nested_array +// rank2345 ::= shape nested_array bool HloParserImpl::ParseNonTupleLiteral(Literal* literal, const Shape& shape) { - if (LayoutUtil::IsSparseArray(shape)) { - return ParseSparseLiteral(literal, shape); - } - CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ToString(true); return ParseDenseLiteral(literal, shape); } @@ -2500,98 +2452,6 @@ bool HloParserImpl::ParseDenseLiteral(Literal* literal, const Shape& shape) { return true; } -bool HloParserImpl::ParseSparseLiteral(Literal* literal, const Shape& shape) { - *literal = Literal(shape); - if (!ParseToken(TokKind::kLbrace, - "expects '{' at the beginning of a sparse literal")) { - return false; - } - - for (;;) { - if (lexer_.GetKind() == TokKind::kRbrace) { - lexer_.Lex(); - break; - } - - std::vector index; - if (lexer_.GetKind() == TokKind::kInt) { - int64 single_index = lexer_.GetInt64Val(); - lexer_.Lex(); - index.push_back(single_index); - } else { - if (!ParseInt64List(TokKind::kLsquare, TokKind::kRsquare, TokKind::kComma, - &index)) { - return false; - } - } - if (!ParseToken(TokKind::kColon, - "expects ':' after after the sparse array index and before " - "the sparse array value")) { - return false; - } - - LocTy value_loc = lexer_.GetLoc(); - if (lexer_.GetKind() == TokKind::kw_true || - lexer_.GetKind() == TokKind::kw_false) { - bool value = lexer_.GetKind() == TokKind::kw_true; - if (!SetValueInLiteral(lexer_.GetLoc(), value, index, literal)) { - return false; - } - lexer_.Lex(); - } else if (primitive_util::IsIntegralType(shape.element_type())) { - int64 value; - if (!ParseInt64(&value)) { - return Error(value_loc, - StrCat("expects integer for primitive type: ", - PrimitiveType_Name(shape.element_type()))); - } - if (!SetValueInLiteral(value_loc, value, index, literal)) { - return false; - } - } else if (primitive_util::IsFloatingPointType(shape.element_type())) { - double value; - if (!ParseDouble(&value)) { - return Error(value_loc, - StrCat("expects floating point value for primitive type: ", - PrimitiveType_Name(shape.element_type()))); - } - if (!SetValueInLiteral(value_loc, value, index, literal)) { - return false; - } - } else if (primitive_util::IsComplexType(shape.element_type())) { - std::complex value; - if (!ParseComplex(&value)) { - return Error(value_loc, - StrCat("expects complex value for primitive type: ", - PrimitiveType_Name(shape.element_type()))); - } - if (!SetValueInLiteral(value_loc, value, index, literal)) { - return false; - } - } else { - LOG(FATAL) << "Unexpected element type: " - << PrimitiveType_Name(shape.element_type()); - } - - if (lexer_.GetKind() != TokKind::kRbrace && - !ParseToken(TokKind::kComma, - "expects ',' separator between sparse array elements")) { - return false; - } - - if (literal->sparse_element_count() + 1 == - LayoutUtil::MaxSparseElements(shape.layout())) { - return Error( - lexer_.GetLoc(), - StrCat("number of sparse elements exceeds maximum for layout: ", - ShapeUtil::HumanStringWithLayout(shape))); - } - } - - literal->SortSparseElements(); - return true; -} - // MaxFiniteValue is a type-traits helper used by // HloParserImpl::CheckParsedValueIsInRange. template @@ -2615,18 +2475,37 @@ struct MinMaxFiniteValue { static double min() { return -max(); } }; +// MSVC's standard C++ library does not define isnan/isfinite for integer types. +// To work around that we will need to provide our own. +template +std::enable_if_t::value, bool> IsFinite(T val) { + return std::isfinite(val); +} +template +std::enable_if_t::value, bool> IsNaN(T val) { + return std::isnan(val); +} +template +std::enable_if_t::value, bool> IsFinite(T val) { + return std::isfinite(static_cast(val)); +} +template +std::enable_if_t::value, bool> IsNaN(T val) { + return std::isnan(static_cast(val)); +} + template bool HloParserImpl::CheckParsedValueIsInRange(LocTy loc, ParsedElemT value) { if (std::is_floating_point::value) { auto value_as_native_t = static_cast(value); auto value_double_converted = static_cast(value_as_native_t); - if (!std::isfinite(value) || std::isfinite(value_double_converted)) { + if (!IsFinite(value) || IsFinite(value_double_converted)) { value = value_double_converted; } } PrimitiveType literal_ty = primitive_util::NativeToPrimitiveType(); - if (std::isnan(value) || + if (IsNaN(value) || (std::numeric_limits::has_infinity && (std::numeric_limits::infinity() == value || -std::numeric_limits::infinity() == value))) { @@ -3820,21 +3699,6 @@ bool HloParserImpl::ParseShape(Shape* result) { } LayoutUtil::SetToDefaultLayout(result); - if (lexer_.GetKind() == TokKind::kw_sparse) { - lexer_.Lex(); - const std::string message = - "expects a brace-bracketed integer for sparse layout"; - int64 max_sparse_elements; - if (!ParseToken(TokKind::kLbrace, message) || - !ParseInt64(&max_sparse_elements) || - !ParseToken(TokKind::kRbrace, message)) { - return false; - } - *result->mutable_layout() = - LayoutUtil::MakeSparseLayout(max_sparse_elements); - return true; - } - // We need to lookahead to see if a following open brace is the start of a // layout. The specific problematic case is: // @@ -4386,6 +4250,7 @@ bool HloParserImpl::ParseSingleInstruction(HloModule* module) { for (auto& comp : computations_) { module->AddEmbeddedComputation(std::move(comp)); } + TF_CHECK_OK(module->set_schedule(ScheduleFromInstructionOrder(module))); return true; } diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index d65613fc4b8..e3431a4731f 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -841,50 +841,6 @@ ENTRY %fusion.v3 () -> f32[3,2,1,1] { )" }, { -"Sparse", -R"(HloModule sparse_f32 - -ENTRY %sparse () -> f32[2,3,4] { - ROOT %foo = f32[2,3,4]sparse{10} constant({[0, 1, 2]: 1, [1, 2, 2]: 2, [1, 2, 3]: 3}) -} - -)", -/*enable_verification=*/false -}, -{ -"SparseC128", -R"(HloModule sparse_c128 - -ENTRY %sparse () -> c128[2,3,4] { - ROOT %foo = c128[2,3,4]sparse{10} constant({[0, 1, 2]: (1, 0), [1, 2, 2]: (2, 5), [1, 2, 3]: (3, 10)}) -} - -)", -/*enable_verification=*/false -}, -{ -"SparseEmpty", -R"(HloModule sparse_f32_empty - -ENTRY %sparse_f32_empty () -> f32[2,3,4] { - ROOT %foo = f32[2,3,4]sparse{10} constant({}) -} - -)", -/*enable_verification=*/false, -}, -{ -"SparseR1", -R"(HloModule sparse_f32_r1 - -ENTRY %sparse_f32_r1 () -> f32[9] { - ROOT %foo = f32[9]sparse{10} constant({1: 2, 3: 4, 5: 6}) -} - -)", -/*enable_verification=*/false, -}, -{ "Gather", R"(HloModule StringifyGather @@ -1982,17 +1938,6 @@ TEST_F(HloParserTest, ConstantBf16Overflow) { "out of range"); } -TEST_F(HloParserTest, ConstantF16OverflowInSparseArray) { - const string original = R"( - HloModule test_module - ENTRY test { - ROOT c = f16[5]sparse{10} constant({[0]: 0, [1]: -65520}) - })"; - ExpectHasSubstr( - ParseAndReturnUnverifiedModule(original).status().error_message(), - "is out of range for literal's primitive type F16"); -} - TEST_F(HloParserTest, ConstantUnsignedUnderflow) { const string original = R"( HloModule ConstantUnsignedUnderflow_module @@ -2852,50 +2797,6 @@ ENTRY %entrycomp (p: f32[2,2]) -> f32[2,2] { " with the shape of the operand instruction f32[2,2]{1,0}."); } -TEST_F(HloParserTest, OutOfRangeSparseIndex) { - const string original = R"( - HloModule test_module - ENTRY test { - ROOT c = f16[5]sparse{10} constant({[100]: 0}) - })"; - ExpectHasSubstr( - ParseAndReturnUnverifiedModule(original).status().error_message(), - "Invalid sparse index"); -} - -TEST_F(HloParserTest, NegativeSparseIndex) { - const string original = R"( - HloModule test_module - ENTRY test { - ROOT c = f16[5]sparse{10} constant({-1: 0}) - })"; - ExpectHasSubstr( - ParseAndReturnUnverifiedModule(original).status().error_message(), - "Invalid sparse index"); -} - -TEST_F(HloParserTest, SparseIndexWithRankTooLarge) { - const string original = R"( - HloModule test_module - ENTRY test { - ROOT c = f16[5]sparse{10} constant({[0, 0]: 0}) - })"; - ExpectHasSubstr( - ParseAndReturnUnverifiedModule(original).status().error_message(), - "Invalid sparse index"); -} - -TEST_F(HloParserTest, SparseIndexWithRankTooSmall) { - const string original = R"( - HloModule test_module - ENTRY test { - ROOT c = f16[5, 5]sparse{10} constant({[0]: 0}) - })"; - ExpectHasSubstr( - ParseAndReturnUnverifiedModule(original).status().error_message(), - "Invalid sparse index"); -} - TEST_F(HloParserTest, ParseShapeStringR2F32) { string shape_string = "f32[123,456]"; TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); @@ -2994,15 +2895,6 @@ TEST_F(HloParserTest, ParseShapeStringWithTilingLayout) { "Dimensions size is 3, but minor to major size is 1."); } -TEST_F(HloParserTest, ParseShapeStringWithSparseLayout) { - string shape_string = "f32[123,456]sparse{10}"; - TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); - Shape expected = ShapeUtil::MakeShapeWithSparseLayout(F32, {123, 456}, 10); - ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) - << "expected: " << ShapeUtil::HumanString(expected) - << "actual: " << ShapeUtil::HumanString(actual); -} - TEST_F(HloParserTest, ParseShapeStringWithMemorySpaceLayout) { // Tile, element size, and memory space. string shape_string = "pred[123,456]{1,0:T(2,128)E(1)S(3)}"; @@ -3047,10 +2939,8 @@ TEST_F(HloParserTest, ParseTokenType) { } TEST_F(HloParserTest, ParseInvalidShapeString) { - string shape_strings[] = { - "f32[123,456]foobar{0,1}", "f32[123,456]sparse{0,1}", "f32[123,456]{foo}", - "f32[123,456]dense{foo}", "f32[123,456]sparse{foo}", - }; + string shape_strings[] = {"f32[123,456]foobar{0,1}", "f32[123,456]{foo}", + "f32[123,456]dense{foo}"}; for (const string& shape_string : shape_strings) { StatusOr result = ParseShape(shape_string); ASSERT_FALSE(result.ok()) << "shape: " << shape_string; diff --git a/tensorflow/compiler/xla/service/hlo_query.cc b/tensorflow/compiler/xla/service/hlo_query.cc index defd6abd8f6..1b6494bf3cb 100644 --- a/tensorflow/compiler/xla/service/hlo_query.cc +++ b/tensorflow/compiler/xla/service/hlo_query.cc @@ -133,5 +133,17 @@ bool ContainsLayoutConstrainedAllReduce(const HloModule& module) { return false; } +int64 NextChannelId(const HloModule& module) { + int64 next_channel_id = 1; + for (const HloComputation* comp : module.computations()) { + for (const HloInstruction* hlo : comp->instructions()) { + if (DynCast(hlo)) { + next_channel_id = std::max(next_channel_id, *hlo->channel_id() + 1); + } + } + } + return next_channel_id; +} + } // namespace hlo_query } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_query.h b/tensorflow/compiler/xla/service/hlo_query.h index 0ea36ae83f8..b7fbc465dcb 100644 --- a/tensorflow/compiler/xla/service/hlo_query.h +++ b/tensorflow/compiler/xla/service/hlo_query.h @@ -77,6 +77,10 @@ bool MatchBinaryInstructionOperandOpcode(HloOpcode opcode, // layout. bool ContainsLayoutConstrainedAllReduce(const HloModule& module); +// Returns the next available channel id that can be used in the given module +// (for HloChannelInstructions). +int64 NextChannelId(const HloModule& module); + } // namespace hlo_query } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 445a3ea97d2..5d38bbeec32 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -370,7 +370,8 @@ class MemoryUsageTracker { const HloRematerialization::ShapeSizeFunction& size_function, const HloRematerialization::CompactShapeFunction& compact_shape_function, const TuplePointsToAnalysis& points_to_analysis, - const InstructionList& instruction_list); + const InstructionList& instruction_list, + HloRematerialization::RematerializationMode mode); // Starts the placement of the given instruction. This adds the sizes of the // LogicalBuffers defined by the instruction to the current memory @@ -607,6 +608,7 @@ class MemoryUsageTracker { // between the calling of BeginInstruction and EndInstruction. Item* in_progress_item_ = nullptr; + HloRematerialization::RematerializationMode mode_; // All buffers in the computation. std::vector buffers_; }; @@ -616,11 +618,13 @@ MemoryUsageTracker::MemoryUsageTracker( const HloRematerialization::ShapeSizeFunction& size_function, const HloRematerialization::CompactShapeFunction& compact_shape_function, const TuplePointsToAnalysis& points_to_analysis, - const InstructionList& instruction_list) + const InstructionList& instruction_list, + HloRematerialization::RematerializationMode mode) : computation_(computation), instruction_list_(instruction_list), size_function_(size_function), - compact_shape_function_(compact_shape_function) { + compact_shape_function_(compact_shape_function), + mode_(mode) { PointsToSet::BufferSet live_out_set = points_to_analysis.GetPointsToSet(computation_->root_instruction()) .CreateFlattenedSet(); @@ -1155,7 +1159,10 @@ MemoryUsageTracker::PickRematerializationCandidate( continue; } - if (item->buffers_output.size() == 1) { + if (item->buffers_output.size() == 1 && + (mode_ == HloRematerialization::RematerializationMode::kCompressOnly || + mode_ == HloRematerialization::RematerializationMode:: + kRecomputeAndCompress)) { // Only consider compressing single output instruction. const Buffer& output_buffer = buffers_.at(item->buffers_output[0]); @@ -1196,6 +1203,11 @@ MemoryUsageTracker::PickRematerializationCandidate( continue; } + // Do not consider recomputation in compress-only mode. + if (mode_ == HloRematerialization::RematerializationMode::kCompressOnly) { + continue; + } + const int64 memory_reduced = MemoryReducedIfRematerialized(item); if (memory_reduced > 0) { @@ -1370,7 +1382,7 @@ StatusOr HloRematerialization::ComputePeakMemory( InstructionList instruction_list(order); MemoryUsageTracker tracker(computation, size_function_, compact_shape_function_, *points_to_analysis_, - instruction_list); + instruction_list, mode_); int64 peak_memory = tracker.memory_usage(); for (auto* item = instruction_list.first(); item != nullptr; item = instruction_list.next(item)) { @@ -1412,9 +1424,9 @@ StatusOr HloRematerialization::RematerializeComputation( CHECK(!ContainsKey(rematerialized_computations_, computation)); InstructionList instruction_list(schedule->sequence(computation)); - MemoryUsageTracker memory_tracker(computation, size_function_, - compact_shape_function_, - *points_to_analysis_, instruction_list); + MemoryUsageTracker memory_tracker( + computation, size_function_, compact_shape_function_, + *points_to_analysis_, instruction_list, mode_); bool changed = false; // If the rematerialization makes the source instruction dead, then the diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index 9ab34b4862d..69cdc84991b 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -49,6 +49,13 @@ class HloRematerialization : public HloModulePass { int64 after_bytes; }; + // Mode in which the rematerialization algorithm should be run. + enum class RematerializationMode { + kRecomputeOnly, // Only consider the kCompress RematStrategy. + kCompressOnly, // Only consider the kRecompute RematStrategy. + kRecomputeAndCompress // Consider both kRecompute and kRemat. + }; + static Shape DefaultCompactShapeFunction(const Shape& shape) { return shape; } // Constructor parameters: @@ -69,13 +76,15 @@ class HloRematerialization : public HloModulePass { explicit HloRematerialization( const ShapeSizeFunction& size_function, int64 memory_limit_bytes, RematerializationSizes* sizes, - CompactShapeFunction compact_shape_function = nullptr) + CompactShapeFunction compact_shape_function = nullptr, + RematerializationMode mode = RematerializationMode::kRecomputeAndCompress) : size_function_(size_function), memory_limit_bytes_(memory_limit_bytes), sizes_(sizes), compact_shape_function_(compact_shape_function == nullptr ? DefaultCompactShapeFunction - : std::move(compact_shape_function)) {} + : std::move(compact_shape_function)), + mode_(mode) {} ~HloRematerialization() override = default; absl::string_view name() const override { return "rematerialization"; } @@ -152,6 +161,8 @@ class HloRematerialization : public HloModulePass { // uses of the original instruction and the original instruction is // dead. Hence, no net instructions were added. int64 net_instructions_added_ = 0; + + RematerializationMode mode_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 1218f7dfc6f..b2beb9dda55 100755 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -33,17 +33,6 @@ limitations under the License. namespace xla { -Status VerifyNotSparse(const Shape& shape) { - return ShapeUtil::ForEachSubshapeWithStatus( - shape, [](const Shape& subshape, const ShapeIndex&) -> Status { - if (LayoutUtil::IsSparseArray(subshape)) { - return InternalError("Sparse arrays are not yet fully supported: %s", - ShapeUtil::HumanStringWithLayout(subshape)); - } - return Status::OK(); - }); -} - bool IsCallerInstruction(HloInstruction* hlo) { switch (hlo->opcode()) { case HloOpcode::kCall: @@ -93,8 +82,6 @@ Status ShapeVerifier::Preprocess(HloInstruction* hlo) { "Called computations specified for non-caller instruction %s", hlo->ToString()); } - TF_RETURN_IF_ERROR(VerifyNotSparse(hlo->shape())); - absl::optional arity = HloOpcodeArity(hlo->opcode()); if (arity) { TF_RETURN_IF_ERROR(CheckOperandCount(hlo, *arity)); @@ -1109,8 +1096,6 @@ Status ShapeVerifier::VerifyEntryComputationLayout(const HloModule& module) { TF_RETURN_IF_ERROR( ShapeUtil::ValidateShapeWithOptionalLayout(result_layout.shape())); - TF_RETURN_IF_ERROR(VerifyNotSparse(result_layout.shape())); - if (!ShapeUtil::Compatible(computation->root_instruction()->shape(), result_layout.shape())) { return InternalError( @@ -1131,7 +1116,6 @@ Status ShapeVerifier::VerifyEntryComputationLayout(const HloModule& module) { const HloInstruction* parameter = computation->parameter_instruction(i); TF_RETURN_IF_ERROR( ShapeUtil::ValidateShapeWithOptionalLayout(layout.parameter_shape(i))); - TF_RETURN_IF_ERROR(VerifyNotSparse(layout.parameter_shape(i))); if (!ShapeUtil::Compatible(parameter->shape(), layout.parameter_shape(i))) { return InternalError( "Shape of the entry computation parameter %d is %s should be " @@ -1333,37 +1317,24 @@ Status VerifyLayoutConstrainedAllReduce(const HloModule& module) { return Status::OK(); } -// Checks various invariants of send and recv instructions. -Status VerifySendsAndRecvs(const HloModule& module) { - absl::flat_hash_map host_channels; - // Host send/recv instructions must have their own unique channel. - auto check_unique_host_channel = [&](const HloInstruction* instruction) { - const HloSendRecvInstruction* sendrecv = - DynCast(instruction); - if (sendrecv->is_host_transfer()) { - auto it_inserted = - host_channels.insert({*sendrecv->channel_id(), sendrecv}); - if (!it_inserted.second) { - return FailedPrecondition( - "Channel %d is used for multiple host send/recv instructions: " - "%s " - "and " - "%s", - *sendrecv->channel_id(), sendrecv->ToString(), - it_inserted.first->second->ToString()); - } - } - - return Status::OK(); - }; +// Checks various invariants of channel instructions (send/recv and +// collectives). +Status VerifyChannels(const HloModule& module) { + absl::flat_hash_map> + channel_instructions; // Send/Recv instruction must have a single user: the corresponding // SendDone/RecvDone. with matching channel. for (const HloComputation* computation : module.computations()) { for (const HloInstruction* instruction : computation->instructions()) { + auto channel_instr = DynCast(instruction); + if (!channel_instr || !channel_instr->channel_id()) { + continue; + } + channel_instructions[*channel_instr->channel_id()].push_back(instruction); + switch (instruction->opcode()) { case HloOpcode::kSend: { - TF_RETURN_IF_ERROR(check_unique_host_channel(instruction)); TF_RET_CHECK(instruction->users().size() == 1); const HloInstruction* send_done = instruction->users().front(); TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone); @@ -1372,7 +1343,6 @@ Status VerifySendsAndRecvs(const HloModule& module) { break; } case HloOpcode::kRecv: { - TF_RETURN_IF_ERROR(check_unique_host_channel(instruction)); TF_RET_CHECK(instruction->users().size() == 1); const HloInstruction* recv_done = instruction->users().front(); TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone); @@ -1393,6 +1363,39 @@ Status VerifySendsAndRecvs(const HloModule& module) { } } } + + // Iterate over each channel to check invariants. + for (auto& pair : channel_instructions) { + auto& instructions = pair.second; + const HloInstruction* first = instructions[0]; + auto sendrecv = DynCast(first); + if (sendrecv) { + absl::flat_hash_set opcodes; + for (const HloInstruction* instr : instructions) { + opcodes.insert(instr->opcode()); + auto cast = DynCast(instr); + TF_RET_CHECK(cast != nullptr) + << "channel " << pair.first + << " is used for different types of channel instructions"; + } + if (sendrecv->is_host_transfer()) { + TF_RET_CHECK(instructions.size() == 2) + << "channel " << pair.first + << " is used for multiple host send/recv instructions"; + } else { + TF_RET_CHECK(instructions.size() == opcodes.size()) + << "channel " << pair.first + << " is used for multiple send/recv instructions"; + } + } else { + for (const HloInstruction* instr : instructions) { + TF_RET_CHECK(first->opcode() == instr->opcode()) + << "channel " << pair.first + << " is used for different types of channel instructions"; + } + } + } + return Status::OK(); } @@ -1696,7 +1699,7 @@ StatusOr HloVerifier::Run(HloModule* module) { TF_RETURN_IF_ERROR(VerifyHloStructure(module)); TF_RETURN_IF_ERROR(VerifyAsynchronousCopies(*module)); - TF_RETURN_IF_ERROR(VerifySendsAndRecvs(*module)); + TF_RETURN_IF_ERROR(VerifyChannels(*module)); std::unique_ptr shape_verifier = target_metadata_->GetVerifier(); diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index 1b273909991..c174af6dec0 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -1013,5 +1013,56 @@ TEST_F(HloVerifierTest, AllReduceVerifier) { HasSubstr("mix of layout constrained and unconstrained AllReduce")); } +TEST_F(HloVerifierTest, ChannelVerifier) { + const char* const kModuleStr = R"( + HloModule test + + add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + + ENTRY entry { + %input = f32[8,12] parameter(0) + %token0 = token[] after-all() + %send = (f32[8,12], u32[], token[]) send(%input, %token0), channel_id=1 + %send-done = token[] send-done(%send), channel_id=1 + %crs = f32[8,12] all-reduce(%input), replica_groups={}, to_apply=add, + channel_id=1 + ROOT result = (f32[8,12]{0,1}, f32[8,12]{0,1}) tuple(%input, %crs) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kModuleStr)); + EXPECT_THAT(verifier().Run(module.get()).status().error_message(), + HasSubstr("used for different types of channel instructions")); +} + +TEST_F(HloVerifierTest, CollectiveChannelVerifier) { + const char* const kModuleStr = R"( + HloModule test + + add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + + ENTRY entry { + %input = f32[8,12] parameter(0) + %permute = f32[8,12] collective-permute(%input), + source_target_pairs={{0,1},{1,0}}, channel_id=1 + %crs = f32[8,12] all-reduce(%input), replica_groups={}, to_apply=add, + channel_id=1 + ROOT result = (f32[8,12]{0,1}, f32[8,12]{0,1}) tuple(%permute, %crs) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kModuleStr)); + EXPECT_THAT(verifier().Run(module.get()).status().error_message(), + HasSubstr("used for different types of channel instructions")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index f0c29efffde..39399df7ad8 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -42,7 +42,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", - "@llvm//:core", + "@llvm-project//llvm:core", ], ) @@ -78,10 +78,10 @@ cc_library( "@com_google_absl//absl/base", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@llvm//:core", - "@llvm//:support", - "@llvm//:target", - "@llvm//:transform_utils", + "@llvm-project//llvm:core", + "@llvm-project//llvm:support", + "@llvm-project//llvm:target", + "@llvm-project//llvm:transform_utils", ], ) @@ -100,7 +100,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@llvm//:core", + "@llvm-project//llvm:core", ], ) @@ -118,7 +118,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@llvm//:core", + "@llvm-project//llvm:core", ], ) @@ -136,7 +136,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/core:lib", "@com_google_absl//absl/strings:str_format", - "@llvm//:core", + "@llvm-project//llvm:core", ], ) @@ -161,7 +161,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", - "@llvm//:core", + "@llvm-project//llvm:core", ], ) @@ -200,8 +200,8 @@ cc_library( "//tensorflow/core:lib", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@llvm//:core", - "@llvm//:support", + "@llvm-project//llvm:core", + "@llvm-project//llvm:support", ], ) @@ -217,7 +217,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/core:lib", "@com_google_absl//absl/types:span", - "@llvm//:core", + "@llvm-project//llvm:core", ], ) @@ -229,7 +229,7 @@ cc_library( ":llvm_loop", ":llvm_util", "@com_google_absl//absl/strings", - "@llvm//:core", + "@llvm-project//llvm:core", ], ) @@ -249,7 +249,7 @@ cc_library( hdrs = ["math_ops.h"], deps = [ ":llvm_util", - "@llvm//:core", + "@llvm-project//llvm:core", ], ) @@ -258,6 +258,6 @@ cc_library( srcs = [], hdrs = ["ir_builder_mixin.h"], deps = [ - "@llvm//:core", + "@llvm-project//llvm:core", ], ) diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc index caf8fce0f2e..4c56bc55609 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc @@ -91,6 +91,11 @@ bool InstructionCountPrefetchIntervalPicker::CanAllocateInAlternateMemoryNoCopy( return end_time - start_time <= max_overlap_count_; } +int64 InstructionCountPrefetchIntervalPicker::PreferredEvictionEndTime( + const Shape& shape, int64 start_time, int64 latest_end_time) const { + return std::min(start_time + min_overlap_count_, latest_end_time); +} + void InstructionCountPrefetchIntervalPicker::Begin(const HloUse& use, int64 start_time, int64 end_time) { @@ -153,6 +158,21 @@ bool CostAnalysisPrefetchIntervalPicker::CanAllocateInAlternateMemoryNoCopy( logical_interval_elapsed; } +int64 CostAnalysisPrefetchIntervalPicker::PreferredEvictionEndTime( + const Shape& shape, int64 start_time, int64 latest_end_time) const { + float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape); + int64 end_time; + for (end_time = start_time + 1; end_time <= latest_end_time; ++end_time) { + float logical_interval_elapsed = + GetLogicalIntervalElapsed(start_time, end_time); + if (logical_interval_elapsed >= + min_async_copy_to_overlap_ratio_ * async_copy_elapsed) { + break; + } + } + return end_time; +} + void CostAnalysisPrefetchIntervalPicker::Begin(const HloUse& use, int64 start_time, int64 end_time) { @@ -337,8 +357,7 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { absl::make_unique( value->defining_instruction(), value->defining_position(), aliased_allocation->memory_space(), aliased_allocation->chunk(), - aliased_allocation->start_time(), - aliased_allocation->end_time())); + definition_time, definition_time)); } // Iterate over the uses. @@ -418,6 +437,28 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { return result_; } +bool operator<(const AsynchronousCopy& a, const AsynchronousCopy& b) { + return (a.start_time < b.start_time && a.end_time <= b.end_time) || + (a.start_time <= b.start_time && a.end_time < b.end_time); +} + +void AsynchronousCopyOrdering::AddCopy(const AsynchronousCopy& copy) { + auto it_and_inserted = ranges_.insert(copy); + CHECK(it_and_inserted.second || + it_and_inserted.first->start_time == copy.start_time); +} + +bool AsynchronousCopyOrdering::ViolatesOrdering(int64 start_time, + int64 end_time) const { + // We allow identical start and end times. It is enough to check for just the + // start time in case we find a match in ranges_ because the found value will + // either be identical to {start_time, end_time} (and this doesn't violate) or + // its start_time will be smaller and end_time will be larger (this violates). + auto copy_it = ranges_.find( + {start_time, end_time, MemorySpaceAssignment::MemorySpace::kAlternate}); + return copy_it != ranges_.end() && copy_it->start_time != start_time; +} + void AlternateMemoryBestFitHeap::AddInputAndOutputRequiredAssignments() { // Go through the parameters and outputs and pin them to the corresponding // memory by adding a required assignment. @@ -520,14 +561,7 @@ void AlternateMemoryBestFitHeap::CommitPendingChunks() { kDummyChunk); } if (interval.destination == MemorySpace::kAlternate) { - // If there is already an asynchronous copy ending the same time, pick - // the earliest copy start time. - auto range_it = async_copy_range_map_.find(interval.end_time); - if (range_it != async_copy_range_map_.end()) { - range_it->second = std::min(range_it->second, interval.start_time); - } else { - async_copy_range_map_[interval.end_time] = interval.start_time; - } + async_copy_ordering_.AddCopy(interval); } } pending_async_copies_.clear(); @@ -627,48 +661,68 @@ bool AlternateMemoryBestFitHeap::FindAllocation( } } - // Since copies couldn't be removed, create an allocation in the default - // memory space. - if (prev_allocation_in_default_mem != nullptr) { - if (prev_allocation == prev_allocation_in_default_mem) { - // The latest allocation is also in the default memory, simply extend - // that. - prev_allocation->Extend(end_time); - } else { - // The latest allocation is different. Create a new allocation in default - // memory. - allocations->push_back( - absl::make_unique( - non_bitcast_operand, defining_position, MemorySpace::kDefault, - kDummyChunk, prev_allocation_in_default_mem->end_time(), - end_time)); - } - } else if (prev_allocation != nullptr && - prev_allocation->memory_space() == MemorySpace::kAlternate && - prev_allocation->defining_position() == defining_position) { + if (prev_allocation_in_default_mem == nullptr && prev_allocation != nullptr && + prev_allocation->memory_space() == MemorySpace::kAlternate && + prev_allocation->defining_position() == defining_position) { // If there was an allocation for this HloValue that was in the alternate // memory space, we also need to perform an eviction. - // TODO(berkin): For now evictions happen relative to the most recent - // allocation in the alternate memory. We can potentially start evictions - // earlier and end later. + int64 eviction_start_time = prev_allocation->start_time(); + int64 eviction_end_time = prev_allocation->end_time(); + CHECK(eviction_start_time <= eviction_end_time); + + int64 preferred_eviction_end_time = std::max( + options_.prefetch_interval_picker->PreferredEvictionEndTime( + non_bitcast_operand->shape(), eviction_start_time, end_time), + eviction_end_time); + + BufferInterval eviction_mem_interval; + eviction_mem_interval.buffer = buffer; + eviction_mem_interval.size = size; + // Try to reserve a buffer from the end of the previous allocation to the + // preferred eviction end time. + eviction_mem_interval.start = prev_allocation->end_time() + 1; + eviction_mem_interval.end = preferred_eviction_end_time; + int64 preferred_offset = prev_allocation->chunk().offset; + VLOG(4) << "Eviction (" << eviction_start_time << ", " << eviction_end_time + << ") preferred end time = " << preferred_eviction_end_time; + + while (preferred_eviction_end_time > eviction_end_time) { + ChunkCandidate chunk_candidate = + FindChunkCandidate(eviction_mem_interval, preferred_offset); + if (chunk_candidate.chunk.offset == preferred_offset) { + eviction_end_time = preferred_eviction_end_time; + AddToPendingChunks(eviction_mem_interval, chunk_candidate); + break; + } + eviction_mem_interval.end = --preferred_eviction_end_time; + } + VLOG(3) << "Evicting buffer at " << prev_allocation->chunk().offset << " (" - << prev_allocation->start_time() << ", " - << prev_allocation->end_time() << ")"; + << eviction_start_time << ", " << eviction_end_time << ")"; + + bool eviction_interval_too_short = + (eviction_start_time == eviction_end_time); + bool eviction_violates_outstanding_copies = + ViolatesMaximumOutstandingAsyncCopies(eviction_start_time, + eviction_end_time); // See if this interval would violate the asynchronous copy limit. - if (!ViolatesMaximumOutstandingAsyncCopies(prev_allocation->start_time(), - prev_allocation->end_time())) { + if (!eviction_interval_too_short && !eviction_violates_outstanding_copies) { + prev_allocation->Extend(eviction_end_time); AddAsyncCopy(*prev_allocation, MemorySpace::kDefault, kDummyChunk, - prev_allocation->start_time(), prev_allocation->end_time(), - prev_allocation->end_time(), allocations); - + eviction_start_time, prev_allocation->end_time(), + eviction_end_time, allocations); } else { - VLOG(3) << "This violates the maximum async copies."; + if (eviction_violates_outstanding_copies) { + VLOG(3) << "This violates the maximum async copies."; + } else { + VLOG(3) << "Eviction interval is too short (" << eviction_start_time + << ", " << eviction_end_time << ")."; + } // If the original interval violated the limit, try sub-intervals within // this interval. bool eviction_scheduled = false; - for (int64 time = prev_allocation->start_time(); - time <= prev_allocation->end_time(); ++time) { + for (int64 time = eviction_start_time; time < eviction_end_time; ++time) { VLOG(3) << "Try evicting (" << time << ", " << time << ")"; if (!ViolatesMaximumOutstandingAsyncCopies(time, time)) { VLOG(3) << "Eviction successful."; @@ -686,25 +740,31 @@ bool AlternateMemoryBestFitHeap::FindAllocation( << " because we hit the limit of maximum asynchronous copies " << "between " << hlo_live_range_.flattened_instruction_sequence() - .instructions()[prev_allocation->start_time()] + .instructions()[eviction_start_time] << " and " << hlo_live_range_.flattened_instruction_sequence() - .instructions()[prev_allocation->end_time()]; + .instructions()[eviction_end_time]; return false; } } - } else { + prev_allocation_in_default_mem = allocations->back().get(); + } else if (prev_allocation_in_default_mem == nullptr) { allocations->push_back(absl::make_unique( non_bitcast_operand, defining_position, MemorySpace::kDefault, kDummyChunk, start_time, end_time)); + prev_allocation_in_default_mem = allocations->back().get(); } + CHECK_NE(prev_allocation_in_default_mem, nullptr); + CHECK(prev_allocation_in_default_mem->memory_space() == + MemorySpace::kDefault); + // If the use requires the buffer to be in default memory, don't try to // prefetch. if (use_requires_buffer_in_default_mem) { VLOG(4) << "Not trying to prefetch because use requires buffer in default mem."; - allocations->back()->AddUse(use); + prev_allocation_in_default_mem->AddUse(use); return true; } @@ -736,8 +796,8 @@ bool AlternateMemoryBestFitHeap::FindAllocation( VLOG(4) << "This would violate the outstanding async copy limit."; continue; } - if (ViolatesAsynchronousCopyOrdering(alternate_mem_interval.start, - alternate_mem_interval.end)) { + if (async_copy_ordering_.ViolatesOrdering(alternate_mem_interval.start, + alternate_mem_interval.end)) { VLOG(4) << "This would violate asynchronous copy ordering."; continue; } @@ -754,7 +814,7 @@ bool AlternateMemoryBestFitHeap::FindAllocation( << options_.prefetch_interval_picker->ToDebugString(); AddToPendingChunks(alternate_mem_interval, chunk_candidate); - AddAsyncCopy(*allocations->back().get(), MemorySpace::kAlternate, + AddAsyncCopy(*prev_allocation_in_default_mem, MemorySpace::kAlternate, chunk_candidate.chunk, alternate_mem_interval.start, end_time, latest_prefetch_time, allocations); @@ -763,8 +823,9 @@ bool AlternateMemoryBestFitHeap::FindAllocation( } } - // If a copy wasn't inserted, then add this use to the latest allocation. - allocations->back()->AddUse(use); + // If a copy wasn't inserted, then add this use to the latest allocation in + // default memory. + prev_allocation_in_default_mem->AddUse(use); return true; } @@ -812,13 +873,6 @@ bool AlternateMemoryBestFitHeap::ViolatesMaximumOutstandingAsyncCopies( return num_async_copies + 1 > options_.max_outstanding_async_copies; } -bool AlternateMemoryBestFitHeap::ViolatesAsynchronousCopyOrdering( - int64 start_time, int64 end_time) const { - auto async_copy_range_it = async_copy_range_map_.lower_bound(end_time); - return async_copy_range_it != async_copy_range_map_.end() && - async_copy_range_it->second < start_time; -} - bool AlternateMemoryBestFitHeap::TryAllocatingInAlternateMemoryNoCopy( int64 start_time, int64 end_time, int64 last_use_time, HloPosition defining_position, HloUse use, @@ -844,7 +898,7 @@ bool AlternateMemoryBestFitHeap::TryAllocatingInAlternateMemoryNoCopy( } if (!options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy( - non_bitcast_operand->shape(), start_time, end_time)) { + non_bitcast_operand->shape(), start_time + 1, end_time)) { return false; } @@ -1032,6 +1086,10 @@ MemorySpaceAssignment::Run(HloModule* module, const Options& options) { VLOG(1) << "Maximum number of outstanding async copies: " << CountMaximumOutstandingAsyncCopies(*module); + if (options.verify || VLOG_IS_ON(1)) { + TF_RETURN_IF_ERROR(memory_space_assignment.Verify()); + } + return std::move(memory_space_assignment.preset_assignments_); } @@ -1313,6 +1371,13 @@ void MemorySpaceAssignment::EnsureInstructionAndOperandsInserted( return; } for (HloInstruction* operand : new_instruction->operands()) { + // CopyStart/CopyDone dependencies should always be already inserted; it is + // a red flag when they haven't already been inserted. + CHECK((operand->opcode() != HloOpcode::kCopyStart && + operand->opcode() != HloOpcode::kCopyDone) || + inserted_instructions->contains(operand)) + << "Inserted instruction " << new_instruction->ToString() + << " has un-inserted dependency: " << operand->ToString(); EnsureInstructionAndOperandsInserted(operand, new_sequence, inserted_instructions); } @@ -1404,10 +1469,14 @@ Status MemorySpaceAssignment::FixSchedule() { } HloInstruction* instruction = flattened_instructions_[instruction_index]; // Insert only if it is not deleted (SimplifyGraph sets it to nullptr if - // it was deleted) and not previously inserted. + // it was deleted) and not previously inserted. Also bitcasts and tuples + // are treated specially and only inserted as a result of operand + // dependencies. if (instruction != nullptr && !inserted_instructions.contains(instruction) && - instruction->parent() == computation) { + instruction->parent() == computation && + instruction->opcode() != HloOpcode::kBitcast && + instruction->opcode() != HloOpcode::kTuple) { EnsureInstructionAndOperandsInserted(instruction, &new_sequence, &inserted_instructions); } @@ -1435,4 +1504,62 @@ Status MemorySpaceAssignment::FixSchedule() { return Status::OK(); } +Status MemorySpaceAssignment::Verify() const { + VLOG(3) << "Verifying:"; + TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, + HloAliasAnalysis::Run(module_)); + TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_live_range, + HloLiveRange::Run(module_->schedule(), *alias_analysis, + module_->entry_computation())); + + BufferIntervalTree interval_tree; + absl::flat_hash_set seen_buffers; + + for (const auto& position_and_chunk : preset_assignments_->chunks()) { + const HloPosition& position = position_and_chunk.first; + const Chunk& chunk = position_and_chunk.second; + const HloBuffer& buffer = + alias_analysis->GetUniqueBufferAt(position.instruction, position.index); + if (seen_buffers.contains(buffer.id())) { + continue; + } + seen_buffers.insert(buffer.id()); + + int64 start_time = INT64_MAX; + int64 end_time = -1; + for (const HloValue* value : buffer.values()) { + const HloLiveRange::TimeBound& time_bound = + hlo_live_range->buffer_live_ranges().at(value); + VLOG(3) << " value: " << value->ToShortString() << " (" + << time_bound.start << ", " << time_bound.end << ")"; + start_time = std::min(start_time, time_bound.start); + end_time = std::max(end_time, time_bound.end); + } + CHECK_GE(start_time, 0); + CHECK_GT(end_time, 0); + // Get the chunks overlapping in time and search if they overlap in space as + // well. + // TODO(berkin): For now checking against end_time - 1 (exclusive), but we + // really should check against end_time (inclusive) for cases where the + // operand can't share buffer with user (see + // HloDataflowAnalysis::CanShareOperandBufferWithUser). + for (const Chunk& overlapping_chunk : + interval_tree.ChunksOverlappingInTime(start_time, end_time - 1)) { + if (chunk.OverlapsWith(overlapping_chunk)) { + return InternalError( + ("Buffer %s (%d, %d) off: %d size: %d overlaps with another chunk" + " off: %d size: %d"), + buffer.ToString(), start_time, end_time, chunk.offset, chunk.size, + overlapping_chunk.offset, overlapping_chunk.size); + } + } + interval_tree.Add(start_time, end_time - 1, chunk); + VLOG(3) << " buffer: " << buffer.ToString() << ": (" << start_time << ", " + << end_time << ") off: " << position_and_chunk.second.offset + << ", size: " << position_and_chunk.second.size; + } + + return Status::OK(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h index 67ced4c4909..d83e888f5ab 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.h +++ b/tensorflow/compiler/xla/service/memory_space_assignment.h @@ -123,6 +123,11 @@ class PrefetchIntervalPicker { int64 start_time, int64 end_time) const = 0; + // Returns the preferred end time for an eviction that starts at a given time + // and must end by the given end time. + virtual int64 PreferredEvictionEndTime(const Shape& shape, int64 start_time, + int64 latest_end_time) const = 0; + // Begins the iterator for the first start time of the prefetch. virtual void Begin(const HloUse& use, int64 start_time, int64 end_time) = 0; @@ -166,6 +171,9 @@ class InstructionCountPrefetchIntervalPicker : public PrefetchIntervalPicker { bool CanAllocateInAlternateMemoryNoCopy(const Shape& shape, int64 start_time, int64 end_time) const override; + int64 PreferredEvictionEndTime(const Shape& shape, int64 start_time, + int64 latest_end_time) const override; + void Begin(const HloUse& use, int64 start_time, int64 end_time) override; int64 Next() override; @@ -206,6 +214,9 @@ class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker { bool CanAllocateInAlternateMemoryNoCopy(const Shape& shape, int64 start_time, int64 end_time) const override; + int64 PreferredEvictionEndTime(const Shape& shape, int64 start_time, + int64 latest_end_time) const override; + void Begin(const HloUse& use, int64 start_time, int64 end_time) override; int64 Next() override; @@ -288,6 +299,10 @@ class MemorySpaceAssignment { // If true, tries allocating buffers across (e.g., before and inside a while // loop body) sequential calls (kWhile, kCall, and kConditional). bool allocate_across_sequential_calls = false; + + // If true, verifies the memory space assignment against overlapping + // buffers. + bool verify = false; }; // This class represents an allocation that might either be in the default or @@ -460,6 +475,9 @@ class MemorySpaceAssignment { static BufferIntervalCompare GetMemoryBoundednessBufferIntervalCompare( const MemorySpaceAssignmentCostAnalysis& cost_analysis); + // Verify that the memory space assignment is free of overlapping buffers. + Status Verify() const; + private: MemorySpaceAssignment(HloModule* module, int64 alternate_memory_space, const HloLiveRange& hlo_live_range) @@ -526,6 +544,48 @@ struct RequiredMemoryAssignment { int64 time; }; +// A struct representing an asynchronous copy with its logical start and end +// time and its destination memory space. +struct AsynchronousCopy { + int64 start_time; + int64 end_time; + MemorySpaceAssignment::MemorySpace destination; +}; + +// Compare asynchronous copies such that an earlier start time has the same or +// earlier end time and an earlier end time has the same or earlier start time. +bool operator<(const AsynchronousCopy& a, const AsynchronousCopy& b); + +// Helper class to enforce asynchronous copy ordering. We only allow +// asynchronous copies that are pipelined: if an asynchronous copy ends earlier +// than another asynchronous copy, it must start the same time or earlier than +// the other asynchronous copy; and if an asynchronous copy starts earlier than +// another asynchronous copy, it must end the same time or earlier than the +// other asynchronous copy. +class AsynchronousCopyOrdering { + public: + AsynchronousCopyOrdering() = default; + + // Adds an asynchronous copy. + void AddCopy(const AsynchronousCopy& copy); + + // Returns true if the addition of an asynchronous copy in the the given time + // interval would violate the asynchronous copy ordering. E.g., consider the + // following scenario: + // CS CD + // already committed async copy: +-----------+ + // new async copy: +--------+ + // + // The new asynchronous copy would violate the ordering guarantee because the + // copy start is after an already committed asynchronous copy while its copy + // done is before the committed copy. + bool ViolatesOrdering(int64 start_time, int64 end_time) const; + + private: + // Stores asynchronous copies in a tree set respecting the pipelining order. + std::set ranges_; +}; + // This class inherits from GlobalDecreasingSizeBestFitHeap with a notion of // maximum size. class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { @@ -551,14 +611,6 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { HeapSimulator::Result Finish() override; private: - // A struct representing an asynchronous copy with its logical start and end - // time and its destination memory space. - struct AsynchronousCopy { - int64 start_time; - int64 end_time; - MemorySpace destination; - }; - // Finds an allocation for the given interval. Internally, it will attempt to // find a suitable chunk candidate within the heap size and prefetch interval // limits, and append the new allocation(s) to allocations. The new @@ -603,18 +655,6 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { bool ViolatesMaximumOutstandingAsyncCopies(int64 start_time, int64 end_time) const; - // Returns true if the addition of an asynchronous copy in the the given time - // interval would violate the asynchronous copy ordering. E.g., consider the - // following scenario: - // CS CD - // already committed async copy: +-----------+ - // new async copy: +--------+ - // - // The new asynchronous copy would violate the ordering guarantee because the - // copy start is after an already committed asynchronous copy while its copy - // done is before the committed copy. - bool ViolatesAsynchronousCopyOrdering(int64 start_time, int64 end_time) const; - // Adds an asynchronous copy to the allocations. void AddAsyncCopy(const MemorySpaceAssignment::Allocation& prev_allocation, MemorySpace memory_space, Chunk chunk, int64 start_time, @@ -639,9 +679,7 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { // We use a interval tree to keep track of the number of outstanding // asynchronous copies. BufferIntervalTree async_copy_interval_tree_; - // Given the logical time for CopyDone in key, stores the earliest time for - // the corresponding CopyStart. - std::map async_copy_range_map_; + AsynchronousCopyOrdering async_copy_ordering_; std::vector> pending_chunks_; std::vector pending_async_copies_; // This map contains required memory assignments for HloValues (e.g., input diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc index 238bbed37c4..1d015507867 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc @@ -67,9 +67,9 @@ class MemorySpaceAssignmentTest : public HloTestBase, std::unique_ptr AssignMemorySpace( HloModule* module, int64 max_outstanding_async_copies = -1, - int64 max_prefetch_interval = 10) { + int64 max_prefetch_interval = 10, int64 min_prefetch_interval = 2) { InstructionCountPrefetchIntervalPicker prefetch_interval_picker( - /*min_overlap_count=*/2, max_prefetch_interval); + min_prefetch_interval, max_prefetch_interval); return AssignMemorySpace(module, max_outstanding_async_copies, /*buffer_interval_compare=*/{}, &prefetch_interval_picker); @@ -107,6 +107,7 @@ class MemorySpaceAssignmentTest : public HloTestBase, options.is_allowed_in_alternate_mem_fn = is_allowed_in_alternate_mem; options.max_outstanding_async_copies = max_outstanding_async_copies; options.allocate_across_sequential_calls = GetParam(); + options.verify = true; std::unique_ptr preset_assignments = MemorySpaceAssignment::Run(module, options).ValueOrDie(); CheckPresetAssignments(preset_assignments.get()); @@ -430,6 +431,103 @@ TEST_P(MemorySpaceAssignmentTest, DontEvictWhenThereIsDefaultMemAllocation) { EXPECT_THAT(h, op::Multiply(op::Subtract(), op::Multiply())); } +TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchAndPrefetch) { + // Test for a memory corruption bug involving evict/prefetch/prefetch pattern, + // where the last prefetch copied from the original buffer in alternate buffer + // instead of evicted buffer. + HloComputation::Builder builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); + HloInstruction* p0 = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); + HloInstruction* p1 = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1")); + HloInstruction* tanh = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0)); + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, tanh)); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1)); + HloInstruction* c = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, p0, p1)); + HloInstruction* d = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1)); + HloInstruction* e = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, b)); + HloInstruction* f = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, c)); + HloInstruction* g = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, d)); + HloInstruction* h = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, b, c)); + HloInstruction* i = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, b, d)); + HloInstruction* j = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, c, d)); + HloInstruction* k = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, e, f)); + HloInstruction* l = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, g, h)); + HloInstruction* m = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, i, j)); + HloInstruction* n = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, k, l)); + HloInstruction* o = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, n, m)); + HloInstruction* add0 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, o, tanh)); + HloInstruction* negate0 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, add0)); + HloInstruction* negate1 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0)); + HloInstruction* negate2 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1)); + HloInstruction* negate3 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2)); + HloInstruction* negate4 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3)); + HloInstruction* negate5 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4)); + HloInstruction* negate6 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5)); + HloInstruction* negate7 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6)); + HloInstruction* negate8 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate7)); + HloInstruction* negate9 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate8)); + HloInstruction* add1 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, negate9, tanh)); + + auto module = CreateNewVerifiedModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + HloSchedule schedule(module.get()); + schedule.set_sequence( + computation, + {p0, p1, tanh, a, b, c, d, e, + f, g, h, i, j, k, l, m, + n, o, add0, negate0, negate1, negate2, negate3, negate4, + negate5, negate6, negate7, negate8, negate9, add1}); + TF_CHECK_OK(module->set_schedule(schedule)); + + AssignMemorySpace(module.get()); + + // Check that both prefetches (add0 and add1) prefetch from the eviction + // instead of tanh, which will be placed in the alternate memory directly. + EXPECT_THAT( + add0, + op::Add(op::Add(), + op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, + op::AsyncCopy(kDefaultMemorySpace, + kAlternateMemorySpace, op::Tanh())))); + EXPECT_THAT( + add1, + op::Add(op::Negate(), + op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, + op::AsyncCopy(kDefaultMemorySpace, + kAlternateMemorySpace, op::Tanh())))); +} + TEST_P(MemorySpaceAssignmentTest, While) { auto module = CreateNewVerifiedModule(); Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3}); @@ -759,6 +857,77 @@ TEST_P(MemorySpaceAssignmentTest, BitcastTuple) { AssignMemorySpace(module.get()); } +TEST_P(MemorySpaceAssignmentTest, BitcastScheduleBug) { + // Bitcasts can force asynchronous copies to be scheduled too early, possibly + // leading to memory corruption. + // Bug: + // p0------------------>neg-->neg-->neg ... -->neg-->neg-->neg->add + // / + // p1->cs->cd->bitcast-----------------------------------------+ + // + // Expected: + // p0-->neg-->neg-->neg ... -->neg-->neg-->neg------------->add + // / + // p1--------------------->cs----------------->cd->bitcast-+ + HloComputation::Builder builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); + Shape param_shape = ShapeUtil::MakeShape(F32, {6}); + HloInstruction* p0 = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); + HloInstruction* p1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, param_shape, "p1")); + HloInstruction* bitcast = + builder.AddInstruction(HloInstruction::CreateBitcast(shape, p1)); + HloInstruction* negate0 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0)); + HloInstruction* negate1 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0)); + HloInstruction* negate2 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1)); + HloInstruction* negate3 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2)); + HloInstruction* negate4 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3)); + HloInstruction* negate5 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4)); + HloInstruction* negate6 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5)); + HloInstruction* negate7 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6)); + HloInstruction* negate8 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate7)); + HloInstruction* negate9 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate8)); + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, bitcast, negate9)); + + auto module = CreateNewVerifiedModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + HloSchedule schedule(module.get()); + schedule.set_sequence( + computation, {p0, p1, bitcast, negate0, negate1, negate2, negate3, + negate4, negate5, negate6, negate7, negate8, negate9, add}); + TF_CHECK_OK(module->set_schedule(schedule)); + + AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1, + /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/4); + + EXPECT_EQ(bitcast->shape().layout().memory_space(), kAlternateMemorySpace); + const auto& instructions = + module->schedule().sequence(module->entry_computation()).instructions(); + for (int i = 0; i < instructions.size(); ++i) { + // Expect that there is a negate before and after the CopyStart and there is + // a negate before CopyDone. + if (instructions.at(i)->opcode() == HloOpcode::kCopyStart) { + EXPECT_EQ(instructions.at(i - 1)->opcode(), HloOpcode::kNegate); + EXPECT_EQ(instructions.at(i + 1)->opcode(), HloOpcode::kNegate); + } else if (instructions.at(i)->opcode() == HloOpcode::kCopyDone) { + EXPECT_EQ(instructions.at(i - 1)->opcode(), HloOpcode::kNegate); + } + } +} + TEST_P(MemorySpaceAssignmentTest, LastUseOpt) { // Test that checks the last use optimization. It uses two buffers that should // be placed in alternate memory. @@ -2266,5 +2435,38 @@ INSTANTIATE_TEST_SUITE_P(MemorySpaceAssignmentInstantiation, MemorySpaceAssignmentTest, ::testing::Values(false, true)); +using AsynchronousCopyOrderingTest = ::testing::Test; + +TEST_F(AsynchronousCopyOrderingTest, Simple) { + // Given asynchronous copies like the following, ensure the pipelining order + // is maintained (earlier start time must have earlier end time). + // 3,11 +-------+ OK + // 1,8 +------+ OK + // 5,14 +--------+ OK + // 7,14 +------+ OK + // 2,16 +-------------+ Violate + // 9,12 +--+ Violate + // 6,17 +----------+ Violate + // 5,13 +-------+ OK (same start as 5,14) + // 5,14 +--------+ OK (same as 5,14) + auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate; + AsynchronousCopyOrdering ordering; + EXPECT_FALSE(ordering.ViolatesOrdering(3, 11)); + ordering.AddCopy({3, 11, alternate_mem_space}); + EXPECT_FALSE(ordering.ViolatesOrdering(1, 8)); + ordering.AddCopy({1, 8, alternate_mem_space}); + EXPECT_FALSE(ordering.ViolatesOrdering(5, 14)); + ordering.AddCopy({5, 14, alternate_mem_space}); + EXPECT_FALSE(ordering.ViolatesOrdering(7, 14)); + ordering.AddCopy({7, 14, alternate_mem_space}); + EXPECT_TRUE(ordering.ViolatesOrdering(2, 16)); + EXPECT_TRUE(ordering.ViolatesOrdering(9, 12)); + EXPECT_TRUE(ordering.ViolatesOrdering(6, 17)); + EXPECT_FALSE(ordering.ViolatesOrdering(5, 13)); + ordering.AddCopy({5, 13, alternate_mem_space}); + EXPECT_FALSE(ordering.ViolatesOrdering(5, 14)); + ordering.AddCopy({5, 14, alternate_mem_space}); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/mlir_gpu/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/BUILD index b687d72d3d9..20b448286d5 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/BUILD +++ b/tensorflow/compiler/xla/service/mlir_gpu/BUILD @@ -37,7 +37,7 @@ cc_library( deps = [ "//tensorflow/compiler/xla/service:hlo", "@com_google_absl//absl/strings", - "@local_config_mlir//:IR", + "@llvm-project//mlir:IR", ], ) @@ -46,8 +46,8 @@ cc_library( srcs = ["inject_errors_pass.cc"], hdrs = ["inject_errors_pass.h"], deps = [ - "@local_config_mlir//:Pass", - "@local_config_mlir//:StandardOps", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", ], ) @@ -81,12 +81,12 @@ cc_library( "//tensorflow/stream_executor:stream_executor_headers", "//tensorflow/stream_executor/gpu:asm_compiler", "@com_google_absl//absl/container:flat_hash_map", - "@local_config_mlir//:GPUDialect", - "@local_config_mlir//:IR", - "@local_config_mlir//:LLVMDialect", - "@local_config_mlir//:StandardOps", - "@local_config_mlir//:Support", - "@local_config_mlir//:TargetNVVMIR", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TargetNVVMIR", ], alwayslink = True, # Contains compiler registration ) @@ -103,9 +103,9 @@ cc_library( "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla/service:hlo", "@com_google_absl//absl/types:span", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:StandardOps", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:StandardOps", ], ) @@ -127,9 +127,9 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/stream_executor:stream_executor_headers", "@com_google_absl//absl/container:flat_hash_map", - "@local_config_mlir//:IR", - "@local_config_mlir//:LLVMDialect", - "@local_config_mlir//:StandardOps", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:StandardOps", ], ) @@ -151,26 +151,26 @@ cc_library( "//tensorflow/compiler/xla:util", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", - "@local_config_mlir//:AffineDialectRegistration", - "@local_config_mlir//:CFGTransforms", - "@local_config_mlir//:GPUDialect", - "@local_config_mlir//:GPUDialectRegistration", - "@local_config_mlir//:GPUToNVVMTransforms", - "@local_config_mlir//:GPUTransforms", - "@local_config_mlir//:IR", - "@local_config_mlir//:LLVMDialect", - "@local_config_mlir//:LLVMTransforms", - "@local_config_mlir//:Linalg", - "@local_config_mlir//:LinalgDialectRegistration", - "@local_config_mlir//:LinalgToLLVM", - "@local_config_mlir//:LoopDialectRegistration", - "@local_config_mlir//:LoopOps", - "@local_config_mlir//:LoopsToGPUPass", - "@local_config_mlir//:NVVMDialect", - "@local_config_mlir//:Pass", - "@local_config_mlir//:StandardDialectRegistration", - "@local_config_mlir//:StandardOps", - "@local_config_mlir//:Transforms", + "@llvm-project//mlir:AffineDialectRegistration", + "@llvm-project//mlir:CFGTransforms", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:GPUDialectRegistration", + "@llvm-project//mlir:GPUToNVVMTransforms", + "@llvm-project//mlir:GPUTransforms", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMTransforms", + "@llvm-project//mlir:Linalg", + "@llvm-project//mlir:LinalgDialectRegistration", + "@llvm-project//mlir:LinalgToLLVM", + "@llvm-project//mlir:LoopDialectRegistration", + "@llvm-project//mlir:LoopOps", + "@llvm-project//mlir:LoopsToGPUPass", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardDialectRegistration", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Transforms", ], ) @@ -191,8 +191,8 @@ cc_library( "//tensorflow/core:test", "//tensorflow/core/platform:test", "@com_google_absl//absl/memory", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:Pass", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", ], ) diff --git a/tensorflow/compiler/xla/service/mlir_gpu/emission_context.cc b/tensorflow/compiler/xla/service/mlir_gpu/emission_context.cc index 08a133a9b52..3c27dc662fe 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/emission_context.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/emission_context.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/mlir_gpu/emission_context.h" #include "absl/strings/substitute.h" -#include "mlir/IR/Location.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project #include "tensorflow/compiler/xla/service/hlo_instruction.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/mlir_gpu/emission_context.h b/tensorflow/compiler/xla/service/mlir_gpu/emission_context.h index cbea4c48568..db702dbc014 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/emission_context.h +++ b/tensorflow/compiler/xla/service/mlir_gpu/emission_context.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "mlir/IR/Diagnostics.h" // TF:local_config_mlir +#include "mlir/IR/Diagnostics.h" // TF:llvm-project #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" diff --git a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/BUILD index eda65583fb5..72acc5463ca 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/BUILD +++ b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/BUILD @@ -31,11 +31,11 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "@com_google_absl//absl/types:span", - "@llvm//:support", - "@local_config_mlir//:AffineOps", - "@local_config_mlir//:IR", - "@local_config_mlir//:StandardOps", - "@local_config_mlir//:TransformUtils", + "@llvm-project//llvm:support", + "@llvm-project//mlir:AffineOps", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:TransformUtils", ], ) @@ -50,13 +50,13 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/platform:test", - "@llvm//:support", - "@local_config_mlir//:AffineDialectRegistration", - "@local_config_mlir//:AffineToStandardTransforms", - "@local_config_mlir//:IR", - "@local_config_mlir//:LLVMTransforms", - "@local_config_mlir//:Pass", - "@local_config_mlir//:StandardDialectRegistration", - "@local_config_mlir//:Transforms", + "@llvm-project//llvm:support", + "@llvm-project//mlir:AffineDialectRegistration", + "@llvm-project//mlir:AffineToStandardTransforms", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMTransforms", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardDialectRegistration", + "@llvm-project//mlir:Transforms", ], ) diff --git a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.cc index 84e239ae196..4ed8745a251 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.cc @@ -30,13 +30,13 @@ limitations under the License. #include "absl/types/span.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/Dialect/AffineOps/AffineOps.h" // TF:local_config_mlir -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/AffineExpr.h" // TF:local_config_mlir -#include "mlir/IR/AffineMap.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/Transforms/LoopUtils.h" // TF:local_config_mlir -#include "mlir/Transforms/RegionUtils.h" // TF:local_config_mlir +#include "mlir/Dialect/AffineOps/AffineOps.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/AffineExpr.h" // TF:llvm-project +#include "mlir/IR/AffineMap.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/Transforms/LoopUtils.h" // TF:llvm-project +#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/window_util.h" @@ -117,18 +117,18 @@ bool IsSimpleLoop(mlir::AffineForOp loop) { struct BoundAffineMap { mlir::AffineMap affine_map; - std::vector operands; + std::vector operands; }; BoundAffineMap GetBoundAffineMapFrom(mlir::Operation* op) { if (auto load = mlir::dyn_cast(op)) { return {load.getAffineMap(), - std::vector(load.getMapOperands().begin(), - load.getMapOperands().end())}; + std::vector(load.getMapOperands().begin(), + load.getMapOperands().end())}; } else if (auto store = mlir::dyn_cast(op)) { return {store.getAffineMap(), - std::vector(store.getMapOperands().begin(), - store.getMapOperands().end())}; + std::vector(store.getMapOperands().begin(), + store.getMapOperands().end())}; } else { CHECK(false); } @@ -150,7 +150,7 @@ mlir::Operation* CloneWithNewAffineMap(mlir::Operation* op, } } -void SetMemRef(mlir::Operation* op, mlir::Value* memref) { +void SetMemRef(mlir::Operation* op, mlir::Value memref) { if (auto load = mlir::dyn_cast(op)) { load.setMemRef(memref); } else if (auto store = mlir::dyn_cast(op)) { @@ -257,7 +257,7 @@ mlir::AffineForOp TileLoop(mlir::AffineForOp loop, int64_t size, } for (mlir::IROperand& use : - llvm::make_early_inc_range(loop.getInductionVar()->getUses())) { + llvm::make_early_inc_range(loop.getInductionVar().getUses())) { mlir::Operation* owner = use.getOwner(); BoundAffineMap affine_map = GetBoundAffineMapFrom(owner); unsigned new_dim = affine_map.operands.size(); @@ -325,12 +325,12 @@ mlir::Operation* HoistAndFix(llvm::iplist::iterator begin_op, auto new_alloc = builder.create(builder.getUnknownLoc(), new_type); - std::vector indvars; + std::vector indvars; for (auto ancestor : ancestors) { indvars.push_back(ancestor.getInductionVar()); } for (mlir::IROperand& use : - llvm::make_early_inc_range(alloc.getResult()->getUses())) { + llvm::make_early_inc_range(alloc.getResult().getUses())) { mlir::Operation* owner = use.getOwner(); BoundAffineMap affine_map = GetBoundAffineMapFrom(owner); affine_map.operands.insert(affine_map.operands.begin(), indvars.begin(), @@ -418,7 +418,7 @@ struct InitialMlirConvAnchors { // output[...] = output_acc[] // } StatusOr CreateNaiveMlirConv( - mlir::Value* input, mlir::Value* filter, mlir::Value* output, + mlir::Value input, mlir::Value filter, mlir::Value output, const ShapeInfo& input_shape_info, const ShapeInfo& filter_shape_info, const ShapeInfo& output_shape_info, const Window& window, mlir::OpBuilder builder) { @@ -440,7 +440,7 @@ StatusOr CreateNaiveMlirConv( location, builder.create( location, mlir::FloatAttr::get(builder.getF32Type(), 0)), - output_acc, llvm::ArrayRef()); + output_acc, llvm::ArrayRef()); std::vector reduction_loops; reduction_loops = CreateNestedSimpleLoops( @@ -450,11 +450,11 @@ StatusOr CreateNaiveMlirConv( mlir::AffineForOp loop_o = cartesian_product_loops[1]; mlir::AffineForOp loop_c = reduction_loops[0]; - std::vector output_spatial_indvars; + std::vector output_spatial_indvars; for (auto loop : absl::MakeSpan(cartesian_product_loops).subspan(2)) { output_spatial_indvars.push_back(loop.getInductionVar()); } - std::vector filter_spatial_indvars; + std::vector filter_spatial_indvars; for (auto loop : absl::MakeSpan(reduction_loops).subspan(1)) { filter_spatial_indvars.push_back(loop.getInductionVar()); } @@ -463,7 +463,7 @@ StatusOr CreateNaiveMlirConv( builder = reduction_loops.back().getBodyBuilder(); - mlir::Value* loaded_input = [&] { + mlir::Value loaded_input = [&] { std::vector input_indices; input_indices.push_back(builder.getAffineDimExpr(0)); input_indices.push_back(builder.getAffineDimExpr(1)); @@ -479,7 +479,7 @@ StatusOr CreateNaiveMlirConv( builder.getAffineDimExpr(2 + num_spatial_dims + i) - window_dim.padding_low()); } - std::vector input_vars; + std::vector input_vars; input_vars.push_back(loop_n.getInductionVar()); input_vars.push_back(loop_c.getInductionVar()); input_vars.insert(input_vars.end(), output_spatial_indvars.begin(), @@ -499,8 +499,8 @@ StatusOr CreateNaiveMlirConv( builder.getF32Type()); }(); - mlir::Value* loaded_filter = [&] { - std::vector filter_vars; + mlir::Value loaded_filter = [&] { + std::vector filter_vars; filter_vars.push_back(loop_o.getInductionVar()); filter_vars.push_back(loop_c.getInductionVar()); filter_vars.insert(filter_vars.end(), filter_spatial_indvars.begin(), @@ -519,11 +519,11 @@ StatusOr CreateNaiveMlirConv( location, builder.createOrFold(location, output_acc), builder.create(location, loaded_input, loaded_filter)), - output_acc, llvm::ArrayRef()); + output_acc, llvm::ArrayRef()); builder.setInsertionPointAfter(reduction_loops[0]); { - std::vector output_vars; + std::vector output_vars; output_vars.push_back(loop_n.getInductionVar()); output_vars.push_back(loop_o.getInductionVar()); output_vars.insert(output_vars.end(), output_spatial_indvars.begin(), @@ -735,9 +735,9 @@ StatusOr EmitConvolutionForwardAsMlir( builder.create(builder.getUnknownLoc()); builder.setInsertionPointToStart(entry_block); - mlir::Value* input = entry_block->getArgument(1); - mlir::Value* filter = entry_block->getArgument(2); - mlir::Value* output = entry_block->getArgument(0); + mlir::Value input = entry_block->getArgument(1); + mlir::Value filter = entry_block->getArgument(2); + mlir::Value output = entry_block->getArgument(0); TF_RETURN_IF_ERROR(ConvIsImplemented(conv)); diff --git a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.h b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.h index f0b95876775..5f01dffb756 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.h +++ b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_EXPERIMENTAL_CONV_EMITTER_CONV_EMITTER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_EXPERIMENTAL_CONV_EMITTER_CONV_EMITTER_H_ -#include "mlir/IR/Function.h" // TF:local_config_mlir +#include "mlir/IR/Function.h" // TF:llvm-project #include "tensorflow/compiler/xla/service/hlo_instruction.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_test.cc b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_test.cc index 00a93455a8b..78cc83dd0bd 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_test.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_test.cc @@ -18,13 +18,13 @@ limitations under the License. #include #include "llvm/Support/raw_ostream.h" -#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // TF:local_config_mlir -#include "mlir/IR/Location.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Pass/PassManager.h" // TF:local_config_mlir -#include "mlir/Transforms/Passes.h" // TF:local_config_mlir +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/PassManager.h" // TF:llvm-project +#include "mlir/Transforms/Passes.h" // TF:llvm-project #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/tests/filecheck.h" #include "tensorflow/compiler/xla/tests/verified_hlo_module.h" diff --git a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc index 60b5d086d15..ae3e42bc20d 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc @@ -16,10 +16,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.h" #include "llvm/ADT/STLExtras.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/Types.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project #include "tensorflow/compiler/mlir/xla/hlo_utils.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/xla/comparison_util.h" @@ -43,14 +43,21 @@ using ::mlir::Value; namespace hlo = ::mlir::xla_hlo; // TODO(b/137624192) Use tablegen for this. -StatusOr InsertMlirOp( - HloOpcode opcode, OpBuilder func_builder, Location loc, ArrayRef rets, - ArrayRef args, ArrayRef> attrs) { +StatusOr InsertMlirOp(HloOpcode opcode, OpBuilder func_builder, + Location loc, ArrayRef rets, + ArrayRef args, + ArrayRef> attrs) { switch (opcode) { + case HloOpcode::kAbs: + return {func_builder.create(loc, rets, args, attrs)}; case HloOpcode::kAdd: return {func_builder.create(loc, rets, args, attrs)}; case HloOpcode::kAnd: return {func_builder.create(loc, rets, args, attrs)}; + case HloOpcode::kCeil: + return {func_builder.create(loc, rets, args, attrs)}; + case HloOpcode::kCos: + return {func_builder.create(loc, rets, args, attrs)}; case HloOpcode::kDivide: return {func_builder.create(loc, rets, args, attrs)}; case HloOpcode::kExp: @@ -61,10 +68,18 @@ StatusOr InsertMlirOp( return {func_builder.create(loc, rets, args, attrs)}; case HloOpcode::kMultiply: return {func_builder.create(loc, rets, args, attrs)}; + case HloOpcode::kNegate: + return {func_builder.create(loc, rets, args, attrs)}; + case HloOpcode::kRemainder: + return {func_builder.create(loc, rets, args, attrs)}; case HloOpcode::kSelect: return {func_builder.create(loc, rets, args, attrs)}; + case HloOpcode::kSign: + return {func_builder.create(loc, rets, args, attrs)}; case HloOpcode::kSubtract: return {func_builder.create(loc, rets, args, attrs)}; + case HloOpcode::kTanh: + return {func_builder.create(loc, rets, args, attrs)}; default: return tensorflow::errors::Internal(absl::StrCat( "HLO Opcode ", HloOpcodeString(opcode), " is not supported.")); @@ -78,7 +93,7 @@ mlir::Location HloDialectEmitter::getLocation( return emission_context_->getLocation(instr); } -StatusOr HloDialectEmitter::EmitComputation( +StatusOr HloDialectEmitter::EmitComputation( const HloComputation& computation) { const auto root = computation.root_instruction(); TF_RETURN_IF_ERROR(root->Accept(this)); @@ -88,7 +103,7 @@ StatusOr HloDialectEmitter::EmitComputation( Status HloDialectEmitter::DefaultAction(HloInstruction* instr) { TF_ASSIGN_OR_RETURN(auto res_type, ConvertTensorShapeToType( instr->shape(), builder_)); - llvm::SmallVector arguments; + llvm::SmallVector arguments; for (auto operand : instr->operands()) { arguments.push_back(instruction_to_values_[operand]); } @@ -135,7 +150,7 @@ Status HloDialectEmitter::HandleConstant(HloInstruction* constant) { } Status HloDialectEmitter::HandleReduce(HloInstruction* reduce) { - llvm::SmallVector operands; + llvm::SmallVector operands; for (auto operand : reduce->operands()) { operands.push_back(instruction_to_values_.at(operand)); } @@ -152,7 +167,7 @@ Status HloDialectEmitter::HandleReduce(HloInstruction* reduce) { { auto computation = reduce->to_apply(); auto block = new mlir::Block(); - llvm::SmallVector arguments; + llvm::SmallVector arguments; arguments.reserve(computation->num_parameters()); for (auto parameter : computation->parameter_instructions()) { TF_ASSIGN_OR_RETURN(auto param_type, @@ -166,7 +181,7 @@ Status HloDialectEmitter::HandleReduce(HloInstruction* reduce) { OpBuilder body_builder(block); body_builder.setInsertionPointToEnd(block); body_builder.create(getLocation(reduce), - ArrayRef{result}); + ArrayRef{result}); } // TODO(b/137624192) Add support for multiple results. instruction_to_values_[reduce] = reduceOp.getResult(0); @@ -180,7 +195,7 @@ Status HloDialectEmitter::HandleCompare(HloInstruction* compare) { "comparison_direction", builder_.getStringAttr( ComparisonDirectionToString(compare->comparison_direction()))); - llvm::SmallVector arguments; + llvm::SmallVector arguments; for (auto operand : compare->operands()) { arguments.push_back(instruction_to_values_[operand]); } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.h b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.h index 86ed97b3c58..a1ec6d88644 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.h +++ b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.h @@ -20,10 +20,10 @@ limitations under the License. #include "absl/types/span.h" #include "llvm/ADT/ArrayRef.h" -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -37,19 +37,19 @@ class HloDialectEmitter : public DfsHloVisitorWithDefault { public: HloDialectEmitter(xla::mlir_gpu::EmissionContext* emission_context, ::mlir::Region* region, - llvm::ArrayRef<::mlir::Value*> arguments) + llvm::ArrayRef<::mlir::Value> arguments) : emission_context_(emission_context), builder_(region), arguments_(arguments) {} HloDialectEmitter(xla::mlir_gpu::EmissionContext* emission_context, ::mlir::OpBuilder builder, - llvm::ArrayRef<::mlir::Value*> arguments) + llvm::ArrayRef<::mlir::Value> arguments) : emission_context_(emission_context), builder_(builder), arguments_(arguments) {} - StatusOr EmitComputation(const HloComputation& computation); + StatusOr EmitComputation(const HloComputation& computation); Status DefaultAction(HloInstruction* instr) override; Status HandleBroadcast(HloInstruction* broadcast) override; @@ -64,8 +64,8 @@ class HloDialectEmitter : public DfsHloVisitorWithDefault { xla::mlir_gpu::EmissionContext* emission_context_; ::mlir::OpBuilder builder_; - llvm::ArrayRef<::mlir::Value*> arguments_; - absl::flat_hash_map + llvm::ArrayRef<::mlir::Value> arguments_; + absl::flat_hash_map instruction_to_values_; }; diff --git a/tensorflow/compiler/xla/service/mlir_gpu/inject_errors_pass.h b/tensorflow/compiler/xla/service/mlir_gpu/inject_errors_pass.h index 832d43ad562..1e0e41868ca 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/inject_errors_pass.h +++ b/tensorflow/compiler/xla/service/mlir_gpu/inject_errors_pass.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_INJECT_ERRORS_PASS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_INJECT_ERRORS_PASS_H_ -#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/Pass/Pass.h" // TF:llvm-project namespace mlir { diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc index 186dacc06e6..c878c90ef2a 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc @@ -17,33 +17,32 @@ limitations under the License. #include -#include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" -#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" // TF:local_config_mlir -#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" // TF:local_config_mlir -#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" // TF:local_config_mlir -#include "mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h" // TF:local_config_mlir -#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // TF:local_config_mlir -#include "mlir/Dialect/GPU/GPUDialect.h" // TF:local_config_mlir -#include "mlir/Dialect/GPU/Passes.h" // TF:local_config_mlir -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // TF:local_config_mlir -#include "mlir/Dialect/LLVMIR/NVVMDialect.h" // TF:local_config_mlir -#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // TF:local_config_mlir -#include "mlir/Dialect/Linalg/Passes.h" // TF:local_config_mlir -#include "mlir/Dialect/LoopOps/LoopOps.h" // TF:local_config_mlir -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/BlockAndValueMapping.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir -#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir -#include "mlir/IR/Region.h" // TF:local_config_mlir -#include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/Pass/PassManager.h" // TF:local_config_mlir -#include "mlir/Transforms/DialectConversion.h" // TF:local_config_mlir -#include "mlir/Transforms/Passes.h" // TF:local_config_mlir +#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" // TF:llvm-project +#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" // TF:llvm-project +#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" // TF:llvm-project +#include "mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h" // TF:llvm-project +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // TF:llvm-project +#include "mlir/Dialect/GPU/GPUDialect.h" // TF:llvm-project +#include "mlir/Dialect/GPU/Passes.h" // TF:llvm-project +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // TF:llvm-project +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" // TF:llvm-project +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // TF:llvm-project +#include "mlir/Dialect/Linalg/Passes.h" // TF:llvm-project +#include "mlir/Dialect/LoopOps/LoopOps.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/OperationSupport.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/IR/Region.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/PassManager.h" // TF:llvm-project +#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project +#include "mlir/Transforms/Passes.h" // TF:llvm-project #include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" #include "tensorflow/compiler/mlir/xla/transforms/rewriters.h" @@ -108,8 +107,8 @@ struct FusionOpRemover : public mlir::FunctionPass { struct SingleTripLoopRemoval : public mlir::FunctionPass { void runOnFunction() override { - auto getConstantValue = [](mlir::Value* value) -> llvm::Optional { - auto definingOp = value->getDefiningOp(); + auto getConstantValue = [](mlir::Value value) -> llvm::Optional { + auto definingOp = value.getDefiningOp(); if (!definingOp) return llvm::None; auto constantOp = llvm::dyn_cast(definingOp); if (!constantOp) return llvm::None; @@ -145,7 +144,7 @@ struct SingleTripLoopRemoval // same address with the stored value. This needs generalization. struct StoreForwardingPass : mlir::FunctionPass { void runOnFunction() override { - absl::flat_hash_map memrefToAllocOp; + llvm::DenseMap memrefToAllocOp; getFunction().walk([&](mlir::LoadOp loadOp) { auto* block = loadOp.getOperation()->getBlock(); @@ -180,10 +179,10 @@ struct StoreForwardingPass : mlir::FunctionPass { // Recursively checks defining ops until finds AllocOp. Return either AllocOp // if it is found or nullptr. - mlir::Operation* SearchAllocOp(mlir::Value* memref) { - mlir::Operation* defOp = memref->getDefiningOp(); + mlir::Operation* SearchAllocOp(mlir::Value memref) { + mlir::Operation* defOp = memref.getDefiningOp(); while (auto subviewOp = mlir::dyn_cast_or_null(defOp)) { - defOp = subviewOp.source()->getDefiningOp(); + defOp = subviewOp.source().getDefiningOp(); } if (auto allocOp = mlir::dyn_cast_or_null(defOp)) { return allocOp.getOperation(); @@ -193,8 +192,8 @@ struct StoreForwardingPass : mlir::FunctionPass { // Retrieves AllocOp from the cache or actually looks for it. mlir::Operation* GetAllocOp( - mlir::Value* memref, - absl::flat_hash_map* memrefToAllocOp) { + mlir::Value memref, + llvm::DenseMap* memrefToAllocOp) { auto allocOpIt = memrefToAllocOp->find(memref); if (allocOpIt != memrefToAllocOp->end()) { return allocOpIt->second; @@ -212,7 +211,7 @@ struct StoreForwardingPass : mlir::FunctionPass { struct DeadTempBufferRemoval : mlir::FunctionPass { bool operationConsideredDead(mlir::Operation* op) { for (auto result : op->getResults()) { - if (!llvm::all_of(result->getUsers(), [&](mlir::Operation* op) { + if (!llvm::all_of(result.getUsers(), [&](mlir::Operation* op) { // Store and Dealloc is OK. if (llvm::isa(op) || llvm::isa(op)) { @@ -236,7 +235,7 @@ struct DeadTempBufferRemoval : mlir::FunctionPass { void recursiveErase(mlir::Operation* op) { for (auto result : op->getResults()) { - for (auto user : llvm::make_early_inc_range(result->getUsers())) { + for (auto user : llvm::make_early_inc_range(result.getUsers())) { recursiveErase(user); } } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h index 3d4cdf49461..027c3c93dca 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h +++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_KERNEL_LOWERING_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_KERNEL_LOWERING_H_ -#include "mlir/IR/Module.h" // TF:local_config_mlir +#include "mlir/IR/Module.h" // TF:llvm-project #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc index fd38cd3bf5e..585223efa7b 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc @@ -15,14 +15,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // TF:local_config_mlir -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/Identifier.h" // TF:local_config_mlir -#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/IR/Types.h" // TF:local_config_mlir +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/Identifier.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project #include "tensorflow/compiler/mlir/xla/hlo_utils.h" #include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" @@ -59,15 +59,24 @@ namespace lhlo = ::mlir::xla_lhlo; // TODO(b/137624192) Use tablegen for this. Status InsertMlirOp(HloOpcode opcode, OpBuilder func_builder, Location loc, - ArrayRef rets, ArrayRef args, + ArrayRef rets, ArrayRef args, ArrayRef> attrs) { switch (opcode) { + case HloOpcode::kAbs: + func_builder.create(loc, rets, args, attrs); + break; case HloOpcode::kAdd: func_builder.create(loc, rets, args, attrs); break; case HloOpcode::kAnd: func_builder.create(loc, rets, args, attrs); break; + case HloOpcode::kCeil: + func_builder.create(loc, rets, args, attrs); + break; + case HloOpcode::kCos: + func_builder.create(loc, rets, args, attrs); + break; case HloOpcode::kDivide: func_builder.create(loc, rets, args, attrs); break; @@ -83,12 +92,24 @@ Status InsertMlirOp(HloOpcode opcode, OpBuilder func_builder, Location loc, case HloOpcode::kMultiply: func_builder.create(loc, rets, args, attrs); break; + case HloOpcode::kNegate: + func_builder.create(loc, rets, args, attrs); + break; + case HloOpcode::kRemainder: + func_builder.create(loc, rets, args, attrs); + break; case HloOpcode::kSelect: func_builder.create(loc, rets, args, attrs); break; + case HloOpcode::kSign: + func_builder.create(loc, rets, args, attrs); + break; case HloOpcode::kSubtract: func_builder.create(loc, rets, args, attrs); break; + case HloOpcode::kTanh: + func_builder.create(loc, rets, args, attrs); + break; default: return tensorflow::errors::Internal(absl::StrCat( "LHLO opcode ", HloOpcodeString(opcode), " is not supported.")); @@ -168,8 +189,8 @@ StatusOr LhloDialectEmitter::CreateFunction( Status LhloDialectEmitter::DefaultAction(HloInstruction* instr) { TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*instr)); OpBuilder func_builder(function.getBody()); - llvm::SmallVector arg_values{function.args_begin(), - function.args_end()}; + llvm::SmallVector arg_values{function.args_begin(), + function.args_end()}; TF_RETURN_IF_ERROR(InsertMlirOp(instr->opcode(), func_builder, getLocation(instr), ArrayRef{}, arg_values, llvm::None)); @@ -197,7 +218,7 @@ Status LhloDialectEmitter::HandleFusion(HloInstruction* fusion) { // Load the HLO argument tensors from the corresponding buffers. The last // argument is for the result, so no need to load it. OpBuilder body_builder(fusion_op.region()); - llvm::SmallVector arg_values; + llvm::SmallVector arg_values; for (int i = 0, e = function.getNumArguments() - 1; i < e; ++i) { arg_values.push_back(body_builder.create<::mlir::TensorLoadOp>( getLocation(fusion), function.getArgument(i))); @@ -211,7 +232,7 @@ Status LhloDialectEmitter::HandleFusion(HloInstruction* fusion) { // Insert the write-back from the HLO computation to the result argument // buffer. body_builder.setInsertionPoint(fusion_op.region().back().getTerminator()); - Value* result_memref = function.getArgument(function.getNumArguments() - 1); + Value result_memref = function.getArgument(function.getNumArguments() - 1); body_builder.create<::mlir::TensorStoreOp>(getLocation(fusion), result, result_memref); @@ -220,8 +241,8 @@ Status LhloDialectEmitter::HandleFusion(HloInstruction* fusion) { Status LhloDialectEmitter::HandleReduce(HloInstruction* reduce) { TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*reduce)); - llvm::SmallVector arg_values{function.args_begin(), - function.args_end()}; + llvm::SmallVector arg_values{function.args_begin(), + function.args_end()}; OpBuilder builder(function.getBody()); auto loc = getLocation(reduce); int input_count = reduce->operand_count() / 3; @@ -239,7 +260,7 @@ Status LhloDialectEmitter::HandleReduce(HloInstruction* reduce) { OpBuilder body_builder(reduce_op.body()); auto block = body_builder.getInsertionBlock(); auto to_apply = reduce->to_apply(); - llvm::SmallVector reduce_arg_values; + llvm::SmallVector reduce_arg_values; // First map parameters to memrefs on the operation. for (auto param : to_apply->parameter_instructions()) { TF_ASSIGN_OR_RETURN(auto arg_type, ConvertShapeToType( @@ -280,8 +301,8 @@ Status LhloDialectEmitter::HandleCompare(HloInstruction* compare) { TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*compare)); OpBuilder func_builder(function.getBody()); - llvm::SmallVector arg_values{function.args_begin(), - function.args_end()}; + llvm::SmallVector arg_values{function.args_begin(), + function.args_end()}; func_builder.create(getLocation(compare), llvm::None, arg_values, comparison_direction_attr); return Status::OK(); diff --git a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h index 09d6fc3a5bb..48d275ef5e0 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h +++ b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h @@ -19,10 +19,10 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/gpu/thunk_emitter.h" diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc index d332392ab2f..67ef9506fe2 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc @@ -18,17 +18,17 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" -#include "mlir/Dialect/GPU/GPUDialect.h" // TF:local_config_mlir -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // TF:local_config_mlir -#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Function.h" // TF:local_config_mlir -#include "mlir/IR/Location.h" // TF:local_config_mlir -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/IR/Value.h" // TF:local_config_mlir -#include "mlir/Support/LLVM.h" // TF:local_config_mlir -#include "mlir/Target/NVVMIR.h" // TF:local_config_mlir +#include "mlir/Dialect/GPU/GPUDialect.h" // TF:llvm-project +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/Value.h" // TF:llvm-project +#include "mlir/Support/LLVM.h" // TF:llvm-project +#include "mlir/Target/NVVMIR.h" // TF:llvm-project #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/dump.h" #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" @@ -197,23 +197,23 @@ static absl::optional getLaunchBound(const mlir::gpu::KernelDim3& dim) { op->emitError() << "bound " << name << " is not constant"; return absl::nullopt; }; - auto y_op = dim.y->getDefiningOp(); + auto y_op = dim.y.getDefiningOp(); auto dim_y = get_constant(y_op, "y"); if (!dim_y.has_value() || dim_y.value() != 1) { y_op->emitError() << "bound 'y' is not constant 1"; return absl::nullopt; } - auto z_op = dim.z->getDefiningOp(); + auto z_op = dim.z.getDefiningOp(); auto dim_z = get_constant(z_op, "z"); if (!dim_z.has_value() || dim_z.value() != 1) { z_op->emitError() << "bound 'z' is not constant 1"; return absl::nullopt; } - return get_constant(dim.x->getDefiningOp(), "x"); + return get_constant(dim.x.getDefiningOp(), "x"); } using OperandToValueMap = - absl::flat_hash_map>; + absl::flat_hash_map>; static StatusOr> ComputeOperandToValueMap( OperandToValueMap* operand_to_value_map, const HloInstruction* instr, @@ -224,7 +224,7 @@ static StatusOr> ComputeOperandToValueMap( for (int kernel_index = 0; kernel_index < launchOp.getNumKernelOperands(); ++kernel_index) { auto launchop_operand = - dyn_cast(launchOp.getKernelOperand(kernel_index)); + launchOp.getKernelOperand(kernel_index).dyn_cast(); if (!launchop_operand) { launchOp.emitError("argument to kernel is not a function input"); has_failed = true; @@ -233,7 +233,7 @@ static StatusOr> ComputeOperandToValueMap( // host_index is the argument position to the surrounding function that // contains the launch. This index corresponds to HLO operand indices // by construction. - auto host_index = launchop_operand->getArgNumber(); + auto host_index = launchop_operand.getArgNumber(); // The trailing argument to the outer function are the results. auto operand = (host_index < operands.size()) ? operands[host_index] : instr; @@ -272,7 +272,7 @@ Status InsertBufferLoadPreduleIntoKernel( std::vector as_mlir_types(new_arg_types.begin(), new_arg_types.end()); auto new_args = kernel.front().addArguments(as_mlir_types); - std::vector buffer_args(new_args.begin(), new_args.end()); + std::vector buffer_args(new_args.begin(), new_args.end()); auto zero = builder.create( loc, offset_type, builder.getI64IntegerAttr(0)); @@ -304,29 +304,27 @@ Status InsertBufferLoadPreduleIntoKernel( // { baseptr, dataptr, offset, shape_vect, stride_vect } // where shape_vect and stride_vect are integer vectors with length // matching the rank of the tensor. - auto target_type = value->getType().cast(); + auto target_type = value.getType().cast(); auto struct_type = target_type.getPointerElementTy(); auto descPtr = builder.create(loc, target_type, one, 0); // Fill the base and aligned pointers. auto casted = builder.create( - loc, struct_type.getStructElementType(0), - llvm::ArrayRef{ptr}); + loc, struct_type.getStructElementType(0), llvm::ArrayRef{ptr}); auto structPtrAddr = builder.create( loc, struct_type.getStructElementType(0), descPtr, - llvm::ArrayRef{zero, baseIndex}); + llvm::ArrayRef{zero, baseIndex}); builder.create(loc, casted, structPtrAddr); casted = builder.create( - loc, struct_type.getStructElementType(1), - llvm::ArrayRef{ptr}); + loc, struct_type.getStructElementType(1), llvm::ArrayRef{ptr}); structPtrAddr = builder.create( loc, struct_type.getStructElementType(1), descPtr, - llvm::ArrayRef{zero, dataIndex}); + llvm::ArrayRef{zero, dataIndex}); builder.create(loc, casted, structPtrAddr); // Fill the offset value. auto structOffsetAddr = builder.create( loc, struct_type.getStructElementType(1), descPtr, - llvm::ArrayRef{zero, offsetIndex}); + llvm::ArrayRef{zero, offsetIndex}); builder.create(loc, offset, structOffsetAddr); // Fill the shape. auto shape = operand->shape(); @@ -341,7 +339,7 @@ Status InsertBufferLoadPreduleIntoKernel( loc, offset_type, builder.getI64IntegerAttr(extent.index())); auto shapeEntryPtr = builder.create( loc, entry_type, descPtr, - llvm::ArrayRef{zero, shapeIndex, index}); + llvm::ArrayRef{zero, shapeIndex, index}); auto extentValue = builder.create( loc, entry_type, builder.getI64IntegerAttr(extent.value())); builder.create(loc, extentValue, shapeEntryPtr); @@ -349,13 +347,13 @@ Status InsertBufferLoadPreduleIntoKernel( // Finally, fill the strides. // TODO(b/137624192): Take assigned layout into account. entry_type = struct_type.getStructElementType(4).getArrayElementType(); - Value* accumulator = nullptr; + Value accumulator = nullptr; for (int64 idx = shape.rank() - 1; idx >= 0; --idx) { auto indexValue = builder.create( loc, offset_type, builder.getI64IntegerAttr(idx)); auto strideEntryPtr = builder.create( loc, entry_type, descPtr, - llvm::ArrayRef{zero, strideIndex, indexValue}); + llvm::ArrayRef{zero, strideIndex, indexValue}); if (accumulator) { auto strideValue = builder.create( loc, entry_type, @@ -369,7 +367,7 @@ Status InsertBufferLoadPreduleIntoKernel( } } // Now we can use the descriptor instead of the original argument. - value->replaceAllUsesWith(descPtr); + value.replaceAllUsesWith(descPtr); } } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h index d84b72cadcf..bb852b47f22 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h @@ -17,8 +17,8 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_MLIR_COMPILER_H_ #include "absl/container/flat_hash_map.h" -#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir -#include "mlir/IR/Module.h" // TF:local_config_mlir +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/mlir_gpu/emission_context.h" diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.cc b/tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.cc index da42e6462e2..dbc6efe9ec9 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.cc @@ -22,8 +22,8 @@ limitations under the License. #include "absl/memory/memory.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/IR/Module.h" // TF:local_config_mlir -#include "mlir/Pass/PassManager.h" // TF:local_config_mlir +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/Pass/PassManager.h" // TF:llvm-project #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.h" #include "tensorflow/compiler/xla/service/mlir_gpu/inject_errors_pass.h" diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc b/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc index 505d16d11cc..afcac65bdc7 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc @@ -393,5 +393,104 @@ ENTRY %AddReduce (x: f32[100,10], c: f32[]) -> f32[100] { )"); } +TEST_F(LhloGenTest, Abs) { + CompileAndVerifyIr(R"( +HloModule Abs +ENTRY %Abs (val: f32[2,2]) -> f32[2,2] { + %val = f32[2,2]{1,0} parameter(0) + ROOT %abs = f32[2,2]{1,0} abs(f32[2,2]{1,0} %val) +})", + R"( +;CHECK: func @abs(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { +;CHECK: "xla_lhlo.abs"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () +;CHECK: } + )"); +} + +TEST_F(LhloGenTest, Ceil) { + CompileAndVerifyIr(R"( +HloModule Ceil +ENTRY %Ceil (val: f32[2,2]) -> f32[2,2] { + %val = f32[2,2]{1,0} parameter(0) + ROOT %ceil = f32[2,2]{1,0} ceil(f32[2,2]{1,0} %val) +})", + R"( +;CHECK: func @ceil(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { +;CHECK: "xla_lhlo.ceil"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () +;CHECK: } + )"); +} + +TEST_F(LhloGenTest, Cos) { + CompileAndVerifyIr(R"( +HloModule Cos +ENTRY %Cos (val: f32[2,2]) -> f32[2,2] { + %val = f32[2,2]{1,0} parameter(0) + ROOT %cos = f32[2,2]{1,0} cosine(f32[2,2]{1,0} %val) +})", + R"( +;CHECK: func @cosine(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { +;CHECK: "xla_lhlo.cos"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () +;CHECK: } + )"); +} + +TEST_F(LhloGenTest, Neg) { + CompileAndVerifyIr(R"( +HloModule Neg +ENTRY %Neg (val: f32[2,2]) -> f32[2,2] { + %val = f32[2,2]{1,0} parameter(0) + ROOT %neg = f32[2,2]{1,0} negate(f32[2,2]{1,0} %val) +})", + R"( +;CHECK: func @negate(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { +;CHECK: "xla_lhlo.neg"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () +;CHECK: } + )"); +} + +TEST_F(LhloGenTest, Rem) { + CompileAndVerifyIr(R"( +HloModule Rem +ENTRY %Rem(x: f32[2,2], y: f32[2,2]) -> f32[2,2] { + %x = f32[2,2]{1,0} parameter(0) + %y = f32[2,2]{1,0} parameter(1) + ROOT %rem = f32[2,2]{1,0} remainder(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y) +})", + R"( +;CHECK: func @remainder(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]]) { +;CHECK: "xla_lhlo.remainder"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> () +;CHECK: } + )"); +} + +TEST_F(LhloGenTest, Sign) { + CompileAndVerifyIr(R"( +HloModule Sign +ENTRY %Sign (val: f32[2,2]) -> f32[2,2] { + %val = f32[2,2]{1,0} parameter(0) + ROOT %sign = f32[2,2]{1,0} sign(f32[2,2]{1,0} %val) +})", + R"( +;CHECK: func @sign(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { +;CHECK: "xla_lhlo.sign"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () +;CHECK: } + )"); +} + +TEST_F(LhloGenTest, Tanh) { + CompileAndVerifyIr(R"( +HloModule Tanh +ENTRY %Tanh (val: f32[2,2]) -> f32[2,2] { + %val = f32[2,2]{1,0} parameter(0) + ROOT %tanh = f32[2,2]{1,0} tanh(f32[2,2]{1,0} %val) +})", + R"( +;CHECK: func @tanh(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { +;CHECK: "xla_lhlo.tanh"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () +;CHECK: } + )"); +} + } // namespace mlir_gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc index 41e2b0e9cb1..16e34331ac5 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc @@ -151,6 +151,37 @@ HloInstruction* MultiOutputFusion::Fuse(HloInstruction* instr1, return remaining; } +HloInstruction* MultiOutputFusion::CreateFusion(HloInstruction* base, + HloInstruction* to_fuse) { + HloInstruction* input_fusion = + computation()->AddInstruction(HloInstruction::CreateFusion( + base->shape(), HloInstruction::FusionKind::kLoop, base)); + + // Update candidate_ and all_fusion_candidates_. + std::vector> new_fusibles = + GetNewFusibles(base, to_fuse); + int64 index; + if (candidates_index_.contains(input_fusion)) { + index = candidates_index_[input_fusion]; + } else { + index = candidates_.size(); + InsertOrDie(&candidates_index_, input_fusion, index); + candidates_.emplace_back(input_fusion); + all_fusion_candidates_.push_back(input_fusion); + } + + // Update the worklist_. + FusionCandidate& candidate_node = candidates_[index]; + for (auto it : new_fusibles) { + candidate_node.fusibles.emplace_back(it.first, it.second); + worklist_.emplace(input_fusion, it.first, it.second); + } + + reachability_->Replace(base, input_fusion); + TF_CHECK_OK(computation()->ReplaceInstruction(base, input_fusion)); + return input_fusion; +} + bool MultiOutputFusion::IsProfitableOperand(HloInstruction* instr) { // kConstant instruction will not have memory reads, so it won't be a profit // source. Skip them. @@ -167,29 +198,12 @@ bool MultiOutputFusion::IsProfitableOperand(HloInstruction* instr) { return true; } -void MultiOutputFusion::Update(HloInstruction* instr1, HloInstruction* instr2) { - HloInstruction* fusion = instr1; - HloInstruction* fused = instr2; - if (is_fused(instr1)) { - fusion = instr2; - fused = instr1; - } - - // Insert the newly created instruction (if any), to candidates_. - for (auto use : fusion->users()) { - if (candidates_index_.find(use) == candidates_index_.end()) { - int64 index = candidates_.size(); - candidates_.emplace_back(use); - InsertOrDie(&candidates_index_, use, index++); - } - } +std::vector> +MultiOutputFusion::GetNewFusibles(HloInstruction* fusion, + HloInstruction* fused) { FusionCandidate& fusion_node = candidates_[get_candidate_id(fusion)]; FusionCandidate& fused_node = candidates_[get_candidate_id(fused)]; - // Update the reachability graph. - UpdateReachability(fusion, fused, all_fusion_candidates_, - [this](HloInstruction* instr) { return is_fused(instr); }); - // Update the fusible list for fusion. Variable new_fusibles keeps // track of the new or changed entries. std::vector> new_fusibles; @@ -227,6 +241,33 @@ void MultiOutputFusion::Update(HloInstruction* instr1, HloInstruction* instr2) { } fused_node.fusibles.clear(); + return new_fusibles; +} + +void MultiOutputFusion::Update(HloInstruction* instr1, HloInstruction* instr2) { + HloInstruction* fusion = instr1; + HloInstruction* fused = instr2; + if (is_fused(instr1)) { + fusion = instr2; + fused = instr1; + } + + // Insert the newly created instruction (if any), to candidates_. + for (auto use : fusion->users()) { + if (candidates_index_.find(use) == candidates_index_.end()) { + int64 index = candidates_.size(); + candidates_.emplace_back(use); + InsertOrDie(&candidates_index_, use, index++); + } + } + + // Update the reachability graph. + UpdateReachability(fusion, fused, all_fusion_candidates_, + [this](HloInstruction* instr) { return is_fused(instr); }); + + std::vector> new_fusibles = + GetNewFusibles(fusion, fused); + // Update the worklist_. for (auto it : new_fusibles) { worklist_.emplace(fusion, it.first, it.second); @@ -235,10 +276,15 @@ void MultiOutputFusion::Update(HloInstruction* instr1, HloInstruction* instr2) { bool MultiOutputFusion::LegalToFuse(HloInstruction* instr1, HloInstruction* instr2) { - if (instr1 == instr2) { + if (instr1->opcode() != HloOpcode::kFusion) { return false; } - if (instr1->opcode() != HloOpcode::kFusion) { + return LegalToFuseMainConstraints(instr1, instr2); +} + +bool MultiOutputFusion::LegalToFuseMainConstraints(HloInstruction* instr1, + HloInstruction* instr2) { + if (instr1 == instr2) { return false; } @@ -342,7 +388,12 @@ bool MultiOutputFusion::Perform() { } Update(instr1, instr2); HloInstruction* ret = Fuse(instr1, instr2); - set_is_fused(ret == instr1 ? instr2 : instr1); + if (ret != instr1) { + set_is_fused(instr1); + } + if (ret != instr2) { + set_is_fused(instr2); + } changed = true; VLOG(2) << "After fusion, \t this: " << ret->name() << "\n" << ret->fused_instructions_computation()->ToString( diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h index 9be69f808c4..55cb15e94fc 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.h +++ b/tensorflow/compiler/xla/service/multi_output_fusion.h @@ -79,6 +79,11 @@ class MultiOutputFusion : public HloModulePass { // Test if it's legal to fuse instr1 and instr2 into one fusion instruction. virtual bool LegalToFuse(HloInstruction* instr1, HloInstruction* instr2); + // Test if it's legal to fuse instr1 and instr2 into one fusion instruction + // using main constraints. + bool LegalToFuseMainConstraints(HloInstruction* instr1, + HloInstruction* instr2); + // Fuse HloInstruction instr1 and instr2 and return the fused instruction. // The other instruction is removed from its parent computation. virtual HloInstruction* Fuse(HloInstruction* instr1, HloInstruction* instr2); @@ -105,6 +110,17 @@ class MultiOutputFusion : public HloModulePass { // InstructionFusion instead. virtual bool DoProducerConsumerMultiOutputFusion(); + // Return a list of new fusible instructions that can be fused into `fusion' + // fused with `fused'. The second entry in the vector is a profit value from + // fusing the corresponding instruction. + std::vector> GetNewFusibles( + HloInstruction* fusion, HloInstruction* fused); + + // Create a new fusion instruction and add `base' into it. + // Prepare for fusing `to_fuse' into the created fusion by updating + // reachability, worklist, and fusion candidates. + HloInstruction* CreateFusion(HloInstruction* base, HloInstruction* to_fuse); + private: // An internal data structure for each instruction in current computation. // When an instruction is removed, member 'hlo' is set to nullptr. diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index 32e4c636327..3a5f6da3b7c 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -73,7 +73,7 @@ namespace xla { // - EqualTo // - CompatibleTo // - IsScalar/IsEffectiveScalar/IsArray/IsTuple -// - IsDenseArray/IsSparseArray +// - IsDenseArray // - WithLayout: layout shape's layout matches the given pattern (e.g. // Layout().WithDenseFormat()) // - WithLayoutEqualTo: shape's layout equals the argument (i.e. another @@ -87,7 +87,7 @@ namespace xla { // // Layout(): // - EqualTo -// - WithDenseFormat/WithSparseFormat +// - WithDenseFormat // // Op(), Shape(), and Layout() may be passed an argument of type // HloInstruction**, Shape**, or Layout**, respectively, or const versions of @@ -506,12 +506,6 @@ class LayoutPattern { return AppendImpl(LayoutPatternFormatImpl(DENSE)); } - // Modifies the pattern to match only if the layout has a sparse format. - constexpr auto WithSparseFormat() const - -> decltype(this->AppendImpl(LayoutPatternFormatImpl(SPARSE))) { - return AppendImpl(LayoutPatternFormatImpl(SPARSE)); - } - private: Impl impl_; LayoutType** matched_layout_; @@ -1060,11 +1054,6 @@ class ShapePattern { return WithLayout(Layout().WithDenseFormat()); } - constexpr auto IsSparseArray() const - -> decltype(this->WithLayout(Layout().WithSparseFormat())) { - return WithLayout(Layout().WithSparseFormat()); - } - // Modifies the pattern to match only if the shape has a subshape that matches // the given pattern. template diff --git a/tensorflow/compiler/xla/service/pattern_matcher_gmock_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_gmock_test.cc index f51a18b1389..a2ba8b888dc 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher_gmock_test.cc +++ b/tensorflow/compiler/xla/service/pattern_matcher_gmock_test.cc @@ -56,9 +56,6 @@ TEST(PatternMatcherGmock, MatchShape) { TEST(PatternMatcherGmock, MatchLayout) { Layout l = LayoutUtil::MakeLayout({0, 1}); EXPECT_THAT(l, GmockMatch(m::Layout())); - EXPECT_THAT(&l, Not(GmockMatch(m::Layout().WithSparseFormat()))); - EXPECT_THAT(Describe(GmockMatch(m::Layout().WithSparseFormat())), - "a layout with format SPARSE"); } TEST(PatternMatchGmock, MatchInstruction) { diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc index b923117318a..5e1287e5ddc 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher_test.cc +++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc @@ -89,7 +89,6 @@ TEST_F(PatternMatcherTest, DenseArrayShape) { EXPECT_TRUE(Match(&array_shape, match::Shape(&matched_shape).IsArray())); EXPECT_EQ(matched_shape, &array_shape); EXPECT_TRUE(Match(&array_shape, match::Shape().IsDenseArray())); - EXPECT_FALSE(Match(&array_shape, match::Shape().IsSparseArray())); EXPECT_FALSE(Match(&array_shape, match::Shape().IsScalar())); EXPECT_FALSE(Match(&array_shape, match::Shape().IsTuple())); EXPECT_TRUE(Match(&array_shape, match::Shape().WithElementType(F32))); @@ -97,38 +96,12 @@ TEST_F(PatternMatcherTest, DenseArrayShape) { EXPECT_FALSE( Match(&array_shape, match::Shape().WithSubshape({0}, match::Shape()))); Layout* matched_layout; - EXPECT_FALSE(Match(&array_shape, - match::Shape().WithLayout( - match::Layout(&matched_layout).WithSparseFormat()))); EXPECT_TRUE(Match(&array_shape, match::Shape().WithLayout( match::Layout(&matched_layout).WithDenseFormat()))); EXPECT_EQ(matched_layout, &array_shape.layout()); } -TEST_F(PatternMatcherTest, SparseArrayShape) { - auto array_shape = ShapeUtil::MakeShapeWithSparseLayout(F32, {2, 3, 4}, 10); - Shape* matched_shape; - EXPECT_TRUE(Match(&array_shape, match::Shape(&matched_shape).IsArray())); - EXPECT_EQ(matched_shape, &array_shape); - EXPECT_FALSE(Match(&array_shape, match::Shape().IsDenseArray())); - EXPECT_TRUE(Match(&array_shape, match::Shape().IsSparseArray())); - EXPECT_FALSE(Match(&array_shape, match::Shape().IsScalar())); - EXPECT_FALSE(Match(&array_shape, match::Shape().IsTuple())); - EXPECT_TRUE(Match(&array_shape, match::Shape().WithElementType(F32))); - EXPECT_TRUE(Match(&array_shape, match::Shape().WithRank(3))); - EXPECT_FALSE( - Match(&array_shape, match::Shape().WithSubshape({0}, match::Shape()))); - Layout* matched_layout; - EXPECT_FALSE(Match(&array_shape, - match::Shape().WithLayout( - match::Layout(&matched_layout).WithDenseFormat()))); - EXPECT_TRUE(Match(&array_shape, - match::Shape().WithLayout( - match::Layout(&matched_layout).WithSparseFormat()))); - EXPECT_EQ(matched_layout, &array_shape.layout()); -} - TEST_F(PatternMatcherTest, TupleShape) { auto tuple_shape = ShapeUtil::MakeTupleShape({ ShapeUtil::MakeShape(F32, {1, 2, 3}), @@ -568,15 +541,6 @@ TEST_F(PatternMatcherTest, LayoutDescribeToAndExplain) { EXPECT_DESC_AND_EXPLANATION(layout2, m::Layout().EqualTo(&layout), "a layout equal to {1,2}", "Layout {2,2} is not equal to expected {1,2}"); - EXPECT_DESC_AND_EXPLANATION(layout2, m::Layout().WithSparseFormat(), - "a layout with format SPARSE", - "Layout has format DENSE but expected SPARSE"); - EXPECT_DESC_AND_EXPLANATION(layout, - m::Layout().EqualTo(&layout).WithSparseFormat(), - "a layout:\n" - " * equal to {1,2} AND\n" - " * with format SPARSE", - "Layout has format DENSE but expected SPARSE"); } TEST_F(PatternMatcherTest, CustomCallTargetMatcherDescribeAndExplain) { @@ -665,11 +629,6 @@ TEST_F(PatternMatcherTest, ShapeDescribeToAndExplain) { "a shape with\n a layout equal to {0,1}", "Layout {1,0} is not equal to expected {0,1}\n" "in f32[1,2]{1,0}"); - EXPECT_DESC_AND_EXPLANATION( - shape, m::Shape().WithLayout(m::Layout().WithSparseFormat()), - "a shape with\n a layout with format SPARSE", - "Layout has format DENSE but expected SPARSE\n" - "in f32[1,2]{0,1}"); EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithSubshapeEqualTo({10}, &shape), "a shape with subshape at index {10} which is\n" diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index ec6a97e928a..816047fcf5d 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -408,7 +408,18 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, for (size_t i = 1; i < arg_shapes.size(); ++i) { new_dimensions[dimension] += arg_shapes[i]->dimensions(dimension); } - return ShapeUtil::MakeShape(element_type, new_dimensions); + + Shape result = ShapeUtil::MakeShape(element_type, new_dimensions); + + // Set dynamic dimensions if any input has dynamic dimension. + for (const Shape* shape : arg_shapes) { + for (int64 i = 0; i < shape->dimensions_size(); ++i) { + if (shape->is_dynamic_dimension(i)) { + result.set_dynamic_dimension(i, true); + } + } + } + return result; } /* static */ StatusOr ShapeInference::InferConvertShape( @@ -1720,7 +1731,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const int64 kernel_output_features = rhs.dimensions(dnums.kernel_output_feature_dimension()); - if (batch_group_count > 1 && kernel_output_features != batch_group_count) { + if (batch_group_count > 1 && + kernel_output_features % batch_group_count != 0) { return InvalidArgument( "Expected output feature dimension size (value %d) to be equal to " "batch group count %d; got (%s, %s)\n" @@ -1759,7 +1771,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, dnums.DebugString()); } - if (input_batch % batch_group_count > 0) { + if (input_batch % batch_group_count != 0) { return InvalidArgument( "Expected input batch dimension (value %d) to be divisible by " "batch_group_count (value %d); " @@ -1793,6 +1805,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, std::vector dimensions(num_dims); dimensions[dnums.output_batch_dimension()] = input_batch / batch_group_count; dimensions[dnums.output_feature_dimension()] = kernel_output_features; + + if (batch_group_count > 1) { + dimensions[dnums.output_batch_dimension()] = + kernel_output_features / batch_group_count; + dimensions[dnums.output_feature_dimension()] = batch_group_count; + } + for (int i = 0; i < num_spatial_dims; ++i) { dimensions[dnums.output_spatial_dimensions(i)] = window_output_shape.dimensions(i); diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index b189e047254..41a54e81792 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -1592,6 +1592,20 @@ TEST_F(ShapeInferenceTest, WhileWithBadShapes) { HasSubstr("parameter of condition and body")); } +// Tests for the concatenate instruction with dynamic shapes. +TEST_F(ShapeInferenceTest, ConcatenateWithDynamicShapes) { + auto dynamic_shape_1 = + ShapeUtil::MakeShape(F32, {32, 160, 10}, {true, false, false}); + auto dynamic_shape_2 = + ShapeUtil::MakeShape(F32, {32, 160, 10}, {false, true, false}); + auto inferred_status = ShapeInference::InferConcatOpShape( + {&dynamic_shape_1, &dynamic_shape_2}, /*dimension=*/0); + ASSERT_IS_OK(inferred_status.status()); + Shape inferred = inferred_status.ValueOrDie(); + ASSERT_TRUE(ShapeUtil::Equal( + ShapeUtil::MakeShape(F32, {64, 160, 10}, {true, true, false}), inferred)); +} + // Tests for the concatenate instruction with proper shapes. TEST_F(ShapeInferenceTest, ConcatenateWithCorrectShapes) { auto inferred_status_1 = ShapeInference::InferConcatOpShape( diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 484673b8b6b..146d03fa0c5 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -229,16 +229,6 @@ StatusOr MakeShapeWithLayoutInternal( return MakeShapeWithLayout(element_type, dimensions, layout); } -/* static */ Shape ShapeUtil::MakeShapeWithSparseLayout( - PrimitiveType element_type, absl::Span dimensions, - int64 max_sparse_elements) { - CHECK(IsArrayPrimitiveType(element_type)); - Shape shape = ShapeUtil::MakeShape(element_type, dimensions); - *shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements); - TF_DCHECK_OK(ShapeUtil::ValidateShape(shape)); - return shape; -} - /* static */ Shape ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( const Shape& shape) { @@ -637,9 +627,6 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return ByteSizeOfTupleIndexTable(shape, pointer_size); } else if (shape.IsArray()) { int64 byte_size = ByteSizeOfElements(shape); - if (LayoutUtil::IsSparseArray(shape)) { - byte_size += ByteSizeOfSparseIndices(shape); - } return byte_size; } else if (shape.element_type() == TOKEN) { return 0; @@ -664,23 +651,12 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( CHECK(shape.IsArray()); int64 allocated_element_count; - if (LayoutUtil::IsSparseArray(shape)) { - allocated_element_count = LayoutUtil::MaxSparseElements(shape.layout()); - } else { - CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ShortDebugString(); - allocated_element_count = ElementsIn(shape); - } + CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ShortDebugString(); + allocated_element_count = ElementsIn(shape); return allocated_element_count * ByteSizeOfPrimitiveType(shape.element_type()); } -/* static */ int64 ShapeUtil::ByteSizeOfSparseIndices(const Shape& shape) { - TF_DCHECK_OK(ValidateShape(shape)); - CHECK(LayoutUtil::IsSparseArray(shape)); - return LayoutUtil::MaxSparseElements(shape.layout()) * shape.rank() * - sizeof(int64); -} - /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal( const Shape& shape) { if (shape.element_type() == PRIMITIVE_TYPE_INVALID || @@ -721,9 +697,6 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return Status::OK(); } - if (LayoutUtil::IsSparseArray(shape) && shape.rank() == 0) { - return InvalidArgument("sparse arrays must have rank > 0"); - } for (int64 i = 0; i < shape.rank(); ++i) { int64 dimension = shape.dimensions(i); if (dimension < 0) { @@ -744,43 +717,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return Status::OK(); } - // We can only reason about some aspects of array's shape if it has a valid - // layout, these aspects will be ignored otherwise. - bool shape_has_valid_layout = LayoutUtil::HasLayout(shape) && - LayoutUtil::ValidateLayoutInShape(shape).ok(); - int64 shape_size = [&]() { - if (shape_has_valid_layout && LayoutUtil::IsSparseArray(shape)) { - int64 max_sparse_elements = LayoutUtil::MaxSparseElements(shape.layout()); - if (max_sparse_elements < 0) { - return max_sparse_elements; - } - int64 sparse_elements_size = MultiplyWithoutOverflow( - max_sparse_elements, ByteSizeOfPrimitiveType(shape.element_type())); - if (sparse_elements_size < 0) { - return sparse_elements_size; - } - int64 sparse_indices_size = - MultiplyWithoutOverflow(max_sparse_elements, shape.rank()); - if (sparse_indices_size < 0) { - return sparse_indices_size; - } - sparse_indices_size = - MultiplyWithoutOverflow(sparse_indices_size, sizeof(int64)); - if (sparse_indices_size < 0) { - return sparse_indices_size; - } - // At this point, both sparse_indices_size and sparse_elements_size are - // non-negative, so we can easily check if adding them wraps. - if (static_cast(sparse_elements_size) + - static_cast(sparse_indices_size) > - INT64_MAX) { - return static_cast(-1); - } - } - - // This is intentionally unconditional: even if the shape is sparse, we want - // to verify the densified version has a reasonable size. int64 dense_shape_size = 1; if (shape.dimensions().empty()) { return dense_shape_size; diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 769094b1f0b..7e05e17865d 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -192,10 +192,7 @@ class ShapeUtil { }; // Returns the number of elements are contained within the provided shape; - // e.g. for rank 0 (scalars) the result is always 1. Note that sparse shapes - // may not actually be able to store this number of elements. See - // LayoutUtil::MaxSparseElements(shape) to obtain the maximum number of - // elements that can be stored in a sparse shape. + // e.g. for rank 0 (scalars) the result is always 1. // Precondition: shape.IsArray() static int64 ElementsIn(const Shape& shape); @@ -228,20 +225,12 @@ class ShapeUtil { int64 pointer_size); // Returns the number of bytes required for the elements in an allocation of - // `shape`, which must be an array shape. The return value does not include - // the bytes needed to store sparse indices. Dense shapes use a separate + // `shape`, which must be an array shape. Shapes use a separate // memory location for each element, and so for these shapes, - // `ByteSizeOf(shape) == ByteSizeOfElements(shape)`. For dense shapes, this - // size also includes padding if present in the layout. For sparse shapes, - // `ByteSizeOf(shape) == ByteSizeOfElements(shape) + - // ByteSizeOfSparseindices(shape)`. + // `ByteSizeOf(shape) == ByteSizeOfElements(shape)`. This + // size also includes padding if present in the layout. static int64 ByteSizeOfElements(const Shape& shape); - // Returns the number of bytes required for the sparse indices in an - // allocation of shape. The shape must be an array shape. The return value - // does not include the bytes needed to store sparse indices. - static int64 ByteSizeOfSparseIndices(const Shape& shape); - // Returns a human-readable string that represents the given shape, with or // without layout. e.g. "f32[42x12] {0, 1}" or "f32[64]". static string HumanString(const Shape& shape); @@ -427,9 +416,6 @@ class ShapeUtil { int64 element_size_in_bits = 0, int64 memory_space = 0); - static Shape MakeShapeWithSparseLayout(PrimitiveType element_type, - absl::Span dimensions, - int64 max_sparse_elements); // Returns the same shape except with all dimensions set to be static. static Shape MakeShapeWithStaticDimensions(const Shape& shape); diff --git a/tensorflow/compiler/xla/sparse_index_array.cc b/tensorflow/compiler/xla/sparse_index_array.cc deleted file mode 100644 index 82091bdee65..00000000000 --- a/tensorflow/compiler/xla/sparse_index_array.cc +++ /dev/null @@ -1,109 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/sparse_index_array.h" - -#include "tensorflow/compiler/xla/index_util.h" -#include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/shape_util.h" - -namespace xla { - -SparseIndexArray::SparseIndexArray() : rank_(0), max_indices_(0) {} - -SparseIndexArray::SparseIndexArray(int64 max_indices, int64 rank, - std::vector indices) - : indices_(std::move(indices)), rank_(rank), max_indices_(max_indices) { - CHECK_GT(rank_, 0); - CHECK_EQ(indices_.size() % rank_, 0) - << "indices_.size(): " << indices_.size() << ", rank_: " << rank_; - CHECK_LE(index_count(), max_indices_); -} - -SparseIndexArray::SparseIndexArray(int64 max_indices, int64 rank, - absl::Span indices) - : SparseIndexArray(max_indices, rank, - std::vector(indices.begin(), indices.end())) {} - -SparseIndexArray::SparseIndexArray(int64 max_indices, - const Array2D& indices) - : SparseIndexArray(max_indices, indices.n2(), - std::vector(indices.begin(), indices.end())) {} - -int64 SparseIndexArray::index_count() const { - CHECK_GT(rank_, 0); - CHECK_EQ(indices_.size() % rank_, 0); - return indices_.size() / rank_; -} - -absl::Span SparseIndexArray::At( - int64 sparse_element_number) const { - CHECK_GT(rank_, 0); - CHECK_GE(sparse_element_number, 0); - CHECK_LE(rank_ * sparse_element_number + rank_, indices_.size()); - return absl::Span( - indices_.data() + rank_ * sparse_element_number, rank_); -} - -absl::Span SparseIndexArray::At(int64 sparse_element_number) { - CHECK_GT(rank_, 0); - CHECK_GE(sparse_element_number, 0); - CHECK_LE(rank_ * sparse_element_number + rank_, indices_.size()); - return absl::Span(indices_.data() + rank_ * sparse_element_number, - rank_); -} - -void SparseIndexArray::Append(absl::Span index) { - CHECK_GT(rank_, 0); - CHECK_EQ(index.size(), rank_); - indices_.insert(indices_.end(), index.begin(), index.end()); -} - -void SparseIndexArray::Clear() { indices_.clear(); } - -void SparseIndexArray::Resize(int64 num_indices) { - CHECK_GT(rank_, 0); - indices_.resize(rank_ * num_indices); -} - -bool SparseIndexArray::Validate(const Shape& shape) const { - if (rank_ == 0 || rank_ != shape.rank()) { - return false; - } - int64 num_indices = index_count(); - if (num_indices > LayoutUtil::MaxSparseElements(shape.layout())) { - return false; - } - if (num_indices < 2) { - return true; - } - absl::Span last = At(0); - if (!IndexUtil::IndexInBounds(shape, last)) { - return false; - } - for (int64 n = 1; n < num_indices; ++n) { - absl::Span next = At(n); - if (!IndexUtil::IndexInBounds(shape, next)) { - return false; - } - if (IndexUtil::CompareIndices(last, next) >= 0) { - return false; - } - last = next; - } - return true; -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/sparse_index_array.h b/tensorflow/compiler/xla/sparse_index_array.h deleted file mode 100644 index 0c25355467d..00000000000 --- a/tensorflow/compiler/xla/sparse_index_array.h +++ /dev/null @@ -1,176 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Utility class for managing sparse array indices. - -#ifndef TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_ -#define TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_ - -#include - -#include "absl/container/inlined_vector.h" -#include "absl/types/span.h" -#include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/index_util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" - -namespace xla { - -// Encapsulates the array of indices for a sparse array. A SparseIndexArray -// contain indices for up to `max_indices` elements of a sparse array. Each -// sparse index is an array of `rank` int64 value that gives the location of a -// value within a sparse array. Note that the dimensions of the array are not -// checked (except for the rank). To avoid confusion, we refer to the position -// of an index within a SparseIndexArray as a sparse index number. -class SparseIndexArray { - public: - SparseIndexArray(); - SparseIndexArray(const SparseIndexArray&) = default; - SparseIndexArray(SparseIndexArray&&) = default; - SparseIndexArray& operator=(const SparseIndexArray&) = default; - SparseIndexArray& operator=(SparseIndexArray&&) = default; - - // Constructs a SparseIndexArray that can hold up to `max_indices` sparse - // indices, with an initial contents obtained from the given array. The rank - // is taken from the minor dimension of the array. The major dimension of the - // array must not exceed `max_indices`. - SparseIndexArray(int64 max_indices, const Array2D& indices); - - // Like above, but the array is flattened. For example, the following are - // equivalent: - // - // SparseIndexArray(10, 3, - // Array2D{ - // {0, 1, 2}, - // {3, 4, 5}, - // {6, 7, 8}, - // {9, 10, 11}, - // }) - // - // SparseIndexArray(10, 3, - // {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}) - // - SparseIndexArray(int64 max_indices, int64 rank, - std::vector indices = {}); - SparseIndexArray(int64 max_indices, int64 rank, - absl::Span indices); - - // Returns the number of elements represented by the indices stored in the - // array. - int64 index_count() const; - - // Returns a slice that refers to the given sparse index number. The argument - // must be in the range [0, element_count()). - absl::Span At(int64 sparse_element_number) const; - absl::Span At(int64 sparse_element_number); - - // Adds the given index at the end of the array. The new size of the - // SparseIndexArray must not exceed `max_indices`. - void Append(absl::Span index); - - // Removes all indices from the array. - void Clear(); - - // Resizes the array to contain the given number of sparse indices. The new - // size must be smaller than `max_indices`. If the new size is larger than - // the old size, the value of the new indices is not specified. - void Resize(int64 num_indices); - - // Returns true iff all indices are unique and occur in sorted order, and are - // valid for the given shape. - bool Validate(const Shape& shape) const; - - int64 rank() const { return rank_; } - int64 max_indices() const { return max_indices_; } - - // Returns a pointer to the int64 array that holds the sparse indices. - absl::Span mutable_data() { return absl::MakeSpan(indices_); } - absl::Span data() const { return indices_; } - - // Sorts this sparse index array along with the set of corresponding values. - // The indices and values are sorted in the lexicographic order of the - // indices, from smallest to largest. - // - // For example: - // - // std::vector v{10.0, 11.0, 12.0}; - // SparseIndexArray a(10, 3, - // {{3, 4, 5}, - // {1, 2, 3}, - // {2, 3, 4}}); - // a.SortWithValues(&v); - // // Prints "11.0, 12.0, 10.0": - // std::cout << v[0] << ", " << v[1] << ", " << v[2] << std::endl; - // - template - void SortWithValues(absl::Span values); - - private: - std::vector indices_; - int64 rank_; - int64 max_indices_; -}; - -template -void SparseIndexArray::SortWithValues(absl::Span values) { - int64 num_elements = index_count(); - CHECK_EQ(values.size(), num_elements); - std::vector sort_order; - sort_order.reserve(num_elements); - for (int64 i = 0; i < num_elements; ++i) { - sort_order.push_back(i); - } - auto sort_order_less = [this](int64 lhs, int64 rhs) { - return IndexUtil::CompareIndices(At(lhs), At(rhs)) < 0; - }; - absl::c_sort(sort_order, sort_order_less); - - // Reorder the array elements according to sort_order. Work through the array - // and follow cycles so we can do the reorder in-place. - absl::InlinedVector saved_index(rank()); - for (int64 i = 0; i < num_elements; ++i) { - // sort_order[i] == -1 indicates the element has already been copied. - if (sort_order[i] < 0) { - continue; - } else if (i == sort_order[i]) { - // The element is already in sorted order. - sort_order[i] = -1; - continue; - } - - std::copy_n(At(i).begin(), rank(), saved_index.begin()); - NativeT saved_value = values[i]; - int64 j = i; - for (;;) { - if (sort_order[j] == i) { - std::copy_n(saved_index.begin(), rank(), At(j).begin()); - values[j] = saved_value; - sort_order[j] = -1; - break; - } - - std::copy_n(At(sort_order[j]).begin(), rank(), At(j).begin()); - values[j] = values[sort_order[j]]; - - int64 k = sort_order[j]; - sort_order[j] = -1; - j = k; - } - } -} - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_ diff --git a/tensorflow/compiler/xla/sparse_index_array_test.cc b/tensorflow/compiler/xla/sparse_index_array_test.cc deleted file mode 100644 index e54057c4007..00000000000 --- a/tensorflow/compiler/xla/sparse_index_array_test.cc +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/sparse_index_array.h" - -#include - -#include "tensorflow/compiler/xla/test.h" - -namespace xla { -namespace { - -TEST(SparseIndexArrayTest, Sort) { - SparseIndexArray a(10, 3); - a.Append({2, 3, 4}); - a.Append({3, 4, 5}); - a.Append({1, 2, 3}); - a.Append({5, 6, 7}); - a.Append({4, 5, 6}); - a.Append({6, 7, 8}); - std::vector values = { - 12.0, 13.0, 11.0, 15.0, 14.0, 16.0, - }; - a.SortWithValues(absl::MakeSpan(values)); - ASSERT_EQ(a.data(), std::vector({1, 2, 3, 2, 3, 4, 3, 4, 5, 4, 5, 6, 5, - 6, 7, 6, 7, 8})); - ASSERT_EQ(values, std::vector({11.0, 12.0, 13.0, 14.0, 15.0, 16.0})); -} - -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 0a0eaa190ee..b2cc8050c42 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -175,7 +175,7 @@ tf_cc_binary( "//tensorflow/compiler/xla/service/cpu:cpu_compiler", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", - "@llvm//:support", + "@llvm-project//llvm:support", ], ) @@ -255,7 +255,7 @@ cc_library( srcs = ["filecheck.cc"], hdrs = ["filecheck.h"], data = [ - "@llvm//:FileCheck", + "@llvm-project//llvm:FileCheck", ], deps = [ "//tensorflow/compiler/xla:statusor", @@ -2136,7 +2136,7 @@ tf_cc_test( "//tensorflow/core:test_main", "//tensorflow/stream_executor", "@com_google_absl//absl/memory", - "@llvm//:core", + "@llvm-project//llvm:core", ], ) diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc old mode 100644 new mode 100755 index 17e37607be1..07465885a69 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -364,7 +364,6 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( instruction->set_raw_backend_config_string(backend_config); } - // return ::testing::AssertionSuccess(); auto output = test_runner_.Execute(std::move(module), fake_argument_ptrs, /*run_hlo_passes=*/run_hlo_passes, /*profile=*/profile); @@ -501,6 +500,19 @@ HloInstruction* HloTestBase::FindInstruction(HloModule* module, return nullptr; } +HloInstruction* HloTestBase::FindInstruction(HloModule* module, + HloOpcode opcode) { + for (const HloComputation* c : module->computations()) { + auto instructions = c->instructions(); + auto it = absl::c_find_if( + instructions, [&](HloInstruction* i) { return i->opcode() == opcode; }); + if (it != instructions.end()) { + return *it; + } + } + return nullptr; +} + Backend& HloTestBase::backend() { return test_runner_.backend(); } /* static */ diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h old mode 100644 new mode 100755 index 848b334cfec..45917f39b6c --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -274,6 +274,8 @@ class HloTestBase : public ::testing::Test { // inspect a particular computation or instruction. HloComputation* FindComputation(HloModule* module, absl::string_view name); HloInstruction* FindInstruction(HloModule* module, absl::string_view name); + // Gets the instruction from the given module with the given opcode. + HloInstruction* FindInstruction(HloModule* module, HloOpcode opcode); // Return an HLO verifier constructed for the test backend. HloVerifier& verifier() const { return *hlo_verifier_; } diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 4563d7e0df2..76488917257 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -218,6 +218,23 @@ void PopulateWithFloatingPointData(Literal* literal, } } +// uniform_int_distribution is not defined for 8-bit integers. +// Use 'short' for those types. +template +struct RngT { + using type = IntT; +}; + +template <> +struct RngT { + using type = int16; +}; + +template <> +struct RngT { + using type = uint16; +}; + template void PopulateWithRandomIntegralData(Literal* literal, std::minstd_rand0* engine, bool no_duplicates) { @@ -230,7 +247,7 @@ void PopulateWithRandomIntegralData(Literal* literal, std::minstd_rand0* engine, std::shuffle(literal->data().begin(), literal->data().end(), *engine); } else { - std::uniform_int_distribution generator( + std::uniform_int_distribution::type> generator( std::numeric_limits::lowest(), std::numeric_limits::max()); for (IntT& value : literal->data()) { value = generator(*engine); @@ -324,9 +341,6 @@ StatusOr MakeFakeLiteralInternal(const Shape& shape, })); break; } - // Token requires no data. - case TOKEN: - break; default: return Unimplemented("Unsupported type for fake literal generation: %s", ShapeUtil::HumanString(shape)); @@ -341,7 +355,7 @@ void PopulateWithRandomIntegralDataWithBounds(Literal* literal, CHECK(engine != nullptr); CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType()); - std::uniform_int_distribution generator(min, max); + std::uniform_int_distribution::type> generator(min, max); for (IntT& value : literal->data()) { value = generator(*engine); } diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc index 9db08a5b72f..2a0d98ad1f1 100644 --- a/tensorflow/compiler/xla/tests/test_utils_test.cc +++ b/tensorflow/compiler/xla/tests/test_utils_test.cc @@ -56,24 +56,6 @@ XLA_TEST_F(TestUtilsTest, UnusedParam) { TF_ASSERT_OK(MakeFakeArguments(&module).status()); } -XLA_TEST_F(TestUtilsTest, Token) { - auto module = ParseAndReturnUnverifiedModule( - R"(HloModule outfeed_module - - ENTRY InfeedToOutfeed { - token0 = token[] parameter(0) - infeed = ((u32[3]{0}, pred[]), token[]) infeed(token0) - infeed.data = (u32[3]{0}, pred[]) get-tuple-element(infeed), index=0 - outfeed = token[] outfeed(infeed.data, token0) - ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token0) - infeed.1.data = (u32[3]{0}, pred[]) get-tuple-element(infeed.1), index=0 - infeed.1.token = token[] get-tuple-element(infeed.1), index=1 - outfeed.1 = token[] outfeed(infeed.1.data, infeed.1.token) - })") - .ValueOrDie(); - TF_ASSERT_OK(MakeFakeArguments(module.get()).status()); -} - XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) { auto module = ParseAndReturnVerifiedModule( R"(HloModule index_space_module diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 603e94ca938..db819c308ce 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -206,6 +206,7 @@ tf_cc_test( ":hlo_extractor", "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], ) diff --git a/tensorflow/compiler/xla/tools/driver.cc b/tensorflow/compiler/xla/tools/driver.cc index 4b3ed2b58b7..8949843b67b 100644 --- a/tensorflow/compiler/xla/tools/driver.cc +++ b/tensorflow/compiler/xla/tools/driver.cc @@ -365,6 +365,9 @@ void Fill(void* buffer, ArrayShape shape) { } template +#if defined(MEMORY_SANITIZER) +__attribute__((no_sanitize_memory)) +#endif void DisplayT(void* buffer, int num_elements) { T* casted = static_cast(buffer); for (int i = 0; i < num_elements; i++) { diff --git a/tensorflow/compiler/xla/tools/hlo_module_loader.cc b/tensorflow/compiler/xla/tools/hlo_module_loader.cc index 8eb170b25e5..0b16c877964 100644 --- a/tensorflow/compiler/xla/tools/hlo_module_loader.cc +++ b/tensorflow/compiler/xla/tools/hlo_module_loader.cc @@ -86,8 +86,8 @@ StatusOr> LoadModuleFromData( return InvalidArgument("Failed to parse input as HLO protobuf binary"); } } else if (format == "pbtxt") { - if (!proto2::TextFormat::ParseFromString(data, &proto) && - !proto2::TextFormat::ParseFromString(data, proto.mutable_hlo())) { + if (!google::protobuf::TextFormat::ParseFromString(data, &proto) && + !google::protobuf::TextFormat::ParseFromString(data, proto.mutable_hlo())) { return InvalidArgument("Failed to parse input as HLO protobuf text"); } } else { diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index 095655085e5..639f91b8b53 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -349,7 +349,7 @@ StatusOr> ParseRecordIoFile(absl::string_view filename, tensorflow::tstring record; while (reader.ReadRecord(&offset, &record).ok()) { HloSnapshot snapshot; - if (snapshot.mutable_hlo()->ParseFromStringPiece(record)) { + if (snapshot.mutable_hlo()->ParseFromString(record)) { snapshots.push_back(std::move(snapshot)); } else { LOG(ERROR) << "Encountered bad proto"; diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index b0b97f1eb45..5a3da69f9fc 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -115,9 +115,8 @@ enum Format { INVALID_FORMAT = 0; // The default layout, with exactly one storage location per element. DENSE = 1; - // A sparsely encoded layout, providing only the index/value pairs of non-zero - // elements. - SPARSE = 2; + reserved 2; + reserved "SPARSE"; } // Describes a tile used in tiling-based layout. Refer to @@ -156,10 +155,8 @@ message LayoutProto { reserved 3; reserved "padding_value"; - // The maximum number of elements that can be stored for SPARSE formats. This - // can be used to determine the maximum size in bytes of arrays stored in - // memory. This field must be unset unless the format is SPARSE. - int64 max_sparse_elements = 5; + reserved 5; + reserved "max_sparse_elements"; // A sequence of tiles, starting from the tile that's applied first to the // Shape. diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 29cb438473a..fbdcb4d65c8 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -79,7 +79,6 @@ load( "tf_cc_tests", "tf_copts", "tf_cuda_library", - "tf_cuda_only_cc_test", "tf_features_nomodules_if_android", "tf_features_nomodules_if_emscripten", "tf_gen_op_libs", @@ -89,15 +88,29 @@ load( "tf_opts_nortti_if_emscripten", "transitive_hdrs", ) + +# buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "if_nccl") + +# buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "tensorflow_opensource_extra_deps") +# buildifier: disable=same-origin-load # load("//tensorflow:tensorflow.bzl", "tf_android_full_lite_protos") + +# buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu") + +# buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "tf_cc_tests_gpu") + +# buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") +# buildifier: disable=same-origin-load # Placeholder: load("//tensorflow:tensorflow.bzl", "tf_portable_proto_lib") + +# buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "tf_portable_proto_library") # For platform specific build config @@ -311,20 +324,15 @@ cc_library( "//tensorflow/core/platform:threadpool_interface", "//tensorflow/core/platform:threadpool_options", "//tensorflow/core/platform:types", - "//tensorflow/core/platform/default/build_config:base", "@com_google_absl//absl/base", "@com_google_absl//absl/strings", ], ) -cc_library( +alias( name = "framework_bounds_check", - hdrs = ["//tensorflow/core/framework:bounds_check.h"], + actual = "//tensorflow/core/framework:bounds_check", visibility = ["//tensorflow/core/kernels:friends"], - deps = [ - ":platform_base", - "//third_party/eigen3", - ], ) filegroup( @@ -492,7 +500,7 @@ cc_library( "//tensorflow/core/lib/monitoring:legacy_lib_monitoring_lib_headers", "//tensorflow/core/lib/random:legacy_lib_random_headers", "//tensorflow/core/lib/strings:legacy_lib_string_headers", - "//tensorflow/core/util:gpu_cuda_alias.h", + "//tensorflow/core/util:lib_hdrs", ], visibility = ["//visibility:public"], deps = [ @@ -556,7 +564,7 @@ cc_library( "//tensorflow/core/lib/core:legacy_lib_core_status_test_util_header", "//tensorflow/core/platform:test.h", "//tensorflow/core/platform:test_benchmark.h", - "//tensorflow/core/util:reporter.h", + "//tensorflow/core/util:test_hdrs", ], copts = tf_copts(), linkopts = select({ @@ -644,46 +652,15 @@ tf_cuda_library( "//tensorflow/core/framework:variant_op_registry.h", "//tensorflow/core/framework:variant_tensor_data.h", "//tensorflow/core/util/sparse:framework_group", - "//tensorflow/core/util:activation_mode.h", - "//tensorflow/core/util:batch_util.h", - "//tensorflow/core/util:bcast.h", - "//tensorflow/core/util:debug_events_writer.h", - "//tensorflow/core/util:device_name_utils.h", - "//tensorflow/core/util:dump_graph.h", - "//tensorflow/core/util:einsum_op_util.h", - "//tensorflow/core/util:events_writer.h", - "//tensorflow/core/util:example_proto_fast_parsing.h", - "//tensorflow/core/util:example_proto_helper.h", - "//tensorflow/core/util:gpu_kernel_helper.h", - "//tensorflow/core/util:guarded_philox_random.h", - "//tensorflow/core/util:matmul_autotune.h", - "//tensorflow/core/util:matmul_bcast.h", - "//tensorflow/core/util:mirror_pad_mode.h", - "//tensorflow/core/util:padding.h", - "//tensorflow/core/util:port.h", - "//tensorflow/core/util:ptr_util.h", - "//tensorflow/core/util:reffed_status_callback.h", - "//tensorflow/core/util:saved_tensor_slice_util.h", - "//tensorflow/core/util:stat_summarizer.h", - "//tensorflow/core/util:stat_summarizer_options.h", - "//tensorflow/core/util:stream_executor_util.h", - "//tensorflow/core/util:strided_slice_op.h", - "//tensorflow/core/util:tensor_format.h", - "//tensorflow/core/util:tensor_ops_util.h", - "//tensorflow/core/util:tensor_slice_reader.h", - "//tensorflow/core/util:tensor_slice_reader_cache.h", - "//tensorflow/core/util:tensor_slice_writer.h", - "//tensorflow/core/util:use_cudnn.h", - "//tensorflow/core/util:util.h", - "//tensorflow/core/util:work_sharder.h", - "public/version.h", + "//tensorflow/core/util:framework_srcs", + "//tensorflow/core/public:version.h", ] + select({ "//tensorflow:windows": [], "//conditions:default": [ "//tensorflow/core/util:memmapped_file_system_hdrs", ], }) + if_mkl([ - "//tensorflow/core/util:mkl_util.h", + "//tensorflow/core/util:mkl_util_hdrs", ]), visibility = ["//visibility:public"], deps = [ @@ -706,24 +683,19 @@ alias( visibility = ["//visibility:public"], ) -cc_library( +alias( name = "overflow", - hdrs = ["//tensorflow/core/util:overflow.h"], - deps = [ - ":framework_lite", - ":lib", - ], + actual = "//tensorflow/core/util:overflow", ) -cc_library( +alias( name = "exec_on_stall", - hdrs = ["//tensorflow/core/util:exec_on_stall.h"], - deps = [":framework_lite"], + actual = "//tensorflow/core/util:exec_on_stall", ) -cc_library( +alias( name = "ptr_util", - hdrs = ["//tensorflow/core/util:ptr_util.h"], + actual = "//tensorflow/core/util:ptr_util", ) # TODO(gonnet): Remove this alias once all users have been moved to the actual target. @@ -742,7 +714,7 @@ alias( cc_library( name = "session_options", - hdrs = ["public/session_options.h"], + hdrs = ["//tensorflow/core/public:session_options.h"], visibility = ["//visibility:public"], deps = [ ":lib", @@ -1169,8 +1141,8 @@ tf_cuda_library( "graph/node_builder.h", "graph/validate.h", "graph/while_context.h", - "public/session.h", - "public/session_options.h", + "//tensorflow/core/public:session.h", + "//tensorflow/core/public:session_options.h", ], visibility = ["//visibility:public"], deps = [ @@ -1295,7 +1267,9 @@ cc_library( cc_library( name = "dynamic_kernels_impl", visibility = [":__subpackages__"], - deps = [], + deps = [ + "//tensorflow/core/kernels:sobol_op", + ], ) cc_library( @@ -1471,8 +1445,9 @@ filegroup( "//tensorflow/core/lib/random:legacy_lib_random_all_srcs", "//tensorflow/core/lib/strings:legacy_lib_strings_all_headers", "//tensorflow/core/lib/strings:legacy_lib_strings_all_srcs", - "//tensorflow/core/platform/default/build_config:android_srcs", + "//tensorflow/core/platform:legacy_mobile_srcs", "//tensorflow/core/profiler:mobile_srcs", + "//tensorflow/core/public:mobile_srcs_no_runtime", "//tensorflow/core/util/ctc:android_srcs", "//tensorflow/core/util/sparse:mobile_srcs_no_runtime_group", "//tensorflow/core/util:mobile_srcs_no_runtime", @@ -1481,7 +1456,6 @@ filegroup( "client/**/*.cc", "lib/**/*.h", "lib/**/*.cc", - "public/**/*.h", ], exclude = [ "**/*test.*", @@ -1748,8 +1722,7 @@ filegroup( "//tensorflow/core/framework:android_test_hdrs", "//tensorflow/core/framework:android_test_srcs", "//tensorflow/core/platform:test.h", - "//tensorflow/core/util:reporter.cc", - "//tensorflow/core/util:reporter.h", + "//tensorflow/core/util:android_test_srcs", ], visibility = ["//visibility:public"], ) @@ -1761,8 +1734,7 @@ filegroup( "//tensorflow/core/framework:android_test_hdrs", "//tensorflow/core/framework:android_test_srcs_no_core", "//tensorflow/core/platform:test.h", - "//tensorflow/core/util:reporter.cc", - "//tensorflow/core/util:reporter.h", + "//tensorflow/core/util:android_test_srcs", ], visibility = ["//visibility:public"], ) @@ -1774,6 +1746,7 @@ cc_library( srcs = if_android([":android_test_srcs"]), hdrs = [ "//tensorflow/core/framework:android_test_hdrs", + "//tensorflow/core/util:android_test_hdrs", ], copts = tf_copts(android_optimization_level_override = None), tags = [ @@ -1783,7 +1756,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":android_tensorflow_lib", - ":protos_cc", + ":protos_all_cc", "//tensorflow/core/platform/default/build_config:gtest", "//third_party/eigen3", ], @@ -2012,7 +1985,7 @@ LIB_INTERNAL_PUBLIC_HEADERS = [ "//tensorflow/core/platform:tracing.h", "//tensorflow/core/platform:unbounded_work_queue.h", "//tensorflow/core/platform:legacy_platform_lib_hdrs", - "//tensorflow/core/util:env_var.h", + "//tensorflow/core/util:lib_internal_public_hdrs", ] cc_library( @@ -2276,7 +2249,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":platform_base", - "//tensorflow/core/platform/default/build_config:logging", + "//tensorflow/core/platform:logging", ], ) @@ -2309,8 +2282,8 @@ cc_library( ":core_stringpiece", "//tensorflow/core/platform:dynamic_annotations", "//tensorflow/core/platform:jpeg", + "//tensorflow/core/platform:logging", "//tensorflow/core/platform:stringpiece", - "//tensorflow/core/platform/default/build_config:logging", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", ], @@ -2344,10 +2317,10 @@ cc_library( "//tensorflow/core/lib/strings:strcat", "//tensorflow/core/platform:dynamic_annotations", "//tensorflow/core/platform:gif", + "//tensorflow/core/platform:logging", "//tensorflow/core/platform:numbers", "//tensorflow/core/platform:strcat", "//tensorflow/core/platform:stringpiece", - "//tensorflow/core/platform/default/build_config:logging", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", ], @@ -2402,11 +2375,9 @@ alias( actual = "//tensorflow/core/lib/core:error_codes_proto_cc", ) -cc_library( +alias( name = "version_lib", - srcs = ["//tensorflow/core/util:version_info.cc"], - hdrs = ["public/version.h"], - copts = tf_copts(), + actual = "//tensorflow/core/util:version_info", ) FRAMEWORK_INTERNAL_PRIVATE_HEADERS = [ @@ -2551,6 +2522,9 @@ tf_cuda_library( "//tensorflow/core/framework:attr_value_proto_text", "//tensorflow/core/framework:bfloat16", "//tensorflow/core/framework:numeric_types", + "//tensorflow/core/framework:resource_handle", + "//tensorflow/core/framework:tensor", + "//tensorflow/core/framework:tensor_shape", "//tensorflow/core/kernels:bounds_check", "//tensorflow/core/platform/default/build_config:platformlib", "//tensorflow/core/profiler/internal:annotation_stack_impl", @@ -2558,6 +2532,7 @@ tf_cuda_library( "//tensorflow/core/profiler/lib:traceme", "//tensorflow/core/util:port", "//tensorflow/core/util:stats_calculator_portable", + "//tensorflow/compiler/jit:common", ] + if_static( extra_deps = ["@com_google_protobuf//:protobuf"], otherwise = ["@com_google_protobuf//:protobuf_headers"], @@ -2621,20 +2596,10 @@ cc_library( ], ) -tf_cuda_library( +alias( name = "cuda_device_functions", - hdrs = [ - "//tensorflow/core/util:gpu_device_functions.h", - ], + actual = "//tensorflow/core/util:gpu_device_functions", visibility = ["//visibility:public"], - deps = [":framework_lite"], -) - -# TODO(josh11b): Is this needed, or can we just use ":protos_all_cc"? -cc_library( - name = "protos_cc", - visibility = ["//visibility:public"], - deps = ["//tensorflow/core/platform/default/build_config:protos_cc"], ) # Library containing all of the graph construction code that is @@ -2713,7 +2678,7 @@ CORE_CPU_BASE_HDRS = GRAPH_HDRS + [ tf_cuda_library( name = "core_cpu_base", - hdrs = CORE_CPU_BASE_HDRS + ["public/session.h"], + hdrs = CORE_CPU_BASE_HDRS + ["//tensorflow/core/public:session.h"], copts = tf_copts(), deps = [":core_cpu_base_no_ops"] + if_static([ ":function_ops_op_lib", @@ -2735,10 +2700,10 @@ tf_cuda_library( "common_runtime/graph_optimizer.h", "graph/graph_constructor.cc", # Depends on common_runtime. "graph/graph_def_builder_util.cc", # Depends on common_runtime. - "public/session_options.h", - "public/version.h", + "//tensorflow/core/public:session_options.h", + "//tensorflow/core/public:version.h", ] + CORE_CPU_BASE_HDRS, - hdrs = CORE_CPU_BASE_HDRS + ["public/session.h"], + hdrs = CORE_CPU_BASE_HDRS + ["//tensorflow/core/public:session.h"], copts = tf_copts(), deps = [ ":graph", @@ -2880,9 +2845,9 @@ tf_cuda_library( "graph/mkl_layout_pass.cc", "graph/mkl_tfconversion_pass.cc", "graph/quantize_training.cc", - "public/session.h", - "public/session_options.h", - "public/version.h", + "//tensorflow/core/public:session.h", + "//tensorflow/core/public:session_options.h", + "//tensorflow/core/public:version.h", ], hdrs = CORE_CPU_LIB_HEADERS, copts = tf_copts() + tf_openmp_copts(), @@ -3006,7 +2971,7 @@ tf_cuda_library( srcs = ["common_runtime/direct_session.cc"], hdrs = [ "common_runtime/direct_session.h", - "//tensorflow/core/util:env_var.h", + "//tensorflow/core/util:lib_internal_public_hdrs", ], copts = tf_copts(), deps = [ @@ -3513,30 +3478,6 @@ tf_cc_test( ], ) -tf_cc_test( - name = "util_overflow_test", - size = "small", - srcs = ["//tensorflow/core/util:overflow_test.cc"], - deps = [ - ":framework_lite", - ":overflow", - ":test", - ":test_main", - ], -) - -tf_cc_test( - name = "exec_on_stall_test", - size = "small", - srcs = ["//tensorflow/core/util:exec_on_stall_test.cc"], - deps = [ - ":exec_on_stall", - ":framework_lite", - ":test", - ":test_main", - ], -) - tf_cc_test( name = "lib_jpeg_jpeg_mem_unittest", srcs = ["lib/jpeg/jpeg_mem_unittest.cc"], @@ -3628,6 +3569,7 @@ test_suite( tests = [ ":core_higher_level_tests", "//tensorflow/core/framework:higher_level_tests", + "//tensorflow/core/util:higher_level_tests", ], ) @@ -3660,29 +3602,6 @@ tf_cc_tests( "graph/subgraph_test.cc", "graph/tensor_id_test.cc", "graph/validate_test.cc", - "//tensorflow/core/util:bcast_test.cc", - "//tensorflow/core/util:command_line_flags_test.cc", - "//tensorflow/core/util:debug_events_writer_test.cc", - "//tensorflow/core/util:device_name_utils_test.cc", - "//tensorflow/core/util:dump_graph_test.cc", - "//tensorflow/core/util:equal_graph_def_test.cc", - "//tensorflow/core/util:events_writer_test.cc", - "//tensorflow/core/util:example_proto_fast_parsing_test.cc", - "//tensorflow/core/util:example_proto_helper_test.cc", - "//tensorflow/core/util:matmul_bcast_test.cc", - "//tensorflow/core/util:memmapped_file_system_test.cc", - "//tensorflow/core/util:presized_cuckoo_map_test.cc", - "//tensorflow/core/util:reffed_status_callback_test.cc", - "//tensorflow/core/util:reporter_test.cc", - "//tensorflow/core/util:saved_tensor_slice_util_test.cc", - "//tensorflow/core/util:semver_test.cc", - "//tensorflow/core/util:stat_summarizer_test.cc", - "//tensorflow/core/util:tensor_format_test.cc", - "//tensorflow/core/util:tensor_slice_reader_test.cc", - "//tensorflow/core/util:tensor_slice_set_test.cc", - "//tensorflow/core/util:tensor_slice_util_test.cc", - "//tensorflow/core/util:tensor_slice_writer_test.cc", - "//tensorflow/core/util:work_sharder_test.cc", "//tensorflow/core/util/sparse:higher_level_tests_group", ], create_named_test_suite = True, @@ -3910,7 +3829,7 @@ tf_cc_test_mkl( srcs = [ "graph/mkl_layout_pass_test.cc", "graph/mkl_tfconversion_pass_test.cc", - "//tensorflow/core/util:mkl_util_test.cc", + "//tensorflow/core/util:mkl_util_test_srcs", ], linkstatic = 1, deps = [ @@ -4063,18 +3982,6 @@ tf_cc_test_gpu( ], ) -tf_cuda_only_cc_test( - name = "util_gpu_kernel_helper_test", - srcs = [ - "//tensorflow/core/util:gpu_kernel_helper_test.cu.cc", - ], - deps = [ - ":test", - ":test_main", - "//third_party/eigen3", - ] + mkl_deps(), -) - tf_cc_test_gpu( name = "memory_types_test", size = "small", diff --git a/tensorflow/core/api_def/base_api/api_def_Asin.pbtxt b/tensorflow/core/api_def/base_api/api_def_Asin.pbtxt index 16531612fdf..1d5b62703ce 100644 --- a/tensorflow/core/api_def/base_api/api_def_Asin.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_Asin.pbtxt @@ -5,7 +5,7 @@ op { The `tf.math.asin` operation returns the inverse of `tf.math.sin`, such that if `y = tf.math.sin(x)` then, `x = tf.math.asin(y)`. -**Note**: The output of `tf.math.asin` will lie within the invertible range +**Note**: The output of `tf.math.asin` will lie within the invertible range of sine, i.e [-pi/2, pi/2]. For example: diff --git a/tensorflow/core/api_def/base_api/api_def_Atan.pbtxt b/tensorflow/core/api_def/base_api/api_def_Atan.pbtxt index 65ce42cb942..8ab19b7515a 100644 --- a/tensorflow/core/api_def/base_api/api_def_Atan.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_Atan.pbtxt @@ -5,7 +5,7 @@ op { The `tf.math.atan` operation returns the inverse of `tf.math.tan`, such that if `y = tf.math.tan(x)` then, `x = tf.math.atan(y)`. -**Note**: The output of `tf.math.atan` will lie within the invertible range +**Note**: The output of `tf.math.atan` will lie within the invertible range of tan, i.e (-pi/2, pi/2). For example: diff --git a/tensorflow/core/api_def/base_api/api_def_AudioSpectrogram.pbtxt b/tensorflow/core/api_def/base_api/api_def_AudioSpectrogram.pbtxt index 172696395ba..8af18098574 100644 --- a/tensorflow/core/api_def/base_api/api_def_AudioSpectrogram.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_AudioSpectrogram.pbtxt @@ -43,8 +43,8 @@ This op expects to receive audio data as an input, stored as floats in the range -1 to 1, together with a window width in samples, and a stride specifying how far to move the window between slices. From this it generates a three dimensional output. The first dimension is for the channels in the input, so a -stereo audio input would have two here for example. The second dimension is time, -with successive frequency slices. The third dimension has an amplitude value for +stereo audio input would have two here for example. The second dimension is time, +with successive frequency slices. The third dimension has an amplitude value for each frequency during that time slice. This means the layout when converted and saved as an image is rotated 90 degrees diff --git a/tensorflow/core/api_def/base_api/api_def_BlockLSTMV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_BlockLSTMV2.pbtxt index 4da9ebaf863..936099e70af 100644 --- a/tensorflow/core/api_def/base_api/api_def_BlockLSTMV2.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_BlockLSTMV2.pbtxt @@ -129,7 +129,7 @@ for x1 in unpack(x): h.append(h1) return pack(i), pack(cs), pack(f), pack(o), pack(ci), pack(ch), pack(h) -Note that unlike LSTMBlockCell (and BlockLSTM) which uses ICFO gate layout, +Note that unlike LSTMBlockCell (and BlockLSTM) which uses ICFO gate layout, this op uses IFCO. So in order for the following snippet to be equivalent all gate-related outputs should be reordered. ``` diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesCalculateBestFeatureSplitV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesCalculateBestFeatureSplitV2.pbtxt new file mode 100644 index 00000000000..2bbaba26257 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesCalculateBestFeatureSplitV2.pbtxt @@ -0,0 +1,124 @@ +op { + graph_op_name: "BoostedTreesCalculateBestFeatureSplitV2" + visibility: HIDDEN + in_arg { + name: "node_id_range" + description: <bik`), the contracted axis label is `j`. - (e) Expand Diagonal: If the output subcripts contain repeated (explicit) axis + (e) Expand Diagonal: If the output subscripts contain repeated (explicit) axis labels, the opposite operation of (a) is applied. For example, in the equation `i->iii`, and input shape `[3]`, the output of shape `[3, 3, 3]` are all zeros, except for the (generalized) diagonal which is populated @@ -70,7 +70,7 @@ Operations are applied to the input(s) according to the following rules: Note: This operation is not supported by `np.einsum` or `tf.einsum`; it is provided to enable computing the symbolic gradient of `tf.einsum`. -The output subcripts must contain only labels appearing in at least one of the +The output subscripts must contain only labels appearing in at least one of the input subscripts. Furthermore, all dimensions mapping to the same axis label must be equal. @@ -82,7 +82,7 @@ according to standard NumPy broadcasting The broadcasted dimensions are placed in the corresponding location of the ellipsis in the output subscript. If the broadcasted dimensions are non-empty -and the output subcripts do not contain ellipsis, then an InvalidArgument error +and the output subscripts do not contain ellipsis, then an InvalidArgument error is raised. @compatibility(numpy) diff --git a/tensorflow/core/api_def/base_api/api_def_LeftShift.pbtxt b/tensorflow/core/api_def/base_api/api_def_LeftShift.pbtxt index 3855c5095a7..b7bf38535a2 100644 --- a/tensorflow/core/api_def/base_api/api_def_LeftShift.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_LeftShift.pbtxt @@ -16,9 +16,9 @@ dtype_list = [tf.int8, tf.int16, tf.int32, tf.int64] for dtype in dtype_list: lhs = tf.constant([-1, -5, -3, -14], dtype=dtype) rhs = tf.constant([5, 0, 7, 11], dtype=dtype) - + left_shift_result = bitwise_ops.left_shift(lhs, rhs) - + print(left_shift_result) # This will print: diff --git a/tensorflow/core/api_def/base_api/api_def_LowerBound.pbtxt b/tensorflow/core/api_def/base_api/api_def_LowerBound.pbtxt index 5ce825ae043..c2b0405c93d 100644 --- a/tensorflow/core/api_def/base_api/api_def_LowerBound.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_LowerBound.pbtxt @@ -28,7 +28,7 @@ Each set of rows with the same index in (sorted_inputs, values) is treated independently. The resulting row is the equivalent of calling `np.searchsorted(sorted_inputs, values, side='left')`. -The result is not a global index to the entire +The result is not a global index to the entire `Tensor`, but rather just the index in the last dimension. A 2-D example: diff --git a/tensorflow/core/api_def/base_api/api_def_MatrixSolveLs.pbtxt b/tensorflow/core/api_def/base_api/api_def_MatrixSolveLs.pbtxt index e667c328ae5..4fc86807200 100644 --- a/tensorflow/core/api_def/base_api/api_def_MatrixSolveLs.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_MatrixSolveLs.pbtxt @@ -49,7 +49,7 @@ in the batch: If `fast` is `True`, then the solution is computed by solving the normal equations using Cholesky decomposition. Specifically, if \\(m \ge n\\) then \\(X = (A^H A + \lambda I)^{-1} A^H B\\), which solves the least-squares -problem \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k} } ||A Z - B||_F^2 + \lambda ||Z||_F^2\\). +problem \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k} } ||A Z - B||_F^2 + \lambda ||Z||_F^2\\). If \\(m \lt n\\) then `output` is computed as \\(X = A^H (A A^H + \lambda I)^{-1} B\\), which (for \\(\lambda = 0\\)) is the minimum-norm solution to the under-determined linear system, i.e. diff --git a/tensorflow/core/api_def/base_api/api_def_MatrixSquareRoot.pbtxt b/tensorflow/core/api_def/base_api/api_def_MatrixSquareRoot.pbtxt index a9f1e593ccb..1e1a80e7648 100644 --- a/tensorflow/core/api_def/base_api/api_def_MatrixSquareRoot.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_MatrixSquareRoot.pbtxt @@ -24,10 +24,10 @@ The input matrix should be invertible. If the input matrix is real, it should have no eigenvalues which are real and negative (pairs of complex conjugate eigenvalues are allowed). -The matrix square root is computed by first reducing the matrix to -quasi-triangular form with the real Schur decomposition. The square root -of the quasi-triangular matrix is then computed directly. Details of -the algorithm can be found in: Nicholas J. Higham, "Computing real +The matrix square root is computed by first reducing the matrix to +quasi-triangular form with the real Schur decomposition. The square root +of the quasi-triangular matrix is then computed directly. Details of +the algorithm can be found in: Nicholas J. Higham, "Computing real square roots of a real matrix", Linear Algebra Appl., 1987. The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions diff --git a/tensorflow/core/api_def/base_api/api_def_ParallelInterleaveDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ParallelInterleaveDataset.pbtxt index 939c64fe925..e30395cbfd3 100644 --- a/tensorflow/core/api_def/base_api/api_def_ParallelInterleaveDataset.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_ParallelInterleaveDataset.pbtxt @@ -11,14 +11,14 @@ END name: "other_arguments" description: <